summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--BUILD8
-rw-r--r--Makefile82
-rw-r--r--WORKSPACE397
-rw-r--r--g3doc/user_guide/containerd/quick_start.md4
-rw-r--r--g3doc/user_guide/install.md19
-rw-r--r--go.mod14
-rw-r--r--go.sum123
-rw-r--r--images/README.md6
-rw-r--r--images/defs.bzl17
-rw-r--r--nogo.yaml253
-rw-r--r--pkg/abi/linux/BUILD2
-rw-r--r--pkg/abi/linux/ioctl.go3
-rw-r--r--pkg/abi/linux/sem.go12
-rw-r--r--pkg/abi/linux/sem_amd64.go33
-rw-r--r--pkg/abi/linux/sem_arm64.go31
-rw-r--r--pkg/bpf/decoder.go2
-rw-r--r--pkg/context/context.go24
-rw-r--r--pkg/merkletree/BUILD10
-rw-r--r--pkg/merkletree/merkletree.go118
-rw-r--r--pkg/merkletree/merkletree_test.go407
-rw-r--r--pkg/refs/refcounter.go10
-rw-r--r--pkg/refsvfs2/BUILD (renamed from pkg/refs_vfs2/BUILD)19
-rw-r--r--pkg/refsvfs2/refs.go (renamed from pkg/refs_vfs2/refs.go)4
-rw-r--r--pkg/refsvfs2/refs_map.go131
-rw-r--r--pkg/refsvfs2/refs_template.go (renamed from pkg/refs_vfs2/refs_template.go)76
-rw-r--r--pkg/sentry/control/state.go2
-rw-r--r--pkg/sentry/devices/tundev/tundev.go4
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper.go7
-rw-r--r--pkg/sentry/fs/gofer/path.go3
-rw-r--r--pkg/sentry/fs/proc/sys_net.go4
-rw-r--r--pkg/sentry/fs/proc/task.go83
-rw-r--r--pkg/sentry/fs/tmpfs/inode_file.go6
-rw-r--r--pkg/sentry/fs/tmpfs/tmpfs.go2
-rw-r--r--pkg/sentry/fsimpl/devpts/BUILD3
-rw-r--r--pkg/sentry/fsimpl/devpts/devpts.go17
-rw-r--r--pkg/sentry/fsimpl/devpts/line_discipline.go4
-rw-r--r--pkg/sentry/fsimpl/devpts/master.go2
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/BUILD5
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/save_restore.go23
-rw-r--r--pkg/sentry/fsimpl/eventfd/eventfd.go2
-rw-r--r--pkg/sentry/fsimpl/fuse/BUILD3
-rw-r--r--pkg/sentry/fsimpl/fuse/dev_test.go2
-rw-r--r--pkg/sentry/fsimpl/fuse/fusefs.go20
-rw-r--r--pkg/sentry/fsimpl/fuse/read_write.go6
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD3
-rw-r--r--pkg/sentry/fsimpl/gofer/directory.go8
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go52
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go482
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer_test.go5
-rw-r--r--pkg/sentry/fsimpl/gofer/host_named_pipe.go20
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go22
-rw-r--r--pkg/sentry/fsimpl/gofer/save_restore.go329
-rw-r--r--pkg/sentry/fsimpl/gofer/socket.go7
-rw-r--r--pkg/sentry/fsimpl/gofer/special_file.go108
-rw-r--r--pkg/sentry/fsimpl/gofer/time.go12
-rw-r--r--pkg/sentry/fsimpl/host/BUILD7
-rw-r--r--pkg/sentry/fsimpl/host/control.go2
-rw-r--r--pkg/sentry/fsimpl/host/host.go255
-rw-r--r--pkg/sentry/fsimpl/host/save_restore.go70
-rw-r--r--pkg/sentry/fsimpl/host/util.go6
-rw-r--r--pkg/sentry/fsimpl/kernfs/BUILD40
-rw-r--r--pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go4
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go14
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go93
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go112
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go264
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go62
-rw-r--r--pkg/sentry/fsimpl/kernfs/mmap_util.go (renamed from pkg/sentry/fsimpl/host/mmap.go)83
-rw-r--r--pkg/sentry/fsimpl/kernfs/save_restore.go36
-rw-r--r--pkg/sentry/fsimpl/kernfs/symlink.go8
-rw-r--r--pkg/sentry/fsimpl/kernfs/synthetic_directory.go11
-rw-r--r--pkg/sentry/fsimpl/overlay/BUILD3
-rw-r--r--pkg/sentry/fsimpl/overlay/copy_up.go2
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go146
-rw-r--r--pkg/sentry/fsimpl/overlay/overlay.go242
-rw-r--r--pkg/sentry/fsimpl/overlay/save_restore.go27
-rw-r--r--pkg/sentry/fsimpl/pipefs/pipefs.go2
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD11
-rw-r--r--pkg/sentry/fsimpl/proc/filesystem.go36
-rw-r--r--pkg/sentry/fsimpl/proc/subtasks.go4
-rw-r--r--pkg/sentry/fsimpl/proc/task.go7
-rw-r--r--pkg/sentry/fsimpl/proc/task_fds.go16
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go173
-rw-r--r--pkg/sentry/fsimpl/proc/task_net.go34
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go30
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_files.go8
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go116
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_test.go1
-rw-r--r--pkg/sentry/fsimpl/sockfs/sockfs.go4
-rw-r--r--pkg/sentry/fsimpl/sys/BUILD3
-rw-r--r--pkg/sentry/fsimpl/sys/kcov.go2
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go70
-rw-r--r--pkg/sentry/fsimpl/tmpfs/BUILD4
-rw-r--r--pkg/sentry/fsimpl/tmpfs/named_pipe.go3
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go4
-rw-r--r--pkg/sentry/fsimpl/tmpfs/save_restore.go20
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go11
-rw-r--r--pkg/sentry/fsimpl/verity/BUILD6
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go138
-rw-r--r--pkg/sentry/fsimpl/verity/save_restore.go27
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go246
-rw-r--r--pkg/sentry/fsimpl/verity/verity_test.go834
-rw-r--r--pkg/sentry/hostfd/BUILD2
-rw-r--r--pkg/sentry/hostfd/hostfd_linux.go18
-rw-r--r--pkg/sentry/hostfd/hostfd_unsafe.go9
-rw-r--r--pkg/sentry/inet/inet.go6
-rw-r--r--pkg/sentry/inet/test_stack.go21
-rw-r--r--pkg/sentry/kernel/BUILD12
-rw-r--r--pkg/sentry/kernel/abstract_socket_namespace.go10
-rw-r--r--pkg/sentry/kernel/fd_table.go6
-rw-r--r--pkg/sentry/kernel/fd_table_unsafe.go13
-rw-r--r--pkg/sentry/kernel/fs_context.go24
-rw-r--r--pkg/sentry/kernel/ipc_namespace.go2
-rw-r--r--pkg/sentry/kernel/kernel.go176
-rw-r--r--pkg/sentry/kernel/pipe/node_test.go8
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go50
-rw-r--r--pkg/sentry/kernel/pipe/pipe_test.go11
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go8
-rw-r--r--pkg/sentry/kernel/ptrace.go4
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore.go75
-rw-r--r--pkg/sentry/kernel/sessions.go32
-rw-r--r--pkg/sentry/kernel/shm/BUILD4
-rw-r--r--pkg/sentry/kernel/task_clone.go2
-rw-r--r--pkg/sentry/kernel/task_usermem.go95
-rw-r--r--pkg/sentry/kernel/vdso.go2
-rw-r--r--pkg/sentry/mm/BUILD5
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go12
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.go17
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go17
-rw-r--r--pkg/sentry/platform/kvm/bluepill_unsafe.go7
-rw-r--r--pkg/sentry/platform/kvm/kvm.go3
-rw-r--r--pkg/sentry/platform/kvm/kvm_const_arm64.go5
-rw-r--r--pkg/sentry/platform/kvm/machine.go48
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go63
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go33
-rw-r--r--pkg/sentry/platform/ring0/aarch64.go67
-rw-r--r--pkg/sentry/platform/ring0/defs.go3
-rw-r--r--pkg/sentry/platform/ring0/defs_amd64.go10
-rw-r--r--pkg/sentry/platform/ring0/defs_arm64.go7
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s67
-rw-r--r--pkg/sentry/platform/ring0/kernel.go6
-rw-r--r--pkg/sentry/platform/ring0/kernel_amd64.go5
-rw-r--r--pkg/sentry/platform/ring0/kernel_arm64.go4
-rw-r--r--pkg/sentry/platform/ring0/offsets_arm64.go45
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables.go84
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go10
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go21
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go15
-rw-r--r--pkg/sentry/platform/ring0/pagetables/walker_arm64.go2
-rw-r--r--pkg/sentry/socket/control/control_vfs2.go10
-rw-r--r--pkg/sentry/socket/hostinet/socket_vfs2.go3
-rw-r--r--pkg/sentry/socket/hostinet/stack.go21
-rw-r--r--pkg/sentry/socket/netfilter/extensions.go63
-rw-r--r--pkg/sentry/socket/netfilter/ipv4.go4
-rw-r--r--pkg/sentry/socket/netfilter/ipv6.go4
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go52
-rw-r--r--pkg/sentry/socket/netfilter/targets.go217
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go2
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go2
-rw-r--r--pkg/sentry/socket/netlink/provider_vfs2.go2
-rw-r--r--pkg/sentry/socket/netlink/route/protocol.go62
-rw-r--r--pkg/sentry/socket/netlink/socket_vfs2.go2
-rw-r--r--pkg/sentry/socket/netstack/netstack.go12
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go4
-rw-r--r--pkg/sentry/socket/netstack/stack.go85
-rw-r--r--pkg/sentry/socket/unix/BUILD5
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD3
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go16
-rw-r--r--pkg/sentry/socket/unix/unix.go1
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go3
-rw-r--r--pkg/sentry/state/BUILD2
-rw-r--r--pkg/sentry/state/state.go10
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_pipe.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_sem.go52
-rw-r--r--pkg/sentry/syscalls/linux/sys_splice.go30
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/splice.go23
-rw-r--r--pkg/sentry/vfs/BUILD8
-rw-r--r--pkg/sentry/vfs/epoll.go2
-rw-r--r--pkg/sentry/vfs/file_description.go1
-rw-r--r--pkg/sentry/vfs/genericfstree/genericfstree.go11
-rw-r--r--pkg/sentry/vfs/inotify.go2
-rw-r--r--pkg/sentry/vfs/lock.go5
-rw-r--r--pkg/sentry/vfs/mount.go69
-rw-r--r--pkg/sentry/vfs/mount_test.go35
-rw-r--r--pkg/sentry/vfs/mount_unsafe.go21
-rw-r--r--pkg/sentry/vfs/save_restore.go124
-rw-r--r--pkg/sentry/vfs/vfs.go24
-rw-r--r--pkg/shim/runsc/BUILD1
-rw-r--r--pkg/shim/runsc/runsc.go12
-rw-r--r--pkg/state/BUILD14
-rw-r--r--pkg/state/decode.go80
-rw-r--r--pkg/state/decode_unsafe.go57
-rw-r--r--pkg/state/encode.go221
-rw-r--r--pkg/state/pretty/pretty.go41
-rw-r--r--pkg/state/state.go10
-rw-r--r--pkg/state/tests/struct.go35
-rw-r--r--pkg/state/tests/struct_test.go34
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go11
-rw-r--r--pkg/tcpip/checker/checker.go49
-rw-r--r--pkg/tcpip/header/icmpv4.go9
-rw-r--r--pkg/tcpip/header/ipv4.go493
-rw-r--r--pkg/tcpip/header/ipv6.go12
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/pipe_test.go36
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go7
-rw-r--r--pkg/tcpip/link/tun/BUILD3
-rw-r--r--pkg/tcpip/link/tun/device.go2
-rw-r--r--pkg/tcpip/network/arp/BUILD1
-rw-r--r--pkg/tcpip/network/arp/arp.go89
-rw-r--r--pkg/tcpip/network/arp/arp_test.go201
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go58
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go120
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go22
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go23
-rw-r--r--pkg/tcpip/network/ip_test.go115
-rw-r--r--pkg/tcpip/network/ipv4/BUILD1
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go168
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go542
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go848
-rw-r--r--pkg/tcpip/network/ipv6/BUILD1
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go135
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go139
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go275
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go458
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go18
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go12
-rw-r--r--pkg/tcpip/stack/conntrack.go14
-rw-r--r--pkg/tcpip/stack/forwarding_test.go6
-rw-r--r--pkg/tcpip/stack/iptables.go75
-rw-r--r--pkg/tcpip/stack/iptables_targets.go102
-rw-r--r--pkg/tcpip/stack/iptables_types.go41
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go8
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go4
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go33
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go531
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go221
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go2487
-rw-r--r--pkg/tcpip/stack/nic.go63
-rw-r--r--pkg/tcpip/stack/nic_test.go5
-rw-r--r--pkg/tcpip/stack/nud.go8
-rw-r--r--pkg/tcpip/stack/packet_buffer.go60
-rw-r--r--pkg/tcpip/stack/pending_packets.go2
-rw-r--r--pkg/tcpip/stack/registration.go51
-rw-r--r--pkg/tcpip/stack/route.go288
-rw-r--r--pkg/tcpip/stack/stack.go324
-rw-r--r--pkg/tcpip/stack/stack_test.go568
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go63
-rw-r--r--pkg/tcpip/stack/transport_test.go35
-rw-r--r--pkg/tcpip/tcpip.go24
-rw-r--r--pkg/tcpip/tests/integration/BUILD1
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go46
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go48
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go2
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go24
-rw-r--r--pkg/tcpip/tests/integration/route_test.go388
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go6
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go2
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go7
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go16
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go38
-rw-r--r--pkg/tcpip/transport/tcp/accept.go93
-rw-r--r--pkg/tcpip/transport/tcp/connect.go30
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go7
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go19
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go39
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go4
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go12
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go31
-rw-r--r--pkg/tcpip/transport/tcp/sack_scoreboard.go2
-rw-r--r--pkg/tcpip/transport/tcp/segment.go47
-rw-r--r--pkg/tcpip/transport/tcp/snd.go4
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go20
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go98
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go50
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go24
-rw-r--r--pkg/tcpip/transport/udp/protocol.go8
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go5
-rw-r--r--pkg/unet/unet_test.go157
-rw-r--r--pkg/waiter/waiter.go2
-rw-r--r--runsc/boot/BUILD2
-rw-r--r--runsc/boot/controller.go50
-rw-r--r--runsc/boot/fs.go33
-rw-r--r--runsc/boot/loader.go11
-rw-r--r--runsc/boot/loader_test.go57
-rw-r--r--runsc/boot/vfs.go28
-rw-r--r--runsc/cgroup/cgroup.go11
-rw-r--r--runsc/cgroup/cgroup_test.go80
-rw-r--r--runsc/cmd/boot.go6
-rw-r--r--runsc/cmd/checkpoint.go5
-rw-r--r--runsc/cmd/debug.go4
-rw-r--r--runsc/cmd/delete.go2
-rw-r--r--runsc/cmd/events.go11
-rw-r--r--runsc/cmd/exec.go2
-rw-r--r--runsc/cmd/kill.go2
-rw-r--r--runsc/cmd/list.go2
-rw-r--r--runsc/cmd/pause.go2
-rw-r--r--runsc/cmd/ps.go2
-rw-r--r--runsc/cmd/resume.go2
-rw-r--r--runsc/cmd/start.go9
-rw-r--r--runsc/cmd/state.go2
-rw-r--r--runsc/cmd/wait.go2
-rw-r--r--runsc/config/config.go3
-rw-r--r--runsc/config/flags.go3
-rw-r--r--runsc/container/container.go179
-rw-r--r--runsc/container/container_test.go18
-rw-r--r--runsc/container/multi_container_test.go76
-rw-r--r--runsc/sandbox/sandbox.go13
-rw-r--r--runsc/specutils/specutils.go50
-rw-r--r--test/benchmarks/base/BUILD2
-rw-r--r--test/iptables/filter_output.go17
-rw-r--r--test/iptables/nat.go15
-rw-r--r--test/packetimpact/netdevs/netdevs.go2
-rw-r--r--test/packetimpact/runner/defs.bzl3
-rw-r--r--test/packetimpact/testbench/connections.go55
-rw-r--r--test/packetimpact/testbench/dut_client.go2
-rw-r--r--test/packetimpact/testbench/layers.go40
-rw-r--r--test/packetimpact/tests/BUILD13
-rw-r--r--test/packetimpact/tests/ipv4_fragment_reassembly_test.go142
-rw-r--r--test/packetimpact/tests/ipv6_fragment_reassembly_test.go237
-rw-r--r--test/packetimpact/tests/tcp_network_unreachable_test.go4
-rw-r--r--test/root/oom_score_adj_test.go12
-rw-r--r--test/runner/defs.bzl2
-rw-r--r--test/runner/runner.go6
-rw-r--r--test/runtimes/exclude/java11.csv1
-rw-r--r--test/runtimes/exclude/nodejs12.4.0.csv41
-rw-r--r--test/runtimes/exclude/php7.3.6.csv6
-rw-r--r--test/runtimes/exclude/python3.7.3.csv2
-rw-r--r--test/runtimes/proctor/main.go28
-rw-r--r--test/runtimes/runner/lib/lib.go39
-rw-r--r--test/runtimes/runner/main.go14
-rw-r--r--test/syscalls/BUILD4
-rw-r--r--test/syscalls/linux/BUILD30
-rw-r--r--test/syscalls/linux/mknod.cc30
-rw-r--r--test/syscalls/linux/mmap.cc6
-rw-r--r--test/syscalls/linux/mount.cc38
-rw-r--r--test/syscalls/linux/packet_socket_raw.cc11
-rw-r--r--test/syscalls/linux/proc.cc409
-rw-r--r--test/syscalls/linux/proc_pid_smaps.cc2
-rw-r--r--test/syscalls/linux/ptrace.cc43
-rw-r--r--test/syscalls/linux/raw_socket_icmp.cc13
-rw-r--r--test/syscalls/linux/semaphore.cc287
-rw-r--r--test/syscalls/linux/sendfile.cc53
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc26
-rw-r--r--test/syscalls/linux/socket_ip_udp_generic.cc14
-rw-r--r--test/syscalls/linux/socket_ip_unbound.cc18
-rw-r--r--test/syscalls/linux/socket_ip_unbound_netlink.cc104
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound.cc82
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc28
-rw-r--r--test/syscalls/linux/socket_netlink_route.cc63
-rw-r--r--test/syscalls/linux/socket_netlink_route_util.cc22
-rw-r--r--test/syscalls/linux/socket_netlink_route_util.h11
-rw-r--r--test/syscalls/linux/socket_test_util.cc16
-rw-r--r--test/syscalls/linux/socket_test_util.h9
-rw-r--r--test/syscalls/linux/socket_unix_stream.cc13
-rw-r--r--test/syscalls/linux/splice.cc55
-rw-r--r--test/syscalls/linux/stat.cc34
-rw-r--r--test/syscalls/linux/tcp_socket.cc57
-rw-r--r--test/syscalls/linux/timers.cc5
-rw-r--r--test/syscalls/linux/udp_socket.cc72
-rw-r--r--test/util/BUILD4
-rw-r--r--test/util/posix_error.cc2
-rw-r--r--test/util/posix_error.h34
-rw-r--r--test/util/save_util.cc52
-rw-r--r--test/util/save_util.h20
-rw-r--r--test/util/save_util_linux.cc22
-rw-r--r--test/util/save_util_other.cc8
-rw-r--r--test/util/signal_util.h2
-rw-r--r--test/util/timer_util.h5
-rw-r--r--tools/bazel.mk40
-rw-r--r--tools/bazeldefs/go.bzl8
-rw-r--r--tools/bigquery/BUILD5
-rw-r--r--tools/bigquery/bigquery.go63
-rw-r--r--tools/defs.bzl12
-rw-r--r--tools/github/nogo/BUILD2
-rw-r--r--tools/github/nogo/nogo.go19
-rw-r--r--tools/github/reviver/github.go19
-rw-r--r--tools/github/reviver/reviver_test.go9
-rwxr-xr-xtools/go_branch.sh45
-rw-r--r--tools/go_generics/imports.go4
-rw-r--r--tools/go_marshal/test/escape/escape.go1
-rw-r--r--tools/go_marshal/test/test.go2
-rw-r--r--tools/nogo/BUILD10
-rw-r--r--tools/nogo/analyzers.go131
-rw-r--r--tools/nogo/build.go16
-rw-r--r--tools/nogo/check/BUILD2
-rw-r--r--tools/nogo/check/main.go92
-rw-r--r--tools/nogo/config.go757
-rw-r--r--tools/nogo/defs.bzl195
-rw-r--r--tools/nogo/filter/BUILD14
-rw-r--r--tools/nogo/filter/main.go131
-rw-r--r--tools/nogo/findings.go63
-rwxr-xr-xtools/nogo/gentest.sh48
-rw-r--r--tools/nogo/matchers.go172
-rw-r--r--tools/nogo/nogo.go175
-rw-r--r--tools/nogo/register.go67
-rw-r--r--tools/nogo/util/BUILD9
-rw-r--r--tools/nogo/util/util.go85
-rw-r--r--tools/parsers/BUILD20
-rw-r--r--tools/parsers/go_parser.go19
-rw-r--r--tools/parsers/go_parser_test.go12
-rw-r--r--tools/parsers/parser_main.go135
-rw-r--r--tools/parsers/version.go18
-rwxr-xr-xtools/tag_release.sh6
-rw-r--r--webhook/BUILD28
-rw-r--r--webhook/main.go24
-rw-r--r--webhook/pkg/cli/BUILD17
-rw-r--r--webhook/pkg/cli/cli.go115
-rw-r--r--webhook/pkg/injector/BUILD34
-rwxr-xr-xwebhook/pkg/injector/gencerts.sh71
-rw-r--r--webhook/pkg/injector/webhook.go211
-rw-r--r--website/BUILD7
-rw-r--r--website/_config.yml7
-rw-r--r--website/_includes/byline.html2
-rw-r--r--website/blog/2020-10-22-platform-portability.md120
-rw-r--r--website/blog/BUILD11
-rw-r--r--website/cmd/server/main.go20
416 files changed, 18142 insertions, 7073 deletions
diff --git a/BUILD b/BUILD
index 63dd05118..a133f16e9 100644
--- a/BUILD
+++ b/BUILD
@@ -1,10 +1,17 @@
load("//tools:defs.bzl", "build_test", "gazelle", "go_path")
+load("//tools/nogo:defs.bzl", "nogo_config")
load("//website:defs.bzl", "doc")
package(licenses = ["notice"])
exports_files(["LICENSE"])
+nogo_config(
+ name = "nogo_config",
+ srcs = ["nogo.yaml"],
+ visibility = ["//:sandbox"],
+)
+
doc(
name = "contributing",
src = "CONTRIBUTING.md",
@@ -86,6 +93,7 @@ go_path(
"//runsc/cli",
"//shim/v1/cli",
"//shim/v2/cli",
+ "//webhook/pkg/cli",
# Packages that are not dependencies of the above.
"//pkg/sentry/kernel/memevent",
diff --git a/Makefile b/Makefile
index afc25557e..5a95adde5 100644
--- a/Makefile
+++ b/Makefile
@@ -156,12 +156,24 @@ syscall-tests: ## Run all system call tests.
@$(call submake,test TARGETS="test/syscalls/...")
%-runtime-tests: load-runtimes_%
+ifeq ($(PARTITION),)
+ @$(eval PARTITION := 1)
+endif
+ifeq ($(TOTAL_PARTITIONS),)
+ @$(eval TOTAL_PARTITIONS := 1)
+endif
@$(call submake,install-test-runtime)
- @$(call submake,test-runtime OPTIONS="--test_timeout=10800" TARGETS="//test/runtimes:$*")
+ @$(call submake,test-runtime OPTIONS="--test_timeout=10800 --test_arg=--partition=$(PARTITION) --test_arg=--total_partitions=$(TOTAL_PARTITIONS)" TARGETS="//test/runtimes:$*")
%-runtime-tests_vfs2: load-runtimes_%
+ifeq ($(PARTITION),)
+ @$(eval PARTITION := 1)
+endif
+ifeq ($(TOTAL_PARTITIONS),)
+ @$(eval TOTAL_PARTITIONS := 1)
+endif
@$(call submake,install-test-runtime RUNTIME="vfs2" ARGS="--vfs2")
- @$(call submake,test-runtime RUNTIME="vfs2" OPTIONS="--test_timeout=10800" TARGETS="//test/runtimes:$*")
+ @$(call submake,test-runtime RUNTIME="vfs2" OPTIONS="--test_timeout=10800 --test_arg=--partition=$(PARTITION) --test_arg=--total_partitions=$(TOTAL_PARTITIONS)" TARGETS="//test/runtimes:$*")
do-tests: runsc
@$(call submake,run TARGETS="//runsc" ARGS="--rootless do true")
@@ -210,6 +222,15 @@ iptables-tests: load-iptables
@$(call submake,test-runtime RUNTIME="iptables" TARGETS="//test/iptables:iptables_test")
.PHONY: iptables-tests
+# Run the iptables tests with runsc only. Useful for developing to skip runc
+# testing.
+iptables-runsc-tests: load-iptables
+ @sudo modprobe iptable_filter
+ @sudo modprobe ip6table_filter
+ @$(call submake,install-test-runtime RUNTIME="iptables" ARGS="--net-raw")
+ @$(call submake,test-runtime RUNTIME="iptables" TARGETS="//test/iptables:iptables_test")
+.PHONY: iptables-runsc-tests
+
packetdrill-tests: load-packetdrill
@$(call submake,install-test-runtime RUNTIME="packetdrill")
@$(call submake,test-runtime RUNTIME="packetdrill" TARGETS="$(shell $(MAKE) query TARGETS='attr(tags, packetdrill, tests(//...))')")
@@ -240,6 +261,61 @@ containerd-tests: containerd-test-1.3.4
containerd-tests: containerd-test-1.4.0-beta.0
##
+## Benchmarks.
+##
+## Targets to run benchmarks. See //test/benchmarks for details.
+##
+## common arguments:
+## RUNTIME_ARGS - arguments to runsc placed in /etc/docker/daemon.json
+## e.g. "--platform=ptrace"
+## BENCHMARKS_PROJECT - BigQuery project to which to send data.
+## BENCHMARKS_DATASET - BigQuery dataset to which to send data.
+## BENCHMARKS_TABLE - BigQuery table to which to send data.
+## BENCHMARKS_SUITE - name of the benchmark suite. See //tools/bigquery/bigquery.go.
+## BENCHMARKS_UPLOAD - if true, upload benchmark data from the run.
+## BENCHMARKS_OFFICIAL - marks the data as official.
+## BENCHMARKS_PLATFORMS - platforms to run benchmarks (e.g. ptrace kvm).
+##
+RUNTIME_ARGS := --net-raw --platform=ptrace
+BENCHMARKS_PROJECT := gvisor-benchmarks
+BENCHMARKS_DATASET := kokoro
+BENCHMARKS_TABLE := benchmarks
+BENCHMARKS_SUITE := start
+BENCHMARKS_UPLOAD := false
+BENCHMARKS_OFFICIAL := false
+BENCHMARKS_PLATFORMS := ptrace
+
+init-benchmark-table: ## Initializes a BigQuery table with the benchmark schema
+## (see //tools/bigquery/bigquery.go). If the table alread exists, this is a noop.
+ $(call submake, run TARGETS=//tools/parsers:parser ARGS="init --project=$(BENCHMARKS_PROJECT) \
+ --dataset=$(BENCHMARKS_DATASET) --table=$(BENCHMARKS_TABLE)")
+.PHONY: init-benchmark-table
+
+benchmark-platforms: load-benchmarks-images ## Runs benchmarks for runc and all given platforms in BENCHMARK_PLATFORMS.
+ $(call submake, run-benchmark RUNTIME="runc")
+ $(foreach PLATFORM,$(BENCHMARKS_PLATFORMS),\
+ $(call submake,benchmark-platform RUNTIME="$(PLATFORM)" RUNTIME_ARGS="--platform=$(PLATFORM) --net-raw --vfs2") && \
+ $(call submake,benchmark-platform RUNTIME="$(PLATFORM)_vfs1" RUNTIME_ARGS="--platform=$(PLATFORM) --net-raw"))
+.PHONY: benchmark-platforms
+
+benchmark-platform: ## Installs a runtime with the given platform args.
+ @$(call submake,install-test-runtime ARGS="$(RUNTIME_ARGS)")
+ @$(call submake, run-benchmark)
+.PHONY: benchmark-platform
+
+run-benchmark: ## Runs single benchmark and optionally sends data to BigQuery.
+ $(eval T := $(shell mktemp /tmp/logs.$(RUNTIME).XXXXXX))
+ $(call submake,sudo TARGETS="$(TARGETS)" ARGS="--runtime=$(RUNTIME) $(ARGS)" | tee $(T))
+ @if [[ "$(BENCHMARKS_UPLOAD)" == "true" ]]; then \
+ @$(call submake,run TARGETS=tools/parsers:parser ARGS="parse --file=$(T) \
+ --runtime=$(RUNTIME) --suite_name=$(BENCHMARKS_SUITE) \
+ --project=$(BENCHMARKS_PROJECT) --dataset=$(BENCHMARKS_DATASET) \
+ --table=$(BENCHMARKS_TABLE) --official=$(BENCHMARKS_OFFICIAL)"); \
+ fi;
+ rm -rf $T
+.PHONY: run-benchmark
+
+##
## Website & documentation helpers.
##
## The website is built from repository documentation and wrappers, using
@@ -260,7 +336,7 @@ website-build: load-jekyll ## Build the site image locally.
.PHONY: website-build
website-server: website-build ## Run a local server for development.
- @docker run -i -p 8080:8080 gvisor.dev/images/website
+ @docker run -i -p 8080:8080 $(WEBSITE_IMAGE)
.PHONY: website-server
website-push: website-build ## Push a new image and update the service.
diff --git a/WORKSPACE b/WORKSPACE
index 30d21e472..2f3408709 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -23,13 +23,13 @@ bazel_skylib_workspace()
http_archive(
name = "io_bazel_rules_go",
- sha256 = "b725e6497741d7fc2d55fcc29a276627d10e43fa5d0bb692692890ae30d98d00",
patch_args = ["-p1"],
patches = [
# Newer versions of the rules_go rules will automatically strip test
# binaries of symbols, which we don't want.
"//tools:rules_go.patch",
],
+ sha256 = "b725e6497741d7fc2d55fcc29a276627d10e43fa5d0bb692692890ae30d98d00",
urls = [
"https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.24.3/rules_go-v0.24.3.tar.gz",
"https://github.com/bazelbuild/rules_go/releases/download/v0.24.3/rules_go-v0.24.3.tar.gz",
@@ -49,7 +49,7 @@ load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_depe
go_rules_dependencies()
-go_register_toolchains(go_version = "1.14.2")
+go_register_toolchains(go_version = "1.15.2")
load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies", "go_repository")
@@ -58,7 +58,7 @@ gazelle_dependencies()
# The com_google_protobuf repository below would trigger downloading a older
# version of org_golang_x_sys. If putting this repository statment in a place
# after that of the com_google_protobuf, this statement will not work as
-# expectd to download a new version of org_golang_x_sys.
+# expected to download a new version of org_golang_x_sys.
go_repository(
name = "org_golang_x_sys",
importpath = "golang.org/x/sys",
@@ -222,8 +222,8 @@ go_repository(
go_repository(
name = "com_github_google_uuid",
importpath = "github.com/google/uuid",
- sum = "h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA=",
- version = "v1.0.0",
+ sum = "h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=",
+ version = "v1.1.1",
)
go_repository(
@@ -328,8 +328,8 @@ go_repository(
go_repository(
name = "org_golang_x_tools",
importpath = "golang.org/x/tools",
- sum = "h1:vWQvJ/Z0Lu+9/8oQ/pAYXNzbc7CMnBl+tULGVHOy3oE=",
- version = "v0.0.0-20201002184944-ecd9fd270d5d",
+ sum = "h1:K+nJoPcImWk+ZGPHOKkDocKcQPACCz8usiCiVQYfXsk=",
+ version = "v0.0.0-20201021000207-d49c4edd7d96",
)
go_repository(
@@ -349,8 +349,8 @@ go_repository(
go_repository(
name = "com_github_golang_protobuf",
importpath = "github.com/golang/protobuf",
- sum = "h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0=",
- version = "v1.4.2",
+ sum = "h1:ZFgWrT+bLgsYPirOnRfKLYJLvssAegOj/hgyMFdJZe0=",
+ version = "v1.4.1",
)
go_repository(
@@ -412,7 +412,7 @@ go_repository(
go_repository(
name = "com_github_konsorten_go_windows_terminal_sequences",
importpath = "github.com/konsorten/go-windows-terminal-sequences",
- sum = "h1:vWQvJ/Z0Lu+9/8oQ/pAYXNzbc7CMnBl+tULGVHOy3oE=",
+ sum = "h1:CE8S1cTafDpPvMhIxNJKvHsGVBgn1xWYf1NbHQhywc8=",
version = "v1.0.3",
)
@@ -461,8 +461,8 @@ go_repository(
go_repository(
name = "org_uber_go_multierr",
importpath = "go.uber.org/multierr",
- sum = "h1:6I+W7f5VwC5SV9dNrZ3qXrDB9mD0dyGOi/ZJmYw03T4=",
- version = "v1.2.0",
+ sum = "h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4=",
+ version = "v1.6.0",
)
go_repository(
@@ -482,8 +482,8 @@ go_repository(
go_repository(
name = "co_honnef_go_tools",
importpath = "honnef.co/go/tools",
- sum = "h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM=",
- version = "v0.0.1-2019.2.3",
+ sum = "h1:W18jzjh8mfPez+AwGLxmOImucz/IFjpNlrKVnaj2YVc=",
+ version = "v0.0.1-2020.1.6",
)
go_repository(
@@ -623,8 +623,8 @@ go_repository(
go_repository(
name = "com_github_google_go_cmp",
importpath = "github.com/google/go-cmp",
- sum = "h1:JFrFEBb2xKufg6XkJsJr+WbKb4FQlURi5RUcBveYu9k=",
- version = "v0.5.1",
+ sum = "h1:pJfrTSHC+QpCQplFZqzlwihfc+0Oty0ViHPHPxXj0SI=",
+ version = "v0.5.3-0.20201020212313-ab46b8bd0abd",
)
go_repository(
@@ -721,8 +721,8 @@ go_repository(
go_repository(
name = "com_github_spf13_pflag",
importpath = "github.com/spf13/pflag",
- sum = "h1:j8jxLbQ0+T1DFggy6XoGvyUnrJWPR/JybflPvu5rwS4=",
- version = "v1.0.1-0.20171106142849-4c012f6dcd95",
+ sum = "h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=",
+ version = "v1.0.5",
)
go_repository(
@@ -763,15 +763,15 @@ go_repository(
go_repository(
name = "org_golang_google_genproto",
importpath = "google.golang.org/genproto",
- sum = "h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY=",
- version = "v0.0.0-20200526211855-cb27e3aa2013",
+ sum = "h1:wDju+RU97qa0FZT0QnZDg9Uc2dH0Ql513kFvHocz+WM=",
+ version = "v0.0.0-20200117163144-32f20d992d24",
)
go_repository(
name = "org_golang_google_protobuf",
importpath = "google.golang.org/protobuf",
- sum = "h1:poC0iCcx0QXFYlS6nuq/8K+Ng5T55k0FXdzq52hVi4w=",
- version = "v1.25.1-0.20200808011614-a180de9f97d9",
+ sum = "h1:jEdfCm+8YTWSYgU4L7Nq0jjU+q9RxIhi0cXLTY+Ih3A=",
+ version = "v1.25.1-0.20201020201750-d3470999428b",
)
go_repository(
@@ -1032,3 +1032,356 @@ go_repository(
sum = "h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo=",
version = "v1.0.0",
)
+
+go_repository(
+ name = "com_github_azure_go_autorest_autorest",
+ importpath = "github.com/Azure/go-autorest/autorest",
+ sum = "h1:MRvx8gncNaXJqOoLmhNjUAKh33JJF8LyxPhomEtOsjs=",
+ version = "v0.9.0",
+)
+
+go_repository(
+ name = "com_github_azure_go_autorest_autorest_adal",
+ importpath = "github.com/Azure/go-autorest/autorest/adal",
+ sum = "h1:q2gDruN08/guU9vAjuPWff0+QIrpH6ediguzdAzXAUU=",
+ version = "v0.5.0",
+)
+
+go_repository(
+ name = "com_github_azure_go_autorest_autorest_date",
+ importpath = "github.com/Azure/go-autorest/autorest/date",
+ sum = "h1:YGrhWfrgtFs84+h0o46rJrlmsZtyZRg470CqAXTZaGM=",
+ version = "v0.1.0",
+)
+
+go_repository(
+ name = "com_github_azure_go_autorest_autorest_mocks",
+ importpath = "github.com/Azure/go-autorest/autorest/mocks",
+ sum = "h1:Ww5g4zThfD/6cLb4z6xxgeyDa7QDkizMkJKe0ysZXp0=",
+ version = "v0.2.0",
+)
+
+go_repository(
+ name = "com_github_azure_go_autorest_logger",
+ importpath = "github.com/Azure/go-autorest/logger",
+ sum = "h1:ruG4BSDXONFRrZZJ2GUXDiUyVpayPmb1GnWeHDdaNKY=",
+ version = "v0.1.0",
+)
+
+go_repository(
+ name = "com_github_azure_go_autorest_tracing",
+ importpath = "github.com/Azure/go-autorest/tracing",
+ sum = "h1:TRn4WjSnkcSy5AEG3pnbtFSwNtwzjr4VYyQflFE619k=",
+ version = "v0.5.0",
+)
+
+go_repository(
+ name = "com_github_dgrijalva_jwt_go",
+ importpath = "github.com/dgrijalva/jwt-go",
+ sum = "h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=",
+ version = "v3.2.0+incompatible",
+)
+
+go_repository(
+ name = "com_github_docker_spdystream",
+ importpath = "github.com/docker/spdystream",
+ sum = "h1:cenwrSVm+Z7QLSV/BsnenAOcDXdX4cMv4wP0B/5QbPg=",
+ version = "v0.0.0-20160310174837-449fdfce4d96",
+)
+
+go_repository(
+ name = "com_github_elazarl_goproxy",
+ importpath = "github.com/elazarl/goproxy",
+ sum = "h1:p1yVGRW3nmb85p1Sh1ZJSDm4A4iKLS5QNbvUHMgGu/M=",
+ version = "v0.0.0-20170405201442-c4fc26588b6e",
+)
+
+go_repository(
+ name = "com_github_emicklei_go_restful",
+ importpath = "github.com/emicklei/go-restful",
+ sum = "h1:H2pdYOb3KQ1/YsqVWoWNLQO+fusocsw354rqGTZtAgw=",
+ version = "v0.0.0-20170410110728-ff4f55a20633",
+)
+
+go_repository(
+ name = "com_github_evanphx_json_patch",
+ importpath = "github.com/evanphx/json-patch",
+ sum = "h1:fUDGZCv/7iAN7u0puUVhvKCcsR6vRfwrJatElLBEf0I=",
+ version = "v4.2.0+incompatible",
+)
+
+go_repository(
+ name = "com_github_fsnotify_fsnotify",
+ importpath = "github.com/fsnotify/fsnotify",
+ sum = "h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=",
+ version = "v1.4.7",
+)
+
+go_repository(
+ name = "com_github_ghodss_yaml",
+ importpath = "github.com/ghodss/yaml",
+ sum = "h1:ZktWZesgun21uEDrwW7iEV1zPCGQldM2atlJZ3TdvVM=",
+ version = "v0.0.0-20150909031657-73d445a93680",
+)
+
+go_repository(
+ name = "com_github_go_logr_logr",
+ importpath = "github.com/go-logr/logr",
+ sum = "h1:M1Tv3VzNlEHg6uyACnRdtrploV2P7wZqH8BoQMtz0cg=",
+ version = "v0.1.0",
+)
+
+go_repository(
+ name = "com_github_go_openapi_jsonpointer",
+ importpath = "github.com/go-openapi/jsonpointer",
+ sum = "h1:wSt/4CYxs70xbATrGXhokKF1i0tZjENLOo1ioIO13zk=",
+ version = "v0.0.0-20160704185906-46af16f9f7b1",
+)
+
+go_repository(
+ name = "com_github_go_openapi_jsonreference",
+ importpath = "github.com/go-openapi/jsonreference",
+ sum = "h1:tF+augKRWlWx0J0B7ZyyKSiTyV6E1zZe+7b3qQlcEf8=",
+ version = "v0.0.0-20160704190145-13c6e3589ad9",
+)
+
+go_repository(
+ name = "com_github_go_openapi_spec",
+ importpath = "github.com/go-openapi/spec",
+ sum = "h1:C1JKChikHGpXwT5UQDFaryIpDtyyGL/CR6C2kB7F1oc=",
+ version = "v0.0.0-20160808142527-6aced65f8501",
+)
+
+go_repository(
+ name = "com_github_go_openapi_swag",
+ importpath = "github.com/go-openapi/swag",
+ sum = "h1:zP3nY8Tk2E6RTkqGYrarZXuzh+ffyLDljLxCy1iJw80=",
+ version = "v0.0.0-20160704191624-1d0bd113de87",
+)
+
+go_repository(
+ name = "com_github_google_gofuzz",
+ importpath = "github.com/google/gofuzz",
+ sum = "h1:A8PeW59pxE9IoFRqBp37U+mSNaQoZ46F1f0f863XSXw=",
+ version = "v1.0.0",
+)
+
+go_repository(
+ name = "com_github_googleapis_gnostic",
+ build_file_proto_mode = "disable_global",
+ importpath = "github.com/googleapis/gnostic",
+ sum = "h1:7XGaL1e6bYS1yIonGp9761ExpPPV1ui0SAC59Yube9k=",
+ version = "v0.0.0-20170729233727-0c5108395e2d",
+)
+
+go_repository(
+ name = "com_github_gophercloud_gophercloud",
+ importpath = "github.com/gophercloud/gophercloud",
+ sum = "h1:P/nh25+rzXouhytV2pUHBb65fnds26Ghl8/391+sT5o=",
+ version = "v0.1.0",
+)
+
+go_repository(
+ name = "com_github_gregjones_httpcache",
+ importpath = "github.com/gregjones/httpcache",
+ sum = "h1:pdN6V1QBWetyv/0+wjACpqVH+eVULgEjkurDLq3goeM=",
+ version = "v0.0.0-20180305231024-9cad4c3443a7",
+)
+
+go_repository(
+ name = "com_github_hpcloud_tail",
+ importpath = "github.com/hpcloud/tail",
+ sum = "h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=",
+ version = "v1.0.0",
+)
+
+go_repository(
+ name = "com_github_imdario_mergo",
+ importpath = "github.com/imdario/mergo",
+ sum = "h1:JboBksRwiiAJWvIYJVo46AfV+IAIKZpfrSzVKj42R4Q=",
+ version = "v0.3.5",
+)
+
+go_repository(
+ name = "com_github_json_iterator_go",
+ importpath = "github.com/json-iterator/go",
+ sum = "h1:KfgG9LzI+pYjr4xvmz/5H4FXjokeP+rlHLhv3iH62Fo=",
+ version = "v1.1.7",
+)
+
+go_repository(
+ name = "com_github_mailru_easyjson",
+ importpath = "github.com/mailru/easyjson",
+ sum = "h1:TpvdAwDAt1K4ANVOfcihouRdvP+MgAfDWwBuct4l6ZY=",
+ version = "v0.0.0-20160728113105-d5b7844b561a",
+)
+
+go_repository(
+ name = "com_github_mattbaird_jsonpatch",
+ importpath = "github.com/mattbaird/jsonpatch",
+ sum = "h1:+J2gw7Bw77w/fbK7wnNJJDKmw1IbWft2Ul5BzrG1Qm8=",
+ version = "v0.0.0-20171005235357-81af80346b1a",
+)
+
+go_repository(
+ name = "com_github_modern_go_concurrent",
+ importpath = "github.com/modern-go/concurrent",
+ sum = "h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=",
+ version = "v0.0.0-20180306012644-bacd9c7ef1dd",
+)
+
+go_repository(
+ name = "com_github_modern_go_reflect2",
+ importpath = "github.com/modern-go/reflect2",
+ sum = "h1:9f412s+6RmYXLWZSEzVVgPGK7C2PphHj5RJrvfx9AWI=",
+ version = "v1.0.1",
+)
+
+go_repository(
+ name = "com_github_munnerz_goautoneg",
+ importpath = "github.com/munnerz/goautoneg",
+ sum = "h1:7PxY7LVfSZm7PEeBTyK1rj1gABdCO2mbri6GKO1cMDs=",
+ version = "v0.0.0-20120707110453-a547fc61f48d",
+)
+
+go_repository(
+ name = "com_github_mxk_go_flowrate",
+ importpath = "github.com/mxk/go-flowrate",
+ sum = "h1:y5//uYreIhSUg3J1GEMiLbxo1LJaP8RfCpH6pymGZus=",
+ version = "v0.0.0-20140419014527-cca7078d478f",
+)
+
+go_repository(
+ name = "com_github_nytimes_gziphandler",
+ importpath = "github.com/NYTimes/gziphandler",
+ sum = "h1:lsxEuwrXEAokXB9qhlbKWPpo3KMLZQ5WB5WLQRW1uq0=",
+ version = "v0.0.0-20170623195520-56545f4a5d46",
+)
+
+go_repository(
+ name = "com_github_onsi_ginkgo",
+ importpath = "github.com/onsi/ginkgo",
+ sum = "h1:VkHVNpR4iVnU8XQR6DBm8BqYjN7CRzw+xKUbVVbbW9w=",
+ version = "v1.8.0",
+)
+
+go_repository(
+ name = "com_github_onsi_gomega",
+ importpath = "github.com/onsi/gomega",
+ sum = "h1:izbySO9zDPmjJ8rDjLvkA2zJHIo+HkYXHnf7eN7SSyo=",
+ version = "v1.5.0",
+)
+
+go_repository(
+ name = "com_github_peterbourgon_diskv",
+ importpath = "github.com/peterbourgon/diskv",
+ sum = "h1:UBdAOUP5p4RWqPBg048CAvpKN+vxiaj6gdUUzhl4XmI=",
+ version = "v2.0.1+incompatible",
+)
+
+go_repository(
+ name = "com_github_puerkitobio_purell",
+ importpath = "github.com/PuerkitoBio/purell",
+ sum = "h1:0GoNN3taZV6QI81IXgCbxMyEaJDXMSIjArYBCYzVVvs=",
+ version = "v1.0.0",
+)
+
+go_repository(
+ name = "com_github_puerkitobio_urlesc",
+ importpath = "github.com/PuerkitoBio/urlesc",
+ sum = "h1:JCHLVE3B+kJde7bIEo5N4J+ZbLhp0J1Fs+ulyRws4gE=",
+ version = "v0.0.0-20160726150825-5bd2802263f2",
+)
+
+go_repository(
+ name = "com_github_spf13_afero",
+ importpath = "github.com/spf13/afero",
+ sum = "h1:5jhuqJyZCZf2JRofRvN/nIFgIWNzPa3/Vz8mYylgbWc=",
+ version = "v1.2.2",
+)
+
+go_repository(
+ name = "in_gopkg_fsnotify_v1",
+ importpath = "gopkg.in/fsnotify.v1",
+ sum = "h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=",
+ version = "v1.4.7",
+)
+
+go_repository(
+ name = "in_gopkg_inf_v0",
+ importpath = "gopkg.in/inf.v0",
+ sum = "h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=",
+ version = "v0.9.1",
+)
+
+go_repository(
+ name = "in_gopkg_tomb_v1",
+ importpath = "gopkg.in/tomb.v1",
+ sum = "h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=",
+ version = "v1.0.0-20141024135613-dd632973f1e7",
+)
+
+go_repository(
+ name = "io_k8s_api",
+ build_file_proto_mode = "disable_global",
+ importpath = "k8s.io/api",
+ sum = "h1:/RE6SNxrws72vzEJsCil3WSR2T9gUlYYoRxnJyZiexs=",
+ version = "v0.16.13",
+)
+
+go_repository(
+ name = "io_k8s_apimachinery",
+ build_file_proto_mode = "disable_global",
+ importpath = "k8s.io/apimachinery",
+ sum = "h1:eUHWTe8VT+VOZVKGfSCcFZDrr9RZ8djLYGjIanaZnXc=",
+ version = "v0.16.14-rc.0",
+)
+
+go_repository(
+ name = "io_k8s_client_go",
+ importpath = "k8s.io/client-go",
+ sum = "h1:jp76b20+4h8qZBxferSAVZ6MjBEpw3F309zLmPhngag=",
+ version = "v0.16.13",
+)
+
+go_repository(
+ name = "io_k8s_gengo",
+ importpath = "k8s.io/gengo",
+ sum = "h1:4s3/R4+OYYYUKptXPhZKjQ04WJ6EhQQVFdjOFvCazDk=",
+ version = "v0.0.0-20190128074634-0689ccc1d7d6",
+)
+
+go_repository(
+ name = "io_k8s_klog",
+ importpath = "k8s.io/klog",
+ sum = "h1:Pt+yjF5aB1xDSVbau4VsWe+dQNzA0qv1LlXdC2dF6Q8=",
+ version = "v1.0.0",
+)
+
+go_repository(
+ name = "io_k8s_kube_openapi",
+ importpath = "k8s.io/kube-openapi",
+ sum = "h1:PsbYeEz2x7ll6JYUzBEG+DT78910DDTlvn5Ma10F5/E=",
+ version = "v0.0.0-20200410163147-594e756bea31",
+)
+
+go_repository(
+ name = "io_k8s_sigs_structured_merge_diff",
+ importpath = "sigs.k8s.io/structured-merge-diff",
+ sum = "h1:4Z09Hglb792X0kfOBBJUPFEyvVfQWrYT/l8h5EKA6JQ=",
+ version = "v0.0.0-20190525122527-15d366b2352e",
+)
+
+go_repository(
+ name = "io_k8s_sigs_yaml",
+ importpath = "sigs.k8s.io/yaml",
+ sum = "h1:4A07+ZFc2wgJwo8YNlQpr1rVlgUDlxXHhPJciaPY5gs=",
+ version = "v1.1.0",
+)
+
+go_repository(
+ name = "io_k8s_utils",
+ importpath = "k8s.io/utils",
+ sum = "h1:+ySTxfHnfzZb9ys375PXNlLhkJPLKgHajBU0N62BDvE=",
+ version = "v0.0.0-20190801114015-581e00157fb1",
+)
diff --git a/g3doc/user_guide/containerd/quick_start.md b/g3doc/user_guide/containerd/quick_start.md
index b6a3186d8..a98fe5c4a 100644
--- a/g3doc/user_guide/containerd/quick_start.md
+++ b/g3doc/user_guide/containerd/quick_start.md
@@ -1,7 +1,7 @@
# Containerd Quick Start
-This document describes how to install and configure `containerd-shim-runsc-v1`
-using the containerd runtime handler support on `containerd` 1.2 or later.
+This document describes how to use `containerd-shim-runsc-v1` with the
+containerd runtime handler support on `containerd` 1.2 or later.
> ⚠️ NOTE: If you are using Kubernetes and set up your cluster using kubeadm you
> may run into issues. See the [FAQ](../FAQ.md#runtime-handler) for details.
diff --git a/g3doc/user_guide/install.md b/g3doc/user_guide/install.md
index abb9e8582..c3ced9d61 100644
--- a/g3doc/user_guide/install.md
+++ b/g3doc/user_guide/install.md
@@ -13,15 +13,19 @@ To download and install the latest release manually follow these steps:
(
set -e
URL=https://storage.googleapis.com/gvisor/releases/release/latest
- wget ${URL}/runsc ${URL}/runsc.sha512
- sha512sum -c runsc.sha512
- rm -f runsc.sha512
- sudo mv runsc /usr/local/bin
- sudo chmod a+rx /usr/local/bin/runsc
+ wget ${URL}/runsc ${URL}/runsc.sha512 \
+ ${URL}/gvisor-containerd-shim ${URL}/gvisor-containerd-shim.sha512 \
+ ${URL}/containerd-shim-runsc-v1 ${URL}/containerd-shim-runsc-v1.sha512
+ sha512sum -c runsc.sha512 \
+ -c gvisor-containerd-shim.sha512 \
+ -c containerd-shim-runsc-v1.sha512
+ rm -f *.sha512
+ chmod a+rx runsc gvisor-containerd-shim containerd-shim-runsc-v1
+ sudo mv runsc gvisor-containerd-shim containerd-shim-runsc-v1 /usr/local/bin
)
```
-To install gVisor with Docker, run the following commands:
+To install gVisor as a Docker runtime, run the following commands:
```bash
/usr/local/bin/runsc install
@@ -165,5 +169,6 @@ You can use this link with the steps described in
Note that `apt` installation of a specific point release is not supported.
After installation, try out `runsc` by following the
-[Docker Quick Start](./quick_start/docker.md) or
+[Docker Quick Start](./quick_start/docker.md),
+[Containerd QuickStart](./containerd/quick_start.md), or
[OCI Quick Start](./quick_start/oci.md).
diff --git a/go.mod b/go.mod
index e6df99177..144543169 100644
--- a/go.mod
+++ b/go.mod
@@ -29,11 +29,12 @@ require (
github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e // indirect
github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 // indirect
github.com/gogo/googleapis v1.4.0 // indirect
- github.com/google/go-cmp v0.5.1 // indirect
+ github.com/google/go-cmp v0.5.3-0.20201020212313-ab46b8bd0abd // indirect
github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8 // indirect
github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8 // indirect
github.com/hashicorp/go-multierror v1.0.0 // indirect
github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1 // indirect
+ github.com/mattbaird/jsonpatch v0.0.0-20171005235357-81af80346b1a
github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9 // indirect
github.com/opencontainers/image-spec v1.0.1 // indirect
github.com/opencontainers/runc v0.1.1 // indirect
@@ -43,12 +44,13 @@ require (
github.com/urfave/cli v1.22.2 // indirect
github.com/vishvananda/netlink v1.0.1-0.20190930145447-2ec5bdc52b86 // indirect
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae // indirect
- go.uber.org/atomic v1.7.0 // indirect
- go.uber.org/multierr v1.2.0 // indirect
+ go.uber.org/multierr v1.6.0 // indirect
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect
- golang.org/x/tools v0.0.0-20201002184944-ecd9fd270d5d // indirect
+ golang.org/x/tools v0.0.0-20201021000207-d49c4edd7d96 // indirect
google.golang.org/grpc v1.29.0 // indirect
- google.golang.org/protobuf v1.25.1-0.20200808011614-a180de9f97d9 // indirect
- gopkg.in/yaml.v2 v2.2.8 // indirect
+ google.golang.org/protobuf v1.25.1-0.20201020201750-d3470999428b // indirect
gotest.tools v2.2.0+incompatible // indirect
+ k8s.io/api v0.16.13
+ k8s.io/apimachinery v0.16.14-rc.0
+ k8s.io/client-go v0.16.13
)
diff --git a/go.sum b/go.sum
index e713d2eaa..060d5596a 100644
--- a/go.sum
+++ b/go.sum
@@ -13,6 +13,13 @@ cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7
cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I=
cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw=
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
+github.com/Azure/go-autorest/autorest v0.9.0/go.mod h1:xyHB1BMZT0cuDHU7I0+g046+BFDTQ8rEZB0s4Yfa6bI=
+github.com/Azure/go-autorest/autorest/adal v0.5.0/go.mod h1:8Z9fGy2MpX0PvDjB1pEgQTmVqjGhiHBW7RJJEciWzS0=
+github.com/Azure/go-autorest/autorest/date v0.1.0/go.mod h1:plvfp3oPSKwf2DNjlBjWF/7vwR+cUD/ELuzDCXwHUVA=
+github.com/Azure/go-autorest/autorest/mocks v0.1.0/go.mod h1:OTyCOPRA2IgIlWxVYxBee2F5Gr4kF2zd2J5cFRaIDN0=
+github.com/Azure/go-autorest/autorest/mocks v0.2.0/go.mod h1:OTyCOPRA2IgIlWxVYxBee2F5Gr4kF2zd2J5cFRaIDN0=
+github.com/Azure/go-autorest/logger v0.1.0/go.mod h1:oExouG+K6PryycPJfVSxi/koC6LSNgds39diKLz7Vrc=
+github.com/Azure/go-autorest/tracing v0.5.0/go.mod h1:r/s2XiOKccPW3HrqB+W0TQzfbtp2fGCgRFtBroKn4Dk=
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
@@ -27,6 +34,9 @@ github.com/Microsoft/hcsshim v0.8.8/go.mod h1:5692vkUqntj1idxauYlpoINNKeqCiG6Sg3
github.com/Microsoft/hcsshim v0.8.9/go.mod h1:5692vkUqntj1idxauYlpoINNKeqCiG6Sg38RRsjT5y8=
github.com/Microsoft/hcsshim v0.8.10 h1:k5wTrpnVU2/xv8ZuzGkbXVd3js5zJ8RnumPo5RxiIxU=
github.com/Microsoft/hcsshim v0.8.10/go.mod h1:g5uw8EV2mAlzqe94tfNBNdr89fnbD/n3HV0OhsddkmM=
+github.com/NYTimes/gziphandler v0.0.0-20170623195520-56545f4a5d46/go.mod h1:3wb06e3pkSAbeQ52E9H9iFoQsEEwGN64994WTCIhntQ=
+github.com/PuerkitoBio/purell v1.0.0/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0=
+github.com/PuerkitoBio/urlesc v0.0.0-20160726150825-5bd2802263f2/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE=
github.com/blang/semver v3.1.0+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk=
github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422 h1:8eZxmY1yvxGHzdzTEhI09npjMVGzNAdrqzruTX6jcK4=
github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422/go.mod h1:b6Nc7NRH5C4aCISLry0tLnTjcuTEvoiqcWDdsU0sOGM=
@@ -68,8 +78,11 @@ github.com/coreos/go-systemd/v22 v22.0.0 h1:XJIw/+VlJ+87J+doOxznsAWIdmWuViOVhkQa
github.com/coreos/go-systemd/v22 v22.0.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk=
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
+github.com/davecgh/go-spew v0.0.0-20151105211317-5215b55f46b2/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible h1:dvc1KSkIYTVjZgHf/CTC2diTYC8PzhaA5sFISRfNVrE=
github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w=
github.com/docker/docker v1.4.2-0.20191028175130-9e7d5ac5ea55 h1:5AkIsnQpeL7eaqsM+Vl4Xbj5eIZFpPZZzXtNyfzzK/w=
@@ -80,14 +93,25 @@ github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c h1:+pKlWGMw7gf6bQ
github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c/go.mod h1:Uw6UezgYA44ePAFQYUehOuCzmy5zmg/+nl2ZfMWGkpA=
github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw=
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
+github.com/docker/spdystream v0.0.0-20160310174837-449fdfce4d96/go.mod h1:Qh8CwZgvJUkLughtfhJv5dyTYa91l1fOUCrgjqmcifM=
github.com/dpjacques/clockwork v0.1.1-0.20200827220843-c1f524b839be h1:l+j1wSnHcimOzeeKxtspsl6tCBTyikdYxcWqFZ+Ho2c=
github.com/dpjacques/clockwork v0.1.1-0.20200827220843-c1f524b839be/go.mod h1:D8mP2A8vVT2GkXqPorSBmhnshhkFBYgzhA90KmJt25Y=
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
+github.com/elazarl/goproxy v0.0.0-20170405201442-c4fc26588b6e/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc=
+github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
+github.com/evanphx/json-patch v4.2.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk=
+github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
+github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
+github.com/go-logr/logr v0.1.0/go.mod h1:ixOQHD9gLJUVQQ2ZOR7zLEifBX6tGkNJF4QyIY7sIas=
+github.com/go-openapi/jsonpointer v0.0.0-20160704185906-46af16f9f7b1/go.mod h1:+35s3my2LFTysnkMfxsJBAMHj/DoqoB9knIWoYG/Vk0=
+github.com/go-openapi/jsonreference v0.0.0-20160704190145-13c6e3589ad9/go.mod h1:W3Z9FmVs9qj+KR4zFKmDPGiLdk1D9Rlm7cyMvf57TTg=
+github.com/go-openapi/spec v0.0.0-20160808142527-6aced65f8501/go.mod h1:J8+jY1nAiCcj+friV/PDoE1/3eeccG9LYBs0tYvLOWc=
+github.com/go-openapi/swag v0.0.0-20160704191624-1d0bd113de87/go.mod h1:DXUve3Dpr1UfpPtxFw+EFuQ41HhCWZfha5jSVRG7C7I=
github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e h1:BWhy2j3IXJhjCbC68FptL43tDKIq8FladmaTs3Xs7Z8=
github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4=
github.com/godbus/dbus/v5 v5.0.3 h1:ZqHaoEF7TBzh4jzPmqVhE/5A1z9of6orkAe5uHoAeME=
@@ -96,9 +120,11 @@ github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 h1:JFTFz3HZTGmgMz4E1
github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU=
github.com/gogo/googleapis v1.4.0 h1:zgVt4UpGxcqVOw97aRGxT4svlcmdK35fynLNctY32zI=
github.com/gogo/googleapis v1.4.0/go.mod h1:5YRNX2z1oM5gXdAkurHa942MDgEJyk02w4OecKY87+c=
+github.com/gogo/protobuf v1.2.2-0.20190723190241-65acae22fc9d/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
github.com/gogo/protobuf v1.3.1 h1:DqDEcV5aeaTmdFBePNpYsp3FlcVH/2ISVVM9Qf8PSls=
github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
+github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7 h1:5ZkaAPbicIKTF2I64qf5Fh8Aa83Q/dnOafMYV0OMwjA=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
@@ -106,6 +132,7 @@ github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfb
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.3.1 h1:qGJ6qTW+x6xX/my+8YUVl4WNpX9B7+/l2tRsHGZ7f2s=
github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y=
+github.com/golang/protobuf v0.0.0-20161109072736-4bd1920723d7/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
@@ -125,11 +152,14 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.5.1 h1:JFrFEBb2xKufg6XkJsJr+WbKb4FQlURi5RUcBveYu9k=
-github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/go-cmp v0.5.3-0.20201020212313-ab46b8bd0abd h1:pJfrTSHC+QpCQplFZqzlwihfc+0Oty0ViHPHPxXj0SI=
+github.com/google/go-cmp v0.5.3-0.20201020212313-ab46b8bd0abd/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8 h1:zOOUQavr8D4AZrcV4ylUpbGa5j3jfeslN6Xculz3tVU=
github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8/go.mod h1:g82e6OHbJ0WYrYeOrid1MMfHAtqjxBz+N74tfAt9KrQ=
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
+github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI=
+github.com/google/gofuzz v1.0.0 h1:A8PeW59pxE9IoFRqBp37U+mSNaQoZ46F1f0f863XSXw=
+github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
@@ -137,18 +167,28 @@ github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hf
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8 h1:8nlgEAjIalk6uj/CGKCdOO8CQqTeysvcW4RFZ6HbkGM=
github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
-github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA=
github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
+github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
+github.com/googleapis/gnostic v0.0.0-20170729233727-0c5108395e2d h1:7XGaL1e6bYS1yIonGp9761ExpPPV1ui0SAC59Yube9k=
+github.com/googleapis/gnostic v0.0.0-20170729233727-0c5108395e2d/go.mod h1:sJBsCZ4ayReDTBIg8b9dl28c5xFWyhBTVRp3pOg5EKY=
+github.com/gophercloud/gophercloud v0.1.0/go.mod h1:vxM41WHh5uqHVBMZHzuwNOHh8XEoIEcSTewFxm1c5g8=
+github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA=
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.0.0 h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uPribsnS6o=
github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
+github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
+github.com/imdario/mergo v0.3.5/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
+github.com/json-iterator/go v0.0.0-20180612202835-f2b4162afba3/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
+github.com/json-iterator/go v1.1.7 h1:KfgG9LzI+pYjr4xvmz/5H4FXjokeP+rlHLhv3iH62Fo=
+github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
github.com/jstemmer/go-junit-report v0.9.1 h1:6QPYqodiu3GuPL+7mfx+NwDdp2eTkp9IfEUpgAwUN0o=
github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
@@ -164,8 +204,25 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1 h1:zc0R6cOw98cMengLA0fvU55mqbnN7sd/tBMLzSejp+M=
github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
+github.com/mailru/easyjson v0.0.0-20160728113105-d5b7844b561a/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
+github.com/mattbaird/jsonpatch v0.0.0-20171005235357-81af80346b1a h1:+J2gw7Bw77w/fbK7wnNJJDKmw1IbWft2Ul5BzrG1Qm8=
+github.com/mattbaird/jsonpatch v0.0.0-20171005235357-81af80346b1a/go.mod h1:M1qoD/MqPgTZIk0EWKB38wE28ACRfVcn+cU08jyArI0=
+github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
+github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
+github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
+github.com/modern-go/reflect2 v0.0.0-20180320133207-05fbef0ca5da/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
+github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
+github.com/modern-go/reflect2 v1.0.1 h1:9f412s+6RmYXLWZSEzVVgPGK7C2PphHj5RJrvfx9AWI=
+github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9 h1:Sha2bQdoWE5YQPTlJOL31rmce94/tYi113SlFo1xQ2c=
github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8=
+github.com/munnerz/goautoneg v0.0.0-20120707110453-a547fc61f48d/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
+github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
+github.com/onsi/ginkgo v0.0.0-20170829012221-11459a886d9c/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
+github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
+github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
+github.com/onsi/gomega v0.0.0-20170829124025-dcabb60a477c/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA=
+github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/opencontainers/go-digest v0.0.0-20180430190053-c9281466c8b2/go.mod h1:cMLVZDEM3+U2I4VmLI6N8jQYUd2OVphdqWwCJHrFt2s=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
@@ -181,9 +238,12 @@ github.com/opencontainers/runtime-spec v1.0.2 h1:UfAcuLBJB9Coz72x1hgl8O5RVzTdNia
github.com/opencontainers/runtime-spec v1.0.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
github.com/pborman/uuid v1.2.0 h1:J7Q5mO4ysT1dv8hyrUGHb9+ooztCXu1D8MY8DZYsu3g=
github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
+github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
+github.com/pmezard/go-difflib v0.0.0-20151028094244-d8ed2627bdf0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/procfs v0.0.0-20180125133057-cb4147076ac7/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
@@ -196,12 +256,18 @@ github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
+github.com/spf13/afero v1.2.2/go.mod h1:9ZxEEn6pIJ8Rxe320qSDBk6AsU0r9pR7Q4OcevTdifk=
github.com/spf13/cobra v0.0.2-0.20171109065643-2da4a54c5cee/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ=
+github.com/spf13/pflag v0.0.0-20170130214245-9ff6c6923cff/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
github.com/spf13/pflag v1.0.1-0.20171106142849-4c012f6dcd95/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
+github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
+github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
+github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 h1:b6uOv7YOFK0TYG7HtkIgExQo+2RdLuwRft63jn2HWj8=
github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww=
@@ -217,12 +283,15 @@ go.opencensus.io v0.22.2 h1:75k/FF0Q2YM8QYo07VPddOLBslDt1MZOdEslOHvmzAs=
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
-go.uber.org/multierr v1.2.0 h1:6I+W7f5VwC5SV9dNrZ3qXrDB9mD0dyGOi/ZJmYw03T4=
-go.uber.org/multierr v1.2.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
+go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4=
+go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
+golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
+golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@@ -247,6 +316,7 @@ golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
+golang.org/x/net v0.0.0-20170114055629-f2499483f923/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -275,8 +345,11 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208 h1:qwRHBd0NqMbJxfbotnDhm2ByMI1Shq4Y6oRJo21SGJA=
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20190209173611-3b5209105503/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -285,6 +358,7 @@ golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191022100944-742c48ecaeb7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -295,6 +369,7 @@ golang.org/x/sys v0.0.0-20200120151820-655fe14d7479/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
@@ -304,6 +379,7 @@ golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxb
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.org/x/tools v0.0.0-20181011042414-1f849cf54d09/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
@@ -322,8 +398,8 @@ golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/tools v0.0.0-20201002184944-ecd9fd270d5d h1:vWQvJ/Z0Lu+9/8oQ/pAYXNzbc7CMnBl+tULGVHOy3oE=
-golang.org/x/tools v0.0.0-20201002184944-ecd9fd270d5d/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU=
+golang.org/x/tools v0.0.0-20201021000207-d49c4edd7d96 h1:K+nJoPcImWk+ZGPHOKkDocKcQPACCz8usiCiVQYfXsk=
+golang.org/x/tools v0.0.0-20201021000207-d49c4edd7d96/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -349,9 +425,8 @@ google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8=
google.golang.org/genproto v0.0.0-20200115191322-ca5a22157cba/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
+google.golang.org/genproto v0.0.0-20200117163144-32f20d992d24 h1:wDju+RU97qa0FZT0QnZDg9Uc2dH0Ql513kFvHocz+WM=
google.golang.org/genproto v0.0.0-20200117163144-32f20d992d24/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
-google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY=
-google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
@@ -359,7 +434,6 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac
google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
-google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
google.golang.org/grpc v1.29.0 h1:2pJjwYOdkZ9HlN4sWRYBg9ttH5bCOlsueaM+b/oYjwo=
google.golang.org/grpc v1.29.0/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
@@ -368,13 +442,18 @@ google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQ
google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
-google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
-google.golang.org/protobuf v1.25.1-0.20200808011614-a180de9f97d9 h1:poC0iCcx0QXFYlS6nuq/8K+Ng5T55k0FXdzq52hVi4w=
-google.golang.org/protobuf v1.25.1-0.20200808011614-a180de9f97d9/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
+google.golang.org/protobuf v1.25.1-0.20201020201750-d3470999428b h1:jEdfCm+8YTWSYgU4L7Nq0jjU+q9RxIhi0cXLTY+Ih3A=
+google.golang.org/protobuf v1.25.1-0.20201020201750-d3470999428b/go.mod h1:hFxJC2f0epmp1elRCiEGJTKAWbwxZ2nvqZdHl3FQXCY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
+gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
+gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
+gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
+gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
+gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
@@ -383,4 +462,22 @@ honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWh
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM=
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
+k8s.io/api v0.16.13 h1:/RE6SNxrws72vzEJsCil3WSR2T9gUlYYoRxnJyZiexs=
+k8s.io/api v0.16.13/go.mod h1:QWu8UWSTiuQZMMeYjwLs6ILu5O74qKSJ0c+4vrchDxs=
+k8s.io/apimachinery v0.16.13/go.mod h1:4HMHS3mDHtVttspuuhrJ1GGr/0S9B6iWYWZ57KnnZqQ=
+k8s.io/apimachinery v0.16.14-rc.0 h1:eUHWTe8VT+VOZVKGfSCcFZDrr9RZ8djLYGjIanaZnXc=
+k8s.io/apimachinery v0.16.14-rc.0/go.mod h1:4HMHS3mDHtVttspuuhrJ1GGr/0S9B6iWYWZ57KnnZqQ=
+k8s.io/client-go v0.16.13 h1:jp76b20+4h8qZBxferSAVZ6MjBEpw3F309zLmPhngag=
+k8s.io/client-go v0.16.13/go.mod h1:UKvVT4cajC2iN7DCjLgT0KVY/cbY6DGdUCyRiIfws5M=
+k8s.io/gengo v0.0.0-20190128074634-0689ccc1d7d6/go.mod h1:ezvh/TsK7cY6rbqRK0oQQ8IAqLxYwwyPxAX1Pzy0ii0=
+k8s.io/klog v0.0.0-20181102134211-b9b56d5dfc92/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk=
+k8s.io/klog v0.3.0/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk=
+k8s.io/klog v1.0.0 h1:Pt+yjF5aB1xDSVbau4VsWe+dQNzA0qv1LlXdC2dF6Q8=
+k8s.io/klog v1.0.0/go.mod h1:4Bi6QPql/J/LkTDqv7R/cd3hPo4k2DG6Ptcz060Ez5I=
+k8s.io/kube-openapi v0.0.0-20200410163147-594e756bea31/go.mod h1:1TqjTSzOxsLGIKfj0lK8EeCP7K1iUG65v09OM0/WG5E=
+k8s.io/utils v0.0.0-20190801114015-581e00157fb1 h1:+ySTxfHnfzZb9ys375PXNlLhkJPLKgHajBU0N62BDvE=
+k8s.io/utils v0.0.0-20190801114015-581e00157fb1/go.mod h1:sZAwmy6armz5eXlNoLmJcl4F1QuKu7sr+mFQ0byX7Ew=
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
+sigs.k8s.io/structured-merge-diff v0.0.0-20190525122527-15d366b2352e/go.mod h1:wWxsB5ozmmv/SG7nM11ayaAW51xMvak/t1r0CSlcokI=
+sigs.k8s.io/yaml v1.1.0 h1:4A07+ZFc2wgJwo8YNlQpr1rVlgUDlxXHhPJciaPY5gs=
+sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o=
diff --git a/images/README.md b/images/README.md
index 9880946a6..297c7c3f3 100644
--- a/images/README.md
+++ b/images/README.md
@@ -41,9 +41,9 @@ All images will be tagged and memoized using a hash of the directory contents.
As a result, every image should be made completely reproducible if possible.
This means using fixed tags and fixed versions whenever feasible.
-Notes that images should also be made architecture-independent if possible. The
-build scripts will handling loading the appropriate architecture onto the
-machine and tagging it with the single canonical tag.
+Note that images should also be made architecture-independent if possible. The
+build scripts will handle loading the appropriate architecture onto the machine
+and tagging it with the single canonical tag.
Add a `load-<image>` dependency in the Makefile if the image is required for a
particular set of tests. This target will pull the tag from the image repository
diff --git a/images/defs.bzl b/images/defs.bzl
index 61d7bbf73..c1f96e312 100644
--- a/images/defs.bzl
+++ b/images/defs.bzl
@@ -2,30 +2,33 @@
def _docker_image_impl(ctx):
importer = ctx.actions.declare_file(ctx.label.name)
+
importer_content = [
"#!/bin/bash",
"set -euo pipefail",
+ "source_file='%s'" % ctx.file.data.path,
+ "if [[ ! -f \"$source_file\" ]]; then",
+ " source_file='%s'" % ctx.file.data.short_path,
+ "fi",
"exec docker import " + " ".join([
"-c '%s'" % attr
for attr in ctx.attr.statements
- ]) + " " + " ".join([
- "'%s'" % f.path
- for f in ctx.files.data
- ]) + " $1",
+ ]) + " \"$source_file\" $1",
"",
]
+
ctx.actions.write(importer, "\n".join(importer_content), is_executable = True)
return [DefaultInfo(
- runfiles = ctx.runfiles(ctx.files.data),
+ runfiles = ctx.runfiles([ctx.file.data]),
executable = importer,
)]
docker_image = rule(
implementation = _docker_image_impl,
- doc = "Tool to load a Docker image; takes a single parameter (image name).",
+ doc = "Tool to import a Docker image; takes a single parameter (image name).",
attrs = {
"statements": attr.string_list(doc = "Extra Dockerfile directives."),
- "data": attr.label_list(doc = "All image data."),
+ "data": attr.label(doc = "Image filesystem tarball", allow_single_file = [".tgz", ".tar.gz"]),
},
executable = True,
)
diff --git a/nogo.yaml b/nogo.yaml
new file mode 100644
index 000000000..5c1737f59
--- /dev/null
+++ b/nogo.yaml
@@ -0,0 +1,253 @@
+groups:
+ # We define three basic groups: generated (all generated files),
+ # external (all files outside the repository), and internal (all
+ # files within the local repository). We can't enforce many style
+ # checks on generated and external code, so enable those cases
+ # selectively for analyzers below.
+ - name: generated
+ regex: "^(bazel-genfiles|bazel-out|bazel-bin)/"
+ default: true
+ - name: external
+ regex: "^external/"
+ default: false
+ - name: internal
+ regex: ".*"
+ default: true
+global:
+ generated:
+ suppress:
+ # Suppress the basic style checks for
+ # generated code, but keep the analysis
+ # that are required for quality & security.
+ - "should not use ALL_CAPS in Go names"
+ - "should not use underscores"
+ - "comment on exported"
+ - "methods on the same type should have the same receiver name"
+ - "at least one file in a package"
+ - "package comment should be of the form"
+ # Generated code may have dead code paths.
+ - "identical build constraints"
+ - "no value of type"
+ - "is never used"
+ # go_embed_data rules generate unicode literals.
+ - "string literal contains the Unicode format character"
+ - "string literal contains the Unicode control character"
+ - "string literal contains Unicode control characters"
+ - "string literal contains Unicode format and control characters"
+ # Some external code will generate protov1
+ # implementations. These should be ignored.
+ - "proto.* is deprecated"
+ - "xxx_messageInfo_.*"
+ - "receiver name should be a reflection of its identity"
+ # Generated gRPC code is not compliant either.
+ - "error strings should not be capitalized"
+ - "grpc.Errorf is deprecated"
+ # Generated proto code does not always follow capitalization conventions.
+ - "(field|method|struct|type) .* should be .*"
+ # Generated proto code sometimes duplicates imports with aliases.
+ - "duplicate import"
+ internal:
+ suppress:
+ # We use ALL_CAPS for system definitions,
+ # which are common enough in the code base
+ # that we shouldn't annotate exceptions.
+ #
+ # Same story for underscores.
+ - "should not use ALL_CAPS in Go names"
+ - "should not use underscores in Go names"
+ exclude:
+ # A variety of staticcheck and stylecheck
+ # rules apply here. These should be fixed
+ # and removed from here, and the global
+ # rules should be used sparingly.
+ - pkg/abi/linux/fuse.go:22
+ - pkg/abi/linux/fuse.go:25
+ - pkg/abi/linux/socket.go:113
+ - pkg/abi/linux/tty.go:73
+ - pkg/cpuid/cpuid_x86.go:675
+ - pkg/gohacks/gohacks_unsafe.go:33
+ - pkg/log/json.go:30
+ - pkg/log/log.go:359
+ - pkg/metric/metric_test.go:20
+ - pkg/p9/p9test/client_test.go:687
+ - pkg/p9/transport_test.go:196
+ - pkg/pool/pool.go:15
+ - pkg/refs/refcounter.go:510
+ - pkg/refs/refcounter_test.go:169
+ - pkg/refs_vfs2/refs.go:16
+ - pkg/safemem/block_unsafe.go:89
+ - pkg/seccomp/seccomp.go:82
+ - pkg/segment/test/set_functions.go:15
+ - pkg/sentry/arch/signal.go:166
+ - pkg/sentry/arch/signal.go:171
+ - pkg/sentry/control/pprof.go:196
+ - pkg/sentry/devices/memdev/full.go:58
+ - pkg/sentry/devices/memdev/null.go:59
+ - pkg/sentry/devices/memdev/random.go:68
+ - pkg/sentry/devices/memdev/zero.go:86
+ - pkg/sentry/fdimport/fdimport.go:15
+ - pkg/sentry/fs/attr.go:257
+ - pkg/sentry/fsbridge/fs.go:116
+ - pkg/sentry/fsbridge/vfs.go:124
+ - pkg/sentry/fsbridge/vfs.go:70
+ - pkg/sentry/fs/copy_up.go:365
+ - pkg/sentry/fs/copy_up_test.go:65
+ - pkg/sentry/fs/dev/net_tun.go:161
+ - pkg/sentry/fs/dev/net_tun.go:63
+ - pkg/sentry/fs/dev/null.go:97
+ - pkg/sentry/fs/dirent_cache.go:64
+ - pkg/sentry/fs/fdpipe/pipe_opener_test.go:366
+ - pkg/sentry/fs/file_overlay.go:327
+ - pkg/sentry/fs/file_overlay.go:524
+ - pkg/sentry/fs/filetest/filetest.go:55
+ - pkg/sentry/fs/filetest/filetest.go:60
+ - pkg/sentry/fs/fs.go:77
+ - pkg/sentry/fs/fsutil/file.go:290
+ - pkg/sentry/fs/fsutil/file.go:346
+ - pkg/sentry/fs/fsutil/host_file_mapper.go:105
+ - pkg/sentry/fs/fsutil/inode_cached.go:676
+ - pkg/sentry/fs/fsutil/inode_cached.go:772
+ - pkg/sentry/fs/gofer/attr.go:120
+ - pkg/sentry/fs/gofer/fifo.go:33
+ - pkg/sentry/fs/gofer/inode.go:410
+ - pkg/sentry/fsimpl/ext/disklayout/superblock_64.go:97
+ - pkg/sentry/fsimpl/ext/disklayout/superblock_old.go:92
+ - pkg/sentry/fsimpl/ext/disklayout/block_group_32.go:44
+ - pkg/sentry/fsimpl/ext/disklayout/inode_new.go:91
+ - pkg/sentry/fsimpl/ext/disklayout/inode_old.go:93
+ - pkg/sentry/fsimpl/ext/disklayout/superblock_32.go:66
+ - pkg/sentry/fsimpl/ext/disklayout/block_group_64.go:53
+ - pkg/sentry/fsimpl/fuse/request_response.go:71
+ - pkg/sentry/fsimpl/signalfd/signalfd.go:15
+ - pkg/sentry/memmap/memmap.go:103
+ - pkg/sentry/memmap/memmap.go:163
+ - pkg/sentry/mm/aio_context.go:208
+ - pkg/sentry/mm/pma.go:683
+ - pkg/sentry/usage/cpu.go:42
+ - pkg/shim/runsc/runsc.go:16
+ - pkg/shim/runsc/utils.go:16
+ - pkg/shim/v1/proc/deleted_state.go:16
+ - pkg/shim/v1/proc/exec.go:16
+ - pkg/shim/v1/proc/exec_state.go:16
+ - pkg/shim/v1/proc/init.go:16
+ - pkg/shim/v1/proc/init_state.go:16
+ - pkg/shim/v1/proc/io.go:16
+ - pkg/shim/v1/proc/process.go:16
+ - pkg/shim/v1/proc/types.go:16
+ - pkg/shim/v1/proc/utils.go:16
+ - pkg/shim/v1/shim/api.go:16
+ - pkg/shim/v1/shim/platform.go:16
+ - pkg/shim/v1/shim/service.go:16
+ - pkg/shim/v1/utils/annotations.go:15
+ - pkg/shim/v1/utils/utils.go:15
+ - pkg/shim/v1/utils/volumes.go:15
+ - pkg/shim/v2/api.go:16
+ - pkg/shim/v2/epoll.go:18
+ - pkg/shim/v2/options/options.go:15
+ - pkg/shim/v2/options/options.go:24
+ - pkg/shim/v2/options/options.go:26
+ - pkg/shim/v2/runtimeoptions/runtimeoptions.go:16
+ - pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go # Generated: exempt all.
+ - pkg/shim/v2/runtimeoptions/runtimeoptions_test.go:22
+ - pkg/shim/v2/service.go:15
+ - pkg/shim/v2/service_linux.go:18
+ - pkg/state/tests/integer_test.go:23
+ - pkg/state/tests/integer_test.go:28
+ - pkg/sync/rwmutex_test.go:105
+ - pkg/syserr/host_linux.go:35
+ - pkg/usermem/addr.go:34
+ - pkg/usermem/usermem.go:171
+ - pkg/usermem/usermem.go:170
+ - runsc/boot/compat.go:56
+ - test/cmd/test_app/fds.go:171
+ - test/iptables/filter_output.go:251
+ - test/packetimpact/testbench/connections.go:77
+ - tools/bigquery/bigquery.go:106
+ - tools/checkescape/test1/test1.go:108
+ - tools/checkescape/test1/test1.go:122
+ - tools/checkescape/test1/test1.go:137
+ - tools/checkescape/test1/test1.go:151
+ - tools/checkescape/test1/test1.go:170
+ - tools/checkescape/test1/test1.go:39
+ - tools/checkescape/test1/test1.go:45
+ - tools/checkescape/test1/test1.go:50
+ - tools/checkescape/test1/test1.go:64
+ - tools/checkescape/test1/test1.go:80
+ - tools/checkescape/test1/test1.go:94
+analyzers:
+ asmdecl:
+ external: # Enabled.
+ assign:
+ external:
+ exclude:
+ - gazelle/walk/walk.go
+ atomic:
+ external: # Enabled.
+ bools:
+ external: # Enabled.
+ buildtag:
+ external: # Enabled.
+ cgocall:
+ external: # Enabled.
+ shadow: # Disable for now.
+ generated:
+ exclude: [".*"]
+ internal:
+ exclude: [".*"]
+ composites: # Disable for now.
+ generated:
+ exclude: [".*"]
+ internal:
+ exclude: [".*"]
+ errorsas:
+ external: # Enabled.
+ httpresponse:
+ external: # Enabled.
+ loopclosure:
+ external: # Enabled.
+ nilfunc:
+ external: # Enabled.
+ nilness:
+ internal:
+ exclude:
+ - pkg/sentry/platform/kvm/kvm_test.go # Intentional.
+ - tools/bigquery/bigquery.go # False positive.
+ printf:
+ external: # Enabled.
+ shift:
+ external: # Enabled.
+ stringintconv:
+ external:
+ exclude:
+ - ".*protobuf/.*.go" # Bad conversions.
+ - ".*flate/huffman_bit_writer.go" # Bad conversion.
+ # Runtime internal violations.
+ - ".*reflect/value.go"
+ - ".*encoding/xml/xml.go"
+ - ".*runtime/pprof/internal/profile/proto.go"
+ - ".*fmt/scan.go"
+ - ".*go/types/conversions.go"
+ - ".*golang.org/x/net/dns/dnsmessage/message.go"
+ tests:
+ external: # Enabled.
+ unmarshal:
+ external: # Enabled.
+ unreachable:
+ external: # Enabled.
+ unsafeptr:
+ internal:
+ exclude:
+ - ".*_test.go" # Exclude tests.
+ - "pkg/flipcall/.*_unsafe.go" # Special case.
+ - pkg/gohacks/gohacks_unsafe.go # Special case.
+ - pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go # Special case.
+ - pkg/sentry/platform/kvm/bluepill_unsafe.go # Special case.
+ - pkg/sentry/platform/kvm/machine_unsafe.go # Special case.
+ - pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go # Special case.
+ - pkg/sentry/platform/safecopy/safecopy_unsafe.go # Special case.
+ - pkg/sentry/vfs/mount_unsafe.go # Special case.
+ - pkg/state/decode_unsafe.go # Special case.
+ unusedresult:
+ external: # Enabled.
+ checkescape:
+ external: # Enabled.
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index 4a26e28de..a0654df2f 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -55,6 +55,8 @@ go_library(
"sched.go",
"seccomp.go",
"sem.go",
+ "sem_amd64.go",
+ "sem_arm64.go",
"shm.go",
"signal.go",
"signalfd.go",
diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go
index 7df02dd6d..006b5a525 100644
--- a/pkg/abi/linux/ioctl.go
+++ b/pkg/abi/linux/ioctl.go
@@ -121,6 +121,9 @@ const (
// Constants from uapi/linux/fsverity.h.
const (
+ FS_VERITY_HASH_ALG_SHA256 = 1
+ FS_VERITY_HASH_ALG_SHA512 = 2
+
FS_IOC_ENABLE_VERITY = 1082156677
FS_IOC_MEASURE_VERITY = 3221513862
)
diff --git a/pkg/abi/linux/sem.go b/pkg/abi/linux/sem.go
index 487a626cc..1b2f76c0b 100644
--- a/pkg/abi/linux/sem.go
+++ b/pkg/abi/linux/sem.go
@@ -34,18 +34,6 @@ const (
const SEM_UNDO = 0x1000
-// SemidDS is equivalent to struct semid64_ds.
-//
-// +marshal
-type SemidDS struct {
- SemPerm IPCPerm
- SemOTime TimeT
- SemCTime TimeT
- SemNSems uint64
- unused3 uint64
- unused4 uint64
-}
-
// Sembuf is equivalent to struct sembuf.
//
// +marshal slice:SembufSlice
diff --git a/pkg/abi/linux/sem_amd64.go b/pkg/abi/linux/sem_amd64.go
new file mode 100644
index 000000000..ab980cb4f
--- /dev/null
+++ b/pkg/abi/linux/sem_amd64.go
@@ -0,0 +1,33 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package linux
+
+// SemidDS is equivalent to struct semid64_ds.
+//
+// Source: arch/x86/include/uapi/asm/sembuf.h
+//
+// +marshal
+type SemidDS struct {
+ SemPerm IPCPerm
+ SemOTime TimeT
+ unused1 uint64
+ SemCTime TimeT
+ unused2 uint64
+ SemNSems uint64
+ unused3 uint64
+ unused4 uint64
+}
diff --git a/pkg/abi/linux/sem_arm64.go b/pkg/abi/linux/sem_arm64.go
new file mode 100644
index 000000000..521468fb1
--- /dev/null
+++ b/pkg/abi/linux/sem_arm64.go
@@ -0,0 +1,31 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package linux
+
+// SemidDS is equivalent to struct semid64_ds.
+//
+// Source: include/uapi/asm-generic/sembuf.h
+//
+// +marshal
+type SemidDS struct {
+ SemPerm IPCPerm
+ SemOTime TimeT
+ SemCTime TimeT
+ SemNSems uint64
+ unused3 uint64
+ unused4 uint64
+}
diff --git a/pkg/bpf/decoder.go b/pkg/bpf/decoder.go
index 069d0395d..6d1e65cb1 100644
--- a/pkg/bpf/decoder.go
+++ b/pkg/bpf/decoder.go
@@ -109,7 +109,7 @@ func decodeLdSize(inst linux.BPFInstruction, w *bytes.Buffer) error {
case B:
w.WriteString("1")
default:
- return fmt.Errorf("Invalid BPF LD size: %v", inst)
+ return fmt.Errorf("invalid BPF LD size: %v", inst)
}
return nil
}
diff --git a/pkg/context/context.go b/pkg/context/context.go
index 2613bc752..f3031fc60 100644
--- a/pkg/context/context.go
+++ b/pkg/context/context.go
@@ -166,3 +166,27 @@ var bgContext = &logContext{Logger: log.Log()}
func Background() Context {
return bgContext
}
+
+// WithValue returns a copy of parent in which the value associated with key is
+// val.
+func WithValue(parent Context, key, val interface{}) Context {
+ return &withValue{
+ Context: parent,
+ key: key,
+ val: val,
+ }
+}
+
+type withValue struct {
+ Context
+ key interface{}
+ val interface{}
+}
+
+// Value implements Context.Value.
+func (ctx *withValue) Value(key interface{}) interface{} {
+ if key == ctx.key {
+ return ctx.val
+ }
+ return ctx.Context.Value(key)
+}
diff --git a/pkg/merkletree/BUILD b/pkg/merkletree/BUILD
index a8fcb2e19..501a9ef21 100644
--- a/pkg/merkletree/BUILD
+++ b/pkg/merkletree/BUILD
@@ -6,12 +6,18 @@ go_library(
name = "merkletree",
srcs = ["merkletree.go"],
visibility = ["//pkg/sentry:internal"],
- deps = ["//pkg/usermem"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/usermem",
+ ],
)
go_test(
name = "merkletree_test",
srcs = ["merkletree_test.go"],
library = ":merkletree",
- deps = ["//pkg/usermem"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/usermem",
+ ],
)
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
index d8227b8bd..e0a9e56c5 100644
--- a/pkg/merkletree/merkletree.go
+++ b/pkg/merkletree/merkletree.go
@@ -18,21 +18,32 @@ package merkletree
import (
"bytes"
"crypto/sha256"
+ "crypto/sha512"
"fmt"
"io"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/usermem"
)
const (
// sha256DigestSize specifies the digest size of a SHA256 hash.
sha256DigestSize = 32
+ // sha512DigestSize specifies the digest size of a SHA512 hash.
+ sha512DigestSize = 64
)
// DigestSize returns the size (in bytes) of a digest.
-// TODO(b/156980949): Allow config other hash methods (SHA384/SHA512).
-func DigestSize() int {
- return sha256DigestSize
+// TODO(b/156980949): Allow config SHA384.
+func DigestSize(hashAlgorithm int) int {
+ switch hashAlgorithm {
+ case linux.FS_VERITY_HASH_ALG_SHA256:
+ return sha256DigestSize
+ case linux.FS_VERITY_HASH_ALG_SHA512:
+ return sha512DigestSize
+ default:
+ return -1
+ }
}
// Layout defines the scale of a Merkle tree.
@@ -51,11 +62,19 @@ type Layout struct {
// InitLayout initializes and returns a new Layout object describing the structure
// of a tree. dataSize specifies the size of input data in bytes.
-func InitLayout(dataSize int64, dataAndTreeInSameFile bool) Layout {
+func InitLayout(dataSize int64, hashAlgorithms int, dataAndTreeInSameFile bool) (Layout, error) {
layout := Layout{
blockSize: usermem.PageSize,
- // TODO(b/156980949): Allow config other hash methods (SHA384/SHA512).
- digestSize: sha256DigestSize,
+ }
+
+ // TODO(b/156980949): Allow config SHA384.
+ switch hashAlgorithms {
+ case linux.FS_VERITY_HASH_ALG_SHA256:
+ layout.digestSize = sha256DigestSize
+ case linux.FS_VERITY_HASH_ALG_SHA512:
+ layout.digestSize = sha512DigestSize
+ default:
+ return Layout{}, fmt.Errorf("unexpected hash algorithms")
}
// treeStart is the offset (in bytes) of the first level of the tree in
@@ -88,7 +107,7 @@ func InitLayout(dataSize int64, dataAndTreeInSameFile bool) Layout {
}
layout.levelOffset = append(layout.levelOffset, treeStart+offset*layout.blockSize)
- return layout
+ return layout, nil
}
// hashesPerBlock() returns the number of digests in each block. For example,
@@ -128,6 +147,7 @@ func (layout Layout) blockOffset(level int, index int64) int64 {
// meatadata.
type VerityDescriptor struct {
Name string
+ FileSize int64
Mode uint32
UID uint32
GID uint32
@@ -135,16 +155,37 @@ type VerityDescriptor struct {
}
func (d *VerityDescriptor) String() string {
- return fmt.Sprintf("Name: %s, Mode: %d, UID: %d, GID: %d, RootHash: %v", d.Name, d.Mode, d.UID, d.GID, d.RootHash)
+ return fmt.Sprintf("Name: %s, Size: %d, Mode: %d, UID: %d, GID: %d, RootHash: %v", d.Name, d.FileSize, d.Mode, d.UID, d.GID, d.RootHash)
}
// verify generates a hash from d, and compares it with expected.
-func (d *VerityDescriptor) verify(expected []byte) error {
- h := sha256.Sum256([]byte(d.String()))
+func (d *VerityDescriptor) verify(expected []byte, hashAlgorithms int) error {
+ h, err := hashData([]byte(d.String()), hashAlgorithms)
+ if err != nil {
+ return err
+ }
if !bytes.Equal(h[:], expected) {
return fmt.Errorf("unexpected root hash")
}
return nil
+
+}
+
+// hashData hashes data and returns the result hash based on the hash
+// algorithms.
+func hashData(data []byte, hashAlgorithms int) ([]byte, error) {
+ var digest []byte
+ switch hashAlgorithms {
+ case linux.FS_VERITY_HASH_ALG_SHA256:
+ digestArray := sha256.Sum256(data)
+ digest = digestArray[:]
+ case linux.FS_VERITY_HASH_ALG_SHA512:
+ digestArray := sha512.Sum512(data)
+ digest = digestArray[:]
+ default:
+ return nil, fmt.Errorf("unexpected hash algorithms")
+ }
+ return digest, nil
}
// GenerateParams contains the parameters used to generate a Merkle tree.
@@ -161,6 +202,8 @@ type GenerateParams struct {
UID uint32
// GID is the group ID of the target file.
GID uint32
+ // HashAlgorithms is the algorithms used to hash data.
+ HashAlgorithms int
// TreeReader is a reader for the Merkle tree.
TreeReader io.ReaderAt
// TreeWriter is a writer for the Merkle tree.
@@ -176,7 +219,10 @@ type GenerateParams struct {
// Generate returns a hash of a VerityDescriptor, which contains the file
// metadata and the hash from file content.
func Generate(params *GenerateParams) ([]byte, error) {
- layout := InitLayout(params.Size, params.DataAndTreeInSameFile)
+ layout, err := InitLayout(params.Size, params.HashAlgorithms, params.DataAndTreeInSameFile)
+ if err != nil {
+ return nil, err
+ }
numBlocks := (params.Size + layout.blockSize - 1) / layout.blockSize
@@ -218,10 +264,13 @@ func Generate(params *GenerateParams) ([]byte, error) {
return nil, err
}
// Hash the bytes in buf.
- digest := sha256.Sum256(buf)
+ digest, err := hashData(buf, params.HashAlgorithms)
+ if err != nil {
+ return nil, err
+ }
if level == layout.rootLevel() {
- root = digest[:]
+ root = digest
}
// Write the generated hash to the end of the tree file.
@@ -241,13 +290,13 @@ func Generate(params *GenerateParams) ([]byte, error) {
}
descriptor := VerityDescriptor{
Name: params.Name,
+ FileSize: params.Size,
Mode: params.Mode,
UID: params.UID,
GID: params.GID,
RootHash: root,
}
- ret := sha256.Sum256([]byte(descriptor.String()))
- return ret[:], nil
+ return hashData([]byte(descriptor.String()), params.HashAlgorithms)
}
// VerifyParams contains the params used to verify a portion of a file against
@@ -269,6 +318,8 @@ type VerifyParams struct {
UID uint32
// GID is the group ID of the target file.
GID uint32
+ // HashAlgorithms is the algorithms used to hash data.
+ HashAlgorithms int
// ReadOffset is the offset of the data range to be verified.
ReadOffset int64
// ReadSize is the size of the data range to be verified.
@@ -293,12 +344,13 @@ func verifyMetadata(params *VerifyParams, layout *Layout) error {
}
descriptor := VerityDescriptor{
Name: params.Name,
+ FileSize: params.Size,
Mode: params.Mode,
UID: params.UID,
GID: params.GID,
RootHash: root,
}
- return descriptor.verify(params.Expected)
+ return descriptor.verify(params.Expected, params.HashAlgorithms)
}
// Verify verifies the content read from data with offset. The content is
@@ -313,7 +365,10 @@ func Verify(params *VerifyParams) (int64, error) {
if params.ReadSize < 0 {
return 0, fmt.Errorf("unexpected read size: %d", params.ReadSize)
}
- layout := InitLayout(int64(params.Size), params.DataAndTreeInSameFile)
+ layout, err := InitLayout(int64(params.Size), params.HashAlgorithms, params.DataAndTreeInSameFile)
+ if err != nil {
+ return 0, err
+ }
if params.ReadSize == 0 {
return 0, verifyMetadata(params, &layout)
}
@@ -349,12 +404,13 @@ func Verify(params *VerifyParams) (int64, error) {
}
}
descriptor := VerityDescriptor{
- Name: params.Name,
- Mode: params.Mode,
- UID: params.UID,
- GID: params.GID,
+ Name: params.Name,
+ FileSize: params.Size,
+ Mode: params.Mode,
+ UID: params.UID,
+ GID: params.GID,
}
- if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.Expected); err != nil {
+ if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.HashAlgorithms, params.Expected); err != nil {
return 0, err
}
@@ -395,7 +451,7 @@ func Verify(params *VerifyParams) (int64, error) {
// fails if the calculated hash from block is different from any level of
// hashes stored in tree. And the final root hash is compared with
// expected.
-func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, expected []byte) error {
+func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, hashAlgorithms int, expected []byte) error {
if len(dataBlock) != int(layout.blockSize) {
return fmt.Errorf("incorrect block size")
}
@@ -406,8 +462,11 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout,
for level := 0; level < layout.numLevels(); level++ {
// Calculate hash.
if level == 0 {
- digestArray := sha256.Sum256(dataBlock)
- digest = digestArray[:]
+ h, err := hashData(dataBlock, hashAlgorithms)
+ if err != nil {
+ return err
+ }
+ digest = h
} else {
// Read a block in previous level that contains the
// hash we just generated, and generate a next level
@@ -415,8 +474,11 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout,
if _, err := tree.ReadAt(treeBlock, layout.blockOffset(level-1, blockIndex)); err != nil {
return err
}
- digestArray := sha256.Sum256(treeBlock)
- digest = digestArray[:]
+ h, err := hashData(treeBlock, hashAlgorithms)
+ if err != nil {
+ return err
+ }
+ digest = h
}
// Read the digest for the current block and store in
@@ -434,5 +496,5 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout,
// Verification for the tree succeeded. Now hash the descriptor with
// the root hash and compare it with expected.
descriptor.RootHash = digest
- return descriptor.verify(expected)
+ return descriptor.verify(expected, hashAlgorithms)
}
diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go
index e1350ebda..405204d94 100644
--- a/pkg/merkletree/merkletree_test.go
+++ b/pkg/merkletree/merkletree_test.go
@@ -22,54 +22,114 @@ import (
"testing"
"time"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/usermem"
)
func TestLayout(t *testing.T) {
testCases := []struct {
dataSize int64
+ hashAlgorithms int
dataAndTreeInSameFile bool
+ expectedDigestSize int64
expectedLevelOffset []int64
}{
{
dataSize: 100,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: false,
+ expectedDigestSize: 32,
expectedLevelOffset: []int64{0},
},
{
dataSize: 100,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ dataAndTreeInSameFile: false,
+ expectedDigestSize: 64,
+ expectedLevelOffset: []int64{0},
+ },
+ {
+ dataSize: 100,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
+ dataAndTreeInSameFile: true,
+ expectedDigestSize: 32,
+ expectedLevelOffset: []int64{usermem.PageSize},
+ },
+ {
+ dataSize: 100,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
dataAndTreeInSameFile: true,
+ expectedDigestSize: 64,
expectedLevelOffset: []int64{usermem.PageSize},
},
{
dataSize: 1000000,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: false,
+ expectedDigestSize: 32,
expectedLevelOffset: []int64{0, 2 * usermem.PageSize, 3 * usermem.PageSize},
},
{
dataSize: 1000000,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ dataAndTreeInSameFile: false,
+ expectedDigestSize: 64,
+ expectedLevelOffset: []int64{0, 4 * usermem.PageSize, 5 * usermem.PageSize},
+ },
+ {
+ dataSize: 1000000,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: true,
+ expectedDigestSize: 32,
expectedLevelOffset: []int64{245 * usermem.PageSize, 247 * usermem.PageSize, 248 * usermem.PageSize},
},
{
+ dataSize: 1000000,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ dataAndTreeInSameFile: true,
+ expectedDigestSize: 64,
+ expectedLevelOffset: []int64{245 * usermem.PageSize, 249 * usermem.PageSize, 250 * usermem.PageSize},
+ },
+ {
dataSize: 4096 * int64(usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: false,
+ expectedDigestSize: 32,
expectedLevelOffset: []int64{0, 32 * usermem.PageSize, 33 * usermem.PageSize},
},
{
dataSize: 4096 * int64(usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ dataAndTreeInSameFile: false,
+ expectedDigestSize: 64,
+ expectedLevelOffset: []int64{0, 64 * usermem.PageSize, 65 * usermem.PageSize},
+ },
+ {
+ dataSize: 4096 * int64(usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: true,
+ expectedDigestSize: 32,
expectedLevelOffset: []int64{4096 * usermem.PageSize, 4128 * usermem.PageSize, 4129 * usermem.PageSize},
},
+ {
+ dataSize: 4096 * int64(usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ dataAndTreeInSameFile: true,
+ expectedDigestSize: 64,
+ expectedLevelOffset: []int64{4096 * usermem.PageSize, 4160 * usermem.PageSize, 4161 * usermem.PageSize},
+ },
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) {
- l := InitLayout(tc.dataSize, tc.dataAndTreeInSameFile)
+ l, err := InitLayout(tc.dataSize, tc.hashAlgorithms, tc.dataAndTreeInSameFile)
+ if err != nil {
+ t.Fatalf("Failed to InitLayout: %v", err)
+ }
if l.blockSize != int64(usermem.PageSize) {
t.Errorf("Got blockSize %d, want %d", l.blockSize, usermem.PageSize)
}
- if l.digestSize != sha256DigestSize {
+ if l.digestSize != tc.expectedDigestSize {
t.Errorf("Got digestSize %d, want %d", l.digestSize, sha256DigestSize)
}
if l.numLevels() != len(tc.expectedLevelOffset) {
@@ -118,24 +178,49 @@ func TestGenerate(t *testing.T) {
// The input data has size dataSize. It starts with the data in startWith,
// and all other bytes are zeroes.
testCases := []struct {
- data []byte
- expectedHash []byte
+ data []byte
+ hashAlgorithms int
+ expectedHash []byte
}{
{
- data: bytes.Repeat([]byte{0}, usermem.PageSize),
- expectedHash: []byte{64, 253, 58, 72, 192, 131, 82, 184, 193, 33, 108, 142, 43, 46, 179, 134, 244, 21, 29, 190, 14, 39, 66, 129, 6, 46, 200, 211, 30, 247, 191, 252},
+ data: bytes.Repeat([]byte{0}, usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
+ expectedHash: []byte{39, 30, 12, 152, 185, 58, 32, 84, 218, 79, 74, 113, 104, 219, 230, 234, 25, 126, 147, 36, 212, 44, 76, 74, 25, 93, 228, 41, 243, 143, 59, 147},
+ },
+ {
+ data: bytes.Repeat([]byte{0}, usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ expectedHash: []byte{184, 76, 172, 204, 17, 136, 127, 75, 224, 42, 251, 181, 98, 149, 1, 44, 58, 148, 20, 187, 30, 174, 73, 87, 166, 9, 109, 169, 42, 96, 87, 202, 59, 82, 174, 80, 51, 95, 101, 100, 6, 246, 56, 120, 27, 166, 29, 59, 67, 115, 227, 121, 241, 177, 63, 238, 82, 157, 43, 107, 174, 180, 44, 84},
+ },
+ {
+ data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
+ expectedHash: []byte{213, 221, 252, 9, 241, 250, 186, 1, 242, 132, 83, 77, 180, 207, 119, 48, 206, 113, 37, 253, 252, 159, 71, 70, 3, 53, 42, 244, 230, 244, 173, 143},
+ },
+ {
+ data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ expectedHash: []byte{40, 231, 187, 28, 3, 171, 168, 36, 177, 244, 118, 131, 218, 226, 106, 55, 245, 157, 244, 147, 144, 57, 41, 182, 65, 6, 13, 49, 38, 66, 237, 117, 124, 110, 250, 246, 248, 132, 201, 156, 195, 201, 142, 179, 122, 128, 195, 194, 187, 240, 129, 171, 168, 182, 101, 58, 194, 155, 99, 147, 49, 130, 161, 178},
},
{
- data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1),
- expectedHash: []byte{182, 223, 218, 62, 65, 185, 160, 219, 93, 119, 186, 88, 205, 32, 122, 231, 173, 72, 78, 76, 65, 57, 177, 146, 159, 39, 44, 123, 230, 156, 97, 26},
+ data: []byte{'a'},
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
+ expectedHash: []byte{182, 25, 170, 240, 16, 153, 234, 4, 101, 238, 197, 154, 182, 168, 171, 96, 177, 33, 171, 117, 73, 78, 124, 239, 82, 255, 215, 121, 156, 95, 121, 171},
},
{
- data: []byte{'a'},
- expectedHash: []byte{28, 201, 8, 36, 150, 178, 111, 5, 193, 212, 129, 205, 206, 124, 211, 90, 224, 142, 81, 183, 72, 165, 243, 240, 242, 241, 76, 127, 101, 61, 63, 11},
+ data: []byte{'a'},
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ expectedHash: []byte{121, 28, 140, 244, 32, 222, 61, 255, 184, 65, 117, 84, 132, 197, 122, 214, 95, 249, 164, 77, 211, 192, 217, 59, 109, 255, 249, 253, 27, 142, 110, 29, 93, 153, 92, 211, 178, 198, 136, 34, 61, 157, 141, 94, 145, 191, 201, 134, 141, 138, 51, 26, 33, 187, 17, 196, 113, 234, 125, 219, 4, 41, 57, 120},
},
{
- data: bytes.Repeat([]byte{'a'}, usermem.PageSize),
- expectedHash: []byte{106, 58, 160, 152, 41, 68, 38, 108, 245, 74, 177, 84, 64, 193, 19, 176, 249, 86, 27, 193, 85, 164, 99, 240, 79, 104, 148, 222, 76, 46, 191, 79},
+ data: bytes.Repeat([]byte{'a'}, usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
+ expectedHash: []byte{17, 40, 99, 150, 206, 124, 196, 184, 41, 40, 50, 91, 113, 47, 8, 204, 2, 102, 202, 86, 157, 92, 218, 53, 151, 250, 234, 247, 191, 121, 113, 246},
+ },
+ {
+ data: bytes.Repeat([]byte{'a'}, usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ expectedHash: []byte{100, 22, 249, 78, 47, 163, 220, 231, 228, 165, 226, 192, 221, 77, 106, 69, 115, 104, 208, 155, 124, 206, 225, 233, 98, 249, 232, 225, 114, 119, 110, 216, 117, 106, 85, 7, 200, 206, 139, 81, 116, 37, 215, 158, 89, 110, 74, 86, 66, 95, 117, 237, 70, 56, 62, 175, 48, 147, 162, 122, 253, 57, 123, 84},
},
}
@@ -149,6 +234,7 @@ func TestGenerate(t *testing.T) {
Mode: defaultMode,
UID: defaultUID,
GID: defaultGID,
+ HashAlgorithms: tc.hashAlgorithms,
TreeReader: &tree,
TreeWriter: &tree,
DataAndTreeInSameFile: dataAndTreeInSameFile,
@@ -189,6 +275,7 @@ func TestVerify(t *testing.T) {
// fail, otherwise Verify should still succeed.
modifyByte int64
modifyName bool
+ modifySize bool
modifyMode bool
modifyUID bool
modifyGID bool
@@ -237,6 +324,15 @@ func TestVerify(t *testing.T) {
modifyName: true,
shouldSucceed: false,
},
+ // Modified size should fail verification.
+ {
+ dataSize: usermem.PageSize,
+ verifyStart: 0,
+ verifySize: 0,
+ modifyByte: 0,
+ modifySize: true,
+ shouldSucceed: false,
+ },
// Modified mode should fail verification.
{
dataSize: usermem.PageSize,
@@ -348,77 +444,84 @@ func TestVerify(t *testing.T) {
// Generate random bytes in data.
rand.Read(data)
- for _, dataAndTreeInSameFile := range []bool{false, true} {
- var tree bytesReadWriter
- genParams := GenerateParams{
- Size: int64(len(data)),
- Name: defaultName,
- Mode: defaultMode,
- UID: defaultUID,
- GID: defaultGID,
- TreeReader: &tree,
- TreeWriter: &tree,
- DataAndTreeInSameFile: dataAndTreeInSameFile,
- }
- if dataAndTreeInSameFile {
- tree.Write(data)
- genParams.File = &tree
- } else {
- genParams.File = &bytesReadWriter{
- bytes: data,
+ for _, hashAlgorithms := range []int{linux.FS_VERITY_HASH_ALG_SHA256, linux.FS_VERITY_HASH_ALG_SHA512} {
+ for _, dataAndTreeInSameFile := range []bool{false, true} {
+ var tree bytesReadWriter
+ genParams := GenerateParams{
+ Size: int64(len(data)),
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ HashAlgorithms: hashAlgorithms,
+ TreeReader: &tree,
+ TreeWriter: &tree,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
+ if dataAndTreeInSameFile {
+ tree.Write(data)
+ genParams.File = &tree
+ } else {
+ genParams.File = &bytesReadWriter{
+ bytes: data,
+ }
+ }
+ hash, err := Generate(&genParams)
+ if err != nil {
+ t.Fatalf("Generate failed: %v", err)
}
- }
- hash, err := Generate(&genParams)
- if err != nil {
- t.Fatalf("Generate failed: %v", err)
- }
- // Flip a bit in data and checks Verify results.
- var buf bytes.Buffer
- data[tc.modifyByte] ^= 1
- verifyParams := VerifyParams{
- Out: &buf,
- File: bytes.NewReader(data),
- Tree: &tree,
- Size: tc.dataSize,
- Name: defaultName,
- Mode: defaultMode,
- UID: defaultUID,
- GID: defaultGID,
- ReadOffset: tc.verifyStart,
- ReadSize: tc.verifySize,
- Expected: hash,
- DataAndTreeInSameFile: dataAndTreeInSameFile,
- }
- if tc.modifyName {
- verifyParams.Name = defaultName + "abc"
- }
- if tc.modifyMode {
- verifyParams.Mode = defaultMode + 1
- }
- if tc.modifyUID {
- verifyParams.UID = defaultUID + 1
- }
- if tc.modifyGID {
- verifyParams.GID = defaultGID + 1
- }
- if tc.shouldSucceed {
- n, err := Verify(&verifyParams)
- if err != nil && err != io.EOF {
- t.Errorf("Verification failed when expected to succeed: %v", err)
+ // Flip a bit in data and checks Verify results.
+ var buf bytes.Buffer
+ data[tc.modifyByte] ^= 1
+ verifyParams := VerifyParams{
+ Out: &buf,
+ File: bytes.NewReader(data),
+ Tree: &tree,
+ Size: tc.dataSize,
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ HashAlgorithms: hashAlgorithms,
+ ReadOffset: tc.verifyStart,
+ ReadSize: tc.verifySize,
+ Expected: hash,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
}
- if n != tc.verifySize {
- t.Errorf("Got Verify output size %d, want %d", n, tc.verifySize)
+ if tc.modifyName {
+ verifyParams.Name = defaultName + "abc"
}
- if int64(buf.Len()) != tc.verifySize {
- t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), tc.verifySize)
+ if tc.modifySize {
+ verifyParams.Size--
}
- if !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) {
- t.Errorf("Incorrect output buf from Verify")
+ if tc.modifyMode {
+ verifyParams.Mode = defaultMode + 1
}
- } else {
- if _, err := Verify(&verifyParams); err == nil {
- t.Errorf("Verification succeeded when expected to fail")
+ if tc.modifyUID {
+ verifyParams.UID = defaultUID + 1
+ }
+ if tc.modifyGID {
+ verifyParams.GID = defaultGID + 1
+ }
+ if tc.shouldSucceed {
+ n, err := Verify(&verifyParams)
+ if err != nil && err != io.EOF {
+ t.Errorf("Verification failed when expected to succeed: %v", err)
+ }
+ if n != tc.verifySize {
+ t.Errorf("Got Verify output size %d, want %d", n, tc.verifySize)
+ }
+ if int64(buf.Len()) != tc.verifySize {
+ t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), tc.verifySize)
+ }
+ if !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) {
+ t.Errorf("Incorrect output buf from Verify")
+ }
+ } else {
+ if _, err := Verify(&verifyParams); err == nil {
+ t.Errorf("Verification succeeded when expected to fail")
+ }
}
}
}
@@ -435,87 +538,91 @@ func TestVerifyRandom(t *testing.T) {
// Generate random bytes in data.
rand.Read(data)
- for _, dataAndTreeInSameFile := range []bool{false, true} {
- var tree bytesReadWriter
- genParams := GenerateParams{
- Size: int64(len(data)),
- Name: defaultName,
- Mode: defaultMode,
- UID: defaultUID,
- GID: defaultGID,
- TreeReader: &tree,
- TreeWriter: &tree,
- DataAndTreeInSameFile: dataAndTreeInSameFile,
- }
+ for _, hashAlgorithms := range []int{linux.FS_VERITY_HASH_ALG_SHA256, linux.FS_VERITY_HASH_ALG_SHA512} {
+ for _, dataAndTreeInSameFile := range []bool{false, true} {
+ var tree bytesReadWriter
+ genParams := GenerateParams{
+ Size: int64(len(data)),
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ HashAlgorithms: hashAlgorithms,
+ TreeReader: &tree,
+ TreeWriter: &tree,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
- if dataAndTreeInSameFile {
- tree.Write(data)
- genParams.File = &tree
- } else {
- genParams.File = &bytesReadWriter{
- bytes: data,
+ if dataAndTreeInSameFile {
+ tree.Write(data)
+ genParams.File = &tree
+ } else {
+ genParams.File = &bytesReadWriter{
+ bytes: data,
+ }
+ }
+ hash, err := Generate(&genParams)
+ if err != nil {
+ t.Fatalf("Generate failed: %v", err)
}
- }
- hash, err := Generate(&genParams)
- if err != nil {
- t.Fatalf("Generate failed: %v", err)
- }
- // Pick a random portion of data.
- start := rand.Int63n(dataSize - 1)
- size := rand.Int63n(dataSize) + 1
+ // Pick a random portion of data.
+ start := rand.Int63n(dataSize - 1)
+ size := rand.Int63n(dataSize) + 1
- var buf bytes.Buffer
- verifyParams := VerifyParams{
- Out: &buf,
- File: bytes.NewReader(data),
- Tree: &tree,
- Size: dataSize,
- Name: defaultName,
- Mode: defaultMode,
- UID: defaultUID,
- GID: defaultGID,
- ReadOffset: start,
- ReadSize: size,
- Expected: hash,
- DataAndTreeInSameFile: dataAndTreeInSameFile,
- }
+ var buf bytes.Buffer
+ verifyParams := VerifyParams{
+ Out: &buf,
+ File: bytes.NewReader(data),
+ Tree: &tree,
+ Size: dataSize,
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ HashAlgorithms: hashAlgorithms,
+ ReadOffset: start,
+ ReadSize: size,
+ Expected: hash,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
- // Checks that the random portion of data from the original data is
- // verified successfully.
- n, err := Verify(&verifyParams)
- if err != nil && err != io.EOF {
- t.Errorf("Verification failed for correct data: %v", err)
- }
- if size > dataSize-start {
- size = dataSize - start
- }
- if n != size {
- t.Errorf("Got Verify output size %d, want %d", n, size)
- }
- if int64(buf.Len()) != size {
- t.Errorf("Got Verify output buf size %d, want %d", buf.Len(), size)
- }
- if !bytes.Equal(data[start:start+size], buf.Bytes()) {
- t.Errorf("Incorrect output buf from Verify")
- }
+ // Checks that the random portion of data from the original data is
+ // verified successfully.
+ n, err := Verify(&verifyParams)
+ if err != nil && err != io.EOF {
+ t.Errorf("Verification failed for correct data: %v", err)
+ }
+ if size > dataSize-start {
+ size = dataSize - start
+ }
+ if n != size {
+ t.Errorf("Got Verify output size %d, want %d", n, size)
+ }
+ if int64(buf.Len()) != size {
+ t.Errorf("Got Verify output buf size %d, want %d", buf.Len(), size)
+ }
+ if !bytes.Equal(data[start:start+size], buf.Bytes()) {
+ t.Errorf("Incorrect output buf from Verify")
+ }
- // Verify that modified metadata should fail verification.
- buf.Reset()
- verifyParams.Name = defaultName + "abc"
- if _, err := Verify(&verifyParams); err == nil {
- t.Error("Verify succeeded for modified metadata, expect failure")
- }
+ // Verify that modified metadata should fail verification.
+ buf.Reset()
+ verifyParams.Name = defaultName + "abc"
+ if _, err := Verify(&verifyParams); err == nil {
+ t.Error("Verify succeeded for modified metadata, expect failure")
+ }
- // Flip a random bit in randPortion, and check that verification fails.
- buf.Reset()
- randBytePos := rand.Int63n(size)
- data[start+randBytePos] ^= 1
- verifyParams.File = bytes.NewReader(data)
- verifyParams.Name = defaultName
+ // Flip a random bit in randPortion, and check that verification fails.
+ buf.Reset()
+ randBytePos := rand.Int63n(size)
+ data[start+randBytePos] ^= 1
+ verifyParams.File = bytes.NewReader(data)
+ verifyParams.Name = defaultName
- if _, err := Verify(&verifyParams); err == nil {
- t.Error("Verification succeeded for modified data, expect failure")
+ if _, err := Verify(&verifyParams); err == nil {
+ t.Error("Verification succeeded for modified data, expect failure")
+ }
}
}
}
diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go
index 699ea8ac3..6992e1de8 100644
--- a/pkg/refs/refcounter.go
+++ b/pkg/refs/refcounter.go
@@ -319,7 +319,8 @@ func makeStackKey(pcs []uintptr) stackKey {
return key
}
-func recordStack() []uintptr {
+// RecordStack constructs and returns the PCs on the current stack.
+func RecordStack() []uintptr {
pcs := make([]uintptr, maxStackFrames)
n := runtime.Callers(1, pcs)
if n == 0 {
@@ -342,7 +343,8 @@ func recordStack() []uintptr {
return v
}
-func formatStack(pcs []uintptr) string {
+// FormatStack converts the given stack into a readable format.
+func FormatStack(pcs []uintptr) string {
frames := runtime.CallersFrames(pcs)
var trace bytes.Buffer
for {
@@ -367,7 +369,7 @@ func (r *AtomicRefCount) finalize() {
if n := r.ReadRefs(); n != 0 {
msg := fmt.Sprintf("%sAtomicRefCount %p owned by %q garbage collected with ref count of %d (want 0)", note, r, r.name, n)
if len(r.stack) != 0 {
- msg += ":\nCaller:\n" + formatStack(r.stack)
+ msg += ":\nCaller:\n" + FormatStack(r.stack)
} else {
msg += " (enable trace logging to debug)"
}
@@ -392,7 +394,7 @@ func (r *AtomicRefCount) EnableLeakCheck(name string) {
case NoLeakChecking:
return
case LeaksLogTraces:
- r.stack = recordStack()
+ r.stack = RecordStack()
}
r.name = name
runtime.SetFinalizer(r, (*AtomicRefCount).finalize)
diff --git a/pkg/refs_vfs2/BUILD b/pkg/refsvfs2/BUILD
index 577b827a5..bfa1daa10 100644
--- a/pkg/refs_vfs2/BUILD
+++ b/pkg/refsvfs2/BUILD
@@ -8,6 +8,9 @@ go_template(
srcs = [
"refs_template.go",
],
+ opt_consts = [
+ "logTrace",
+ ],
types = [
"T",
],
@@ -19,8 +22,16 @@ go_template(
)
go_library(
- name = "refs_vfs2",
- srcs = ["refs.go"],
- visibility = ["//pkg/sentry:internal"],
- deps = ["//pkg/context"],
+ name = "refsvfs2",
+ srcs = [
+ "refs.go",
+ "refs_map.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/context",
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/sync",
+ ],
)
diff --git a/pkg/refs_vfs2/refs.go b/pkg/refsvfs2/refs.go
index 99a074e96..ef8beb659 100644
--- a/pkg/refs_vfs2/refs.go
+++ b/pkg/refsvfs2/refs.go
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package refs_vfs2 defines an interface for a reference-counted object.
-package refs_vfs2
+// Package refsvfs2 defines an interface for a reference-counted object.
+package refsvfs2
import (
"gvisor.dev/gvisor/pkg/context"
diff --git a/pkg/refsvfs2/refs_map.go b/pkg/refsvfs2/refs_map.go
new file mode 100644
index 000000000..9fbc5466f
--- /dev/null
+++ b/pkg/refsvfs2/refs_map.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package refsvfs2
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/log"
+ refs_vfs1 "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+var (
+ // liveObjects is a global map of reference-counted objects. Objects are
+ // inserted when leak check is enabled, and they are removed when they are
+ // destroyed. It is protected by liveObjectsMu.
+ liveObjects map[CheckedObject]struct{}
+ liveObjectsMu sync.Mutex
+)
+
+// CheckedObject represents a reference-counted object with an informative
+// leak detection message.
+type CheckedObject interface {
+ // RefType is the type of the reference-counted object.
+ RefType() string
+
+ // LeakMessage supplies a warning to be printed upon leak detection.
+ LeakMessage() string
+
+ // LogRefs indicates whether reference-related events should be logged.
+ LogRefs() bool
+}
+
+func init() {
+ liveObjects = make(map[CheckedObject]struct{})
+}
+
+// leakCheckEnabled returns whether leak checking is enabled. The following
+// functions should only be called if it returns true.
+func leakCheckEnabled() bool {
+ return refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking
+}
+
+// Register adds obj to the live object map.
+func Register(obj CheckedObject) {
+ if leakCheckEnabled() {
+ liveObjectsMu.Lock()
+ if _, ok := liveObjects[obj]; ok {
+ panic(fmt.Sprintf("Unexpected entry in leak checking map: reference %p already added", obj))
+ }
+ liveObjects[obj] = struct{}{}
+ liveObjectsMu.Unlock()
+ if leakCheckEnabled() && obj.LogRefs() {
+ logEvent(obj, "registered")
+ }
+ }
+}
+
+// Unregister removes obj from the live object map.
+func Unregister(obj CheckedObject) {
+ if leakCheckEnabled() {
+ liveObjectsMu.Lock()
+ defer liveObjectsMu.Unlock()
+ if _, ok := liveObjects[obj]; !ok {
+ panic(fmt.Sprintf("Expected to find entry in leak checking map for reference %p", obj))
+ }
+ delete(liveObjects, obj)
+ if leakCheckEnabled() && obj.LogRefs() {
+ logEvent(obj, "unregistered")
+ }
+ }
+}
+
+// LogIncRef logs a reference increment.
+func LogIncRef(obj CheckedObject, refs int64) {
+ if leakCheckEnabled() && obj.LogRefs() {
+ logEvent(obj, fmt.Sprintf("IncRef to %d", refs))
+ }
+}
+
+// LogTryIncRef logs a successful TryIncRef call.
+func LogTryIncRef(obj CheckedObject, refs int64) {
+ if leakCheckEnabled() && obj.LogRefs() {
+ logEvent(obj, fmt.Sprintf("TryIncRef to %d", refs))
+ }
+}
+
+// LogDecRef logs a reference decrement.
+func LogDecRef(obj CheckedObject, refs int64) {
+ if leakCheckEnabled() && obj.LogRefs() {
+ logEvent(obj, fmt.Sprintf("DecRef to %d", refs))
+ }
+}
+
+// logEvent logs a message for the given reference-counted object.
+//
+// obj.LogRefs() should be checked before calling logEvent, in order to avoid
+// calling any text processing needed to evaluate msg.
+func logEvent(obj CheckedObject, msg string) {
+ log.Infof("[%s %p] %s:", obj.RefType(), obj, msg)
+ log.Infof(refs_vfs1.FormatStack(refs_vfs1.RecordStack()))
+}
+
+// DoLeakCheck iterates through the live object map and logs a message for each
+// object. It is called once no reference-counted objects should be reachable
+// anymore, at which point anything left in the map is considered a leak.
+func DoLeakCheck() {
+ if leakCheckEnabled() {
+ liveObjectsMu.Lock()
+ defer liveObjectsMu.Unlock()
+ leaked := len(liveObjects)
+ if leaked > 0 {
+ log.Warningf("Leak checking detected %d leaked objects:", leaked)
+ for obj := range liveObjects {
+ log.Warningf(obj.LeakMessage())
+ }
+ }
+ }
+}
diff --git a/pkg/refs_vfs2/refs_template.go b/pkg/refsvfs2/refs_template.go
index d9b552896..8f50b4ee6 100644
--- a/pkg/refs_vfs2/refs_template.go
+++ b/pkg/refsvfs2/refs_template.go
@@ -21,20 +21,24 @@ package refs_template
import (
"fmt"
- "runtime"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/log"
- refs_vfs1 "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
)
+// enableLogging indicates whether reference-related events should be logged (with
+// stack traces). This is false by default and should only be set to true for
+// debugging purposes, as it can generate an extremely large amount of output
+// and drastically degrade performance.
+const enableLogging = false
+
// T is the type of the reference counted object. It is only used to customize
// debug output when leak checking.
type T interface{}
-// ownerType is used to customize logging. Note that we use a pointer to T so
-// that we do not copy the entire object when passed as a format parameter.
-var ownerType *T
+// obj is used to customize logging. Note that we use a pointer to T so that
+// we do not copy the entire object when passed as a format parameter.
+var obj *T
// Refs implements refs.RefCounter. It keeps a reference count using atomic
// operations and calls the destructor when the count reaches zero.
@@ -42,11 +46,6 @@ var ownerType *T
// Note that the number of references is actually refCount + 1 so that a default
// zero-value Refs object contains one reference.
//
-// TODO(gvisor.dev/issue/1486): Store stack traces when leak check is enabled in
-// a map with 16-bit hashes, and store the hash in the top 16 bits of refCount.
-// This will allow us to add stack trace information to the leak messages
-// without growing the size of Refs.
-//
// +stateify savable
type Refs struct {
// refCount is composed of two fields:
@@ -59,24 +58,24 @@ type Refs struct {
refCount int64
}
-func (r *Refs) finalize() {
- var note string
- switch refs_vfs1.GetLeakMode() {
- case refs_vfs1.NoLeakChecking:
- return
- case refs_vfs1.UninitializedLeakChecking:
- note = "(Leak checker uninitialized): "
- }
- if n := r.ReadRefs(); n != 0 {
- log.Warningf("%sRefs %p owned by %T garbage collected with ref count of %d (want 0)", note, r, ownerType, n)
- }
+// RefType implements refsvfs2.CheckedObject.RefType.
+func (r *Refs) RefType() string {
+ return fmt.Sprintf("%T", obj)[1:]
+}
+
+// LeakMessage implements refsvfs2.CheckedObject.LeakMessage.
+func (r *Refs) LeakMessage() string {
+ return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs())
}
-// EnableLeakCheck checks for reference leaks when Refs gets garbage collected.
+// LogRefs implements refsvfs2.CheckedObject.LogRefs.
+func (r *Refs) LogRefs() bool {
+ return enableLogging
+}
+
+// EnableLeakCheck enables reference leak checking on r.
func (r *Refs) EnableLeakCheck() {
- if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking {
- runtime.SetFinalizer(r, (*Refs).finalize)
- }
+ refsvfs2.Register(r)
}
// ReadRefs returns the current number of references. The returned count is
@@ -90,8 +89,10 @@ func (r *Refs) ReadRefs() int64 {
//
//go:nosplit
func (r *Refs) IncRef() {
- if v := atomic.AddInt64(&r.refCount, 1); v <= 0 {
- panic(fmt.Sprintf("Incrementing non-positive ref count %p owned by %T", r, ownerType))
+ v := atomic.AddInt64(&r.refCount, 1)
+ refsvfs2.LogIncRef(r, v+1)
+ if v <= 0 {
+ panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType()))
}
}
@@ -104,15 +105,15 @@ func (r *Refs) IncRef() {
//go:nosplit
func (r *Refs) TryIncRef() bool {
const speculativeRef = 1 << 32
- v := atomic.AddInt64(&r.refCount, speculativeRef)
- if int32(v) < 0 {
+ if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) < 0 {
// This object has already been freed.
atomic.AddInt64(&r.refCount, -speculativeRef)
return false
}
// Turn into a real reference.
- atomic.AddInt64(&r.refCount, -speculativeRef+1)
+ v := atomic.AddInt64(&r.refCount, -speculativeRef+1)
+ refsvfs2.LogTryIncRef(r, v+1)
return true
}
@@ -129,14 +130,23 @@ func (r *Refs) TryIncRef() bool {
//
//go:nosplit
func (r *Refs) DecRef(destroy func()) {
- switch v := atomic.AddInt64(&r.refCount, -1); {
+ v := atomic.AddInt64(&r.refCount, -1)
+ refsvfs2.LogDecRef(r, v+1)
+ switch {
case v < -1:
- panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %T", r, ownerType))
+ panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType()))
case v == -1:
+ refsvfs2.Unregister(r)
// Call the destructor.
if destroy != nil {
destroy()
}
}
}
+
+func (r *Refs) afterLoad() {
+ if r.ReadRefs() > 0 {
+ r.EnableLeakCheck()
+ }
+}
diff --git a/pkg/sentry/control/state.go b/pkg/sentry/control/state.go
index 41feeffe3..d800f2c85 100644
--- a/pkg/sentry/control/state.go
+++ b/pkg/sentry/control/state.go
@@ -69,5 +69,5 @@ func (s *State) Save(o *SaveOpts, _ *struct{}) error {
s.Kernel.Kill(kernel.ExitStatus{})
},
}
- return saveOpts.Save(s.Kernel, s.Watchdog)
+ return saveOpts.Save(s.Kernel.SupervisorContext(), s.Kernel, s.Watchdog)
}
diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go
index 655ea549b..ff5d49fbd 100644
--- a/pkg/sentry/devices/tundev/tundev.go
+++ b/pkg/sentry/devices/tundev/tundev.go
@@ -39,6 +39,8 @@ const (
)
// tunDevice implements vfs.Device for /dev/net/tun.
+//
+// +stateify savable
type tunDevice struct{}
// Open implements vfs.Device.Open.
@@ -53,6 +55,8 @@ func (tunDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opt
}
// tunFD implements vfs.FileDescriptionImpl for /dev/net/tun.
+//
+// +stateify savable
type tunFD struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go
index 1390a9a7f..4468f5dd2 100644
--- a/pkg/sentry/fs/fsutil/host_file_mapper.go
+++ b/pkg/sentry/fs/fsutil/host_file_mapper.go
@@ -70,6 +70,13 @@ func (f *HostFileMapper) Init() {
f.mappings = make(map[uint64]mapping)
}
+// IsInited returns true if f.Init() has been called. This is used when
+// restoring a checkpoint that contains a HostFileMapper that may or may not
+// have been initialized.
+func (f *HostFileMapper) IsInited() bool {
+ return f.refs != nil
+}
+
// NewHostFileMapper returns an initialized HostFileMapper allocated on the
// heap with no references or cached mappings.
func NewHostFileMapper() *HostFileMapper {
diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go
index 3c66dc3c2..6b3627813 100644
--- a/pkg/sentry/fs/gofer/path.go
+++ b/pkg/sentry/fs/gofer/path.go
@@ -25,7 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// maxFilenameLen is the maximum length of a filename. This is dictated by 9P's
@@ -305,7 +304,7 @@ func (i *inodeOperations) createInternalFifo(ctx context.Context, dir *fs.Inode,
}
// First create a pipe.
- p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)
+ p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize)
// Wrap the fileOps with our Fifo.
iops := &fifo{
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go
index e555672ad..52061175f 100644
--- a/pkg/sentry/fs/proc/sys_net.go
+++ b/pkg/sentry/fs/proc/sys_net.go
@@ -86,9 +86,9 @@ func (*tcpMemInode) Truncate(context.Context, *fs.Inode, int64) error {
}
// GetFile implements fs.InodeOperations.GetFile.
-func (m *tcpMemInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+func (t *tcpMemInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
flags.Pread = true
- return fs.NewFile(ctx, dirent, flags, &tcpMemFile{tcpMemInode: m}), nil
+ return fs.NewFile(ctx, dirent, flags, &tcpMemFile{tcpMemInode: t}), nil
}
// +stateify savable
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index 22d658acf..450044c9c 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -92,6 +92,7 @@ func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bo
"gid_map": newGIDMap(t, msrc),
"io": newIO(t, msrc, isThreadGroup),
"maps": newMaps(t, msrc),
+ "mem": newMem(t, msrc),
"mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc),
"mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc),
"net": newNetDir(t, msrc),
@@ -399,6 +400,88 @@ func newNamespaceDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
return newProcInode(t, d, msrc, fs.SpecialDirectory, t)
}
+// memData implements fs.Inode for /proc/[pid]/mem.
+//
+// +stateify savable
+type memData struct {
+ fsutil.SimpleFileInode
+
+ t *kernel.Task
+}
+
+// memDataFile implements fs.FileOperations for /proc/[pid]/mem.
+//
+// +stateify savable
+type memDataFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoWrite `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ t *kernel.Task
+}
+
+func newMem(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ inode := &memData{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0400), linux.PROC_SUPER_MAGIC),
+ t: t,
+ }
+ return newProcInode(t, inode, msrc, fs.SpecialFile, t)
+}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (m *memData) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (m *memData) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ // TODO(gvisor.dev/issue/260): Add check for PTRACE_MODE_ATTACH_FSCREDS
+ // Permission to read this file is governed by PTRACE_MODE_ATTACH_FSCREDS
+ // Since we dont implement setfsuid/setfsgid we can just use PTRACE_MODE_ATTACH
+ if !kernel.ContextCanTrace(ctx, m.t, true) {
+ return nil, syserror.EACCES
+ }
+ if err := checkTaskState(m.t); err != nil {
+ return nil, err
+ }
+ // Enable random access reads
+ flags.Pread = true
+ return fs.NewFile(ctx, dirent, flags, &memDataFile{t: m.t}), nil
+}
+
+// Read implements fs.FileOperations.Read.
+func (m *memDataFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ mm, err := getTaskMM(m.t)
+ if err != nil {
+ return 0, nil
+ }
+ defer mm.DecUsers(ctx)
+ // Buffer the read data because of MM locks
+ buf := make([]byte, dst.NumBytes())
+ n, readErr := mm.CopyIn(ctx, usermem.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true})
+ if n > 0 {
+ if _, err := dst.CopyOut(ctx, buf[:n]); err != nil {
+ return 0, syserror.EFAULT
+ }
+ return int64(n), nil
+ }
+ if readErr != nil {
+ return 0, syserror.EIO
+ }
+ return 0, nil
+}
+
// mapsData implements seqfile.SeqSource for /proc/[pid]/maps.
//
// +stateify savable
diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go
index fc0498f17..d6c65301c 100644
--- a/pkg/sentry/fs/tmpfs/inode_file.go
+++ b/pkg/sentry/fs/tmpfs/inode_file.go
@@ -431,9 +431,6 @@ func (rw *fileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
// Continue.
seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{}
-
- default:
- break
}
}
return done, nil
@@ -532,9 +529,6 @@ func (rw *fileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error)
// Write to that memory as usual.
seg, gap = rw.f.data.Insert(gap, gapMR, fr.Start), fsutil.FileRangeGapIterator{}
-
- default:
- break
}
}
return done, nil
diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go
index 998b697ca..cf4ed5de0 100644
--- a/pkg/sentry/fs/tmpfs/tmpfs.go
+++ b/pkg/sentry/fs/tmpfs/tmpfs.go
@@ -336,7 +336,7 @@ type Fifo struct {
// NewFifo creates a new named pipe.
func NewFifo(ctx context.Context, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource) *fs.Inode {
// First create a pipe.
- p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)
+ p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize)
// Build pipe InodeOperations.
iops := pipe.NewInodeOperations(ctx, perms, p)
diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD
index 84baaac66..6af3c3781 100644
--- a/pkg/sentry/fsimpl/devpts/BUILD
+++ b/pkg/sentry/fsimpl/devpts/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "root_inode_refs.go",
package = "devpts",
prefix = "rootInode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "rootInode",
},
@@ -33,6 +33,7 @@ go_library(
"//pkg/marshal",
"//pkg/marshal/primitive",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/fs",
diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go
index d5c5aaa8c..346cca558 100644
--- a/pkg/sentry/fsimpl/devpts/devpts.go
+++ b/pkg/sentry/fsimpl/devpts/devpts.go
@@ -60,7 +60,7 @@ func (fstype *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Vir
}
fstype.initOnce.Do(func() {
- fs, root, err := fstype.newFilesystem(vfsObj, creds)
+ fs, root, err := fstype.newFilesystem(ctx, vfsObj, creds)
if err != nil {
fstype.initErr = err
return
@@ -93,7 +93,7 @@ type filesystem struct {
// newFilesystem creates a new devpts filesystem with root directory and ptmx
// master inode. It returns the filesystem and root Dentry.
-func (fstype *FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*filesystem, *kernfs.Dentry, error) {
+func (fstype *FilesystemType) newFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*filesystem, *kernfs.Dentry, error) {
devMinor, err := vfsObj.GetAnonBlockDevMinor()
if err != nil {
return nil, nil, err
@@ -108,19 +108,19 @@ func (fstype *FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds
root := &rootInode{
replicas: make(map[uint32]*replicaInode),
}
- root.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555)
+ root.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555)
root.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
root.EnableLeakCheck()
var rootD kernfs.Dentry
- rootD.Init(&fs.Filesystem, root)
+ rootD.InitRoot(&fs.Filesystem, root)
// Construct the pts master inode and dentry. Linux always uses inode
// id 2 for ptmx. See fs/devpts/inode.c:mknod_ptmx.
master := &masterInode{
root: root,
}
- master.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 2, linux.ModeCharacterDevice|0666)
+ master.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, devMinor, 2, linux.ModeCharacterDevice|0666)
// Add the master as a child of the root.
links := root.OrderedChildren.Populate(map[string]kernfs.Inode{
@@ -170,7 +170,7 @@ type rootInode struct {
var _ kernfs.Inode = (*rootInode)(nil)
// allocateTerminal creates a new Terminal and installs a pts node for it.
-func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error) {
+func (i *rootInode) allocateTerminal(ctx context.Context, creds *auth.Credentials) (*Terminal, error) {
i.mu.Lock()
defer i.mu.Unlock()
if i.nextIdx == math.MaxUint32 {
@@ -192,7 +192,7 @@ func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error)
}
// Linux always uses pty index + 3 as the inode id. See
// fs/devpts/inode.c:devpts_pty_new().
- replica.InodeAttrs.Init(creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600)
+ replica.InodeAttrs.Init(ctx, creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600)
i.replicas[idx] = replica
return t, nil
@@ -248,9 +248,10 @@ func (i *rootInode) Lookup(ctx context.Context, name string) (kernfs.Inode, erro
}
// IterDirents implements kernfs.Inode.IterDirents.
-func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+func (i *rootInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
i.mu.Lock()
defer i.mu.Unlock()
+ i.InodeAttrs.TouchAtime(ctx, mnt)
ids := make([]int, 0, len(i.replicas))
for id := range i.replicas {
ids = append(ids, int(id))
diff --git a/pkg/sentry/fsimpl/devpts/line_discipline.go b/pkg/sentry/fsimpl/devpts/line_discipline.go
index e6b0e81cf..ae95fdd08 100644
--- a/pkg/sentry/fsimpl/devpts/line_discipline.go
+++ b/pkg/sentry/fsimpl/devpts/line_discipline.go
@@ -100,10 +100,10 @@ type lineDiscipline struct {
column int
// masterWaiter is used to wait on the master end of the TTY.
- masterWaiter waiter.Queue `state:"zerovalue"`
+ masterWaiter waiter.Queue
// replicaWaiter is used to wait on the replica end of the TTY.
- replicaWaiter waiter.Queue `state:"zerovalue"`
+ replicaWaiter waiter.Queue
}
func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline {
diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go
index fda30fb93..e91fa26a4 100644
--- a/pkg/sentry/fsimpl/devpts/master.go
+++ b/pkg/sentry/fsimpl/devpts/master.go
@@ -50,7 +50,7 @@ var _ kernfs.Inode = (*masterInode)(nil)
// Open implements kernfs.Inode.Open.
func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- t, err := mi.root.allocateTerminal(rp.Credentials())
+ t, err := mi.root.allocateTerminal(ctx, rp.Credentials())
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/fsimpl/devtmpfs/BUILD b/pkg/sentry/fsimpl/devtmpfs/BUILD
index 01bbee5ad..e49a04c1b 100644
--- a/pkg/sentry/fsimpl/devtmpfs/BUILD
+++ b/pkg/sentry/fsimpl/devtmpfs/BUILD
@@ -4,7 +4,10 @@ licenses(["notice"])
go_library(
name = "devtmpfs",
- srcs = ["devtmpfs.go"],
+ srcs = [
+ "devtmpfs.go",
+ "save_restore.go",
+ ],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
diff --git a/pkg/sentry/fsimpl/devtmpfs/save_restore.go b/pkg/sentry/fsimpl/devtmpfs/save_restore.go
new file mode 100644
index 000000000..28832d850
--- /dev/null
+++ b/pkg/sentry/fsimpl/devtmpfs/save_restore.go
@@ -0,0 +1,23 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devtmpfs
+
+// afterLoad is invoked by stateify.
+func (fst *FilesystemType) afterLoad() {
+ if fst.fs != nil {
+ // Ensure that we don't create another filesystem.
+ fst.initOnce.Do(func() {})
+ }
+}
diff --git a/pkg/sentry/fsimpl/eventfd/eventfd.go b/pkg/sentry/fsimpl/eventfd/eventfd.go
index 1c27ad700..5b29f2358 100644
--- a/pkg/sentry/fsimpl/eventfd/eventfd.go
+++ b/pkg/sentry/fsimpl/eventfd/eventfd.go
@@ -43,7 +43,7 @@ type EventFileDescription struct {
// queue is used to notify interested parties when the event object
// becomes readable or writable.
- queue waiter.Queue `state:"zerovalue"`
+ queue waiter.Queue
// mu protects the fields below.
mu sync.Mutex `state:"nosave"`
diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD
index 045d7ab08..2158b1bbc 100644
--- a/pkg/sentry/fsimpl/fuse/BUILD
+++ b/pkg/sentry/fsimpl/fuse/BUILD
@@ -20,7 +20,7 @@ go_template_instance(
out = "inode_refs.go",
package = "fuse",
prefix = "inode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "inode",
},
@@ -49,6 +49,7 @@ go_library(
"//pkg/log",
"//pkg/marshal",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/fsimpl/devtmpfs",
"//pkg/sentry/fsimpl/kernfs",
diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go
index 5986133e9..95c475a65 100644
--- a/pkg/sentry/fsimpl/fuse/dev_test.go
+++ b/pkg/sentry/fsimpl/fuse/dev_test.go
@@ -315,7 +315,7 @@ func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.F
readPayload.MarshalUnsafe(outBuf[outHdrLen:])
outIOseq := usermem.BytesIOSequence(outBuf)
- n, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{})
+ _, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{})
if err != nil {
t.Fatalf("Write failed :%v", err)
}
diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go
index e39df21c6..6de416da0 100644
--- a/pkg/sentry/fsimpl/fuse/fusefs.go
+++ b/pkg/sentry/fsimpl/fuse/fusefs.go
@@ -205,7 +205,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
// root is the fusefs root directory.
- root := fs.newRootInode(creds, fsopts.rootMode)
+ root := fs.newRoot(ctx, creds, fsopts.rootMode)
return fs.VFSFilesystem(), root.VFSDentry(), nil
}
@@ -284,21 +284,21 @@ type inode struct {
link string
}
-func (fs *filesystem) newRootInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry {
+func (fs *filesystem) newRoot(ctx context.Context, creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry {
i := &inode{fs: fs, nodeID: 1}
- i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755)
+ i.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755)
i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
i.EnableLeakCheck()
var d kernfs.Dentry
- d.Init(&fs.Filesystem, i)
+ d.InitRoot(&fs.Filesystem, i)
return &d
}
-func (fs *filesystem) newInode(nodeID uint64, attr linux.FUSEAttr) kernfs.Inode {
+func (fs *filesystem) newInode(ctx context.Context, nodeID uint64, attr linux.FUSEAttr) kernfs.Inode {
i := &inode{fs: fs, nodeID: nodeID}
creds := auth.Credentials{EffectiveKGID: auth.KGID(attr.UID), EffectiveKUID: auth.KUID(attr.UID)}
- i.InodeAttrs.Init(&creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode))
+ i.InodeAttrs.Init(ctx, &creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode))
atomic.StoreUint64(&i.size, attr.Size)
i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
i.EnableLeakCheck()
@@ -424,7 +424,7 @@ func (i *inode) Keep() bool {
}
// IterDirents implements kernfs.Inode.IterDirents.
-func (*inode) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+func (*inode) IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
return offset, nil
}
@@ -544,7 +544,7 @@ func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMo
if opcode != linux.FUSE_LOOKUP && ((out.Attr.Mode&linux.S_IFMT)^uint32(fileType) != 0 || out.NodeID == 0 || out.NodeID == linux.FUSE_ROOT_ID) {
return nil, syserror.EIO
}
- child := i.fs.newInode(out.NodeID, out.Attr)
+ child := i.fs.newInode(ctx, out.NodeID, out.Attr)
return child, nil
}
@@ -696,7 +696,7 @@ func (i *inode) getAttr(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOp
}
// Set the metadata of kernfs.InodeAttrs.
- if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{
+ if err := i.InodeAttrs.SetStat(ctx, fs, creds, vfs.SetStatOptions{
Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor),
}); err != nil {
return linux.FUSEAttr{}, err
@@ -812,7 +812,7 @@ func (i *inode) setAttr(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
}
// Set the metadata of kernfs.InodeAttrs.
- if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{
+ if err := i.InodeAttrs.SetStat(ctx, fs, creds, vfs.SetStatOptions{
Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor),
}); err != nil {
return err
diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go
index 625d1547f..2d396e84c 100644
--- a/pkg/sentry/fsimpl/fuse/read_write.go
+++ b/pkg/sentry/fsimpl/fuse/read_write.go
@@ -132,7 +132,7 @@ func (fs *filesystem) ReadCallback(ctx context.Context, fd *regularFileFD, off u
// May need to update the signature.
i := fd.inode()
- // TODO(gvisor.dev/issue/1193): Invalidate or update atime.
+ i.InodeAttrs.TouchAtime(ctx, fd.vfsfd.Mount())
// Reached EOF.
if sizeRead < size {
@@ -179,6 +179,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64,
Flags: fd.statusFlags(),
}
+ inode := fd.inode()
var written uint32
// This loop is intended for fragmented write where the bytes to write is
@@ -203,7 +204,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64,
in.Offset = off + uint64(written)
in.Size = toWrite
- req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_WRITE, &in)
+ req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), inode.nodeID, linux.FUSE_WRITE, &in)
if err != nil {
return 0, err
}
@@ -237,6 +238,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64,
break
}
}
+ inode.InodeAttrs.TouchCMtime(ctx)
return written, nil
}
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
index ad0afc41b..4c3e9acf8 100644
--- a/pkg/sentry/fsimpl/gofer/BUILD
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -38,6 +38,7 @@ go_library(
"host_named_pipe.go",
"p9file.go",
"regular_file.go",
+ "save_restore.go",
"socket.go",
"special_file.go",
"symlink.go",
@@ -53,6 +54,7 @@ go_library(
"//pkg/log",
"//pkg/p9",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/lock",
@@ -70,6 +72,7 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usage",
"//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/unet",
diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go
index 18c884b59..ce1b2a390 100644
--- a/pkg/sentry/fsimpl/gofer/directory.go
+++ b/pkg/sentry/fsimpl/gofer/directory.go
@@ -16,16 +16,17 @@ package gofer
import (
"fmt"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -92,7 +93,7 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) {
child := &dentry{
refs: 1, // held by d
fs: d.fs,
- ino: d.fs.nextSyntheticIno(),
+ ino: d.fs.nextIno(),
mode: uint32(opts.mode),
uid: uint32(opts.kuid),
gid: uint32(opts.kgid),
@@ -100,6 +101,7 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) {
hostFD: -1,
nlink: uint32(2),
}
+ refsvfs2.Register(child)
switch opts.mode.FileType() {
case linux.S_IFDIR:
// Nothing else needs to be done.
@@ -235,7 +237,7 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) {
}
dirent := vfs.Dirent{
Name: p9d.Name,
- Ino: uint64(inoFromPath(p9d.QID.Path)),
+ Ino: d.fs.inoFromQIDPath(p9d.QID.Path),
NextOff: int64(len(dirents) + 1),
}
// p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 94d96261b..bbb01148b 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -30,12 +30,11 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
// Sync implements vfs.FilesystemImpl.Sync.
func (fs *filesystem) Sync(ctx context.Context) error {
- // Snapshot current syncable dentries and special files.
+ // Snapshot current syncable dentries and special file FDs.
fs.syncMu.Lock()
ds := make([]*dentry, 0, len(fs.syncableDentries))
for d := range fs.syncableDentries {
@@ -53,22 +52,28 @@ func (fs *filesystem) Sync(ctx context.Context) error {
// regardless.
var retErr error
- // Sync regular files.
+ // Sync syncable dentries.
for _, d := range ds {
- err := d.syncCachedFile(ctx)
+ err := d.syncCachedFile(ctx, true /* forFilesystemSync */)
d.DecRef(ctx)
- if err != nil && retErr == nil {
- retErr = err
+ if err != nil {
+ ctx.Infof("gofer.filesystem.Sync: dentry.syncCachedFile failed: %v", err)
+ if retErr == nil {
+ retErr = err
+ }
}
}
// Sync special files, which may be writable but do not use dentry shared
// handles (so they won't be synced by the above).
for _, sffd := range sffds {
- err := sffd.Sync(ctx)
+ err := sffd.sync(ctx, true /* forFilesystemSync */)
sffd.vfsfd.DecRef(ctx)
- if err != nil && retErr == nil {
- retErr = err
+ if err != nil {
+ ctx.Infof("gofer.filesystem.Sync: specialFileFD.sync failed: %v", err)
+ if retErr == nil {
+ retErr = err
+ }
}
}
@@ -229,7 +234,7 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir
return nil, err
}
if child != nil {
- if !file.isNil() && inoFromPath(qid.Path) == child.ino {
+ if !file.isNil() && qid.Path == child.qidPath {
// The file at this path hasn't changed. Just update cached metadata.
file.close(ctx)
child.updateFromP9AttrsLocked(attrMask, &attr)
@@ -256,7 +261,7 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir
// treat their invalidation as deletion.
child.setDeleted()
parent.syntheticChildren--
- child.decRefLocked()
+ child.decRefNoCaching()
parent.dirents = nil
}
*ds = appendDentry(*ds, child)
@@ -366,9 +371,6 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if len(name) > maxFilenameLen {
return syserror.ENAMETOOLONG
}
- if !dir && rp.MustBeDir() {
- return syserror.ENOENT
- }
if parent.isDeleted() {
return syserror.ENOENT
}
@@ -383,6 +385,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if child := parent.children[name]; child != nil {
return syserror.EEXIST
}
+ if !dir && rp.MustBeDir() {
+ return syserror.ENOENT
+ }
if createInSyntheticDir == nil {
return syserror.EPERM
}
@@ -402,6 +407,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if child := parent.children[name]; child != nil && child.isSynthetic() {
return syserror.EEXIST
}
+ if !dir && rp.MustBeDir() {
+ return syserror.ENOENT
+ }
// The existence of a non-synthetic dentry at name would be inconclusive
// because the file it represents may have been deleted from the remote
// filesystem, so we would need to make an RPC to revalidate the dentry.
@@ -422,6 +430,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if child := parent.children[name]; child != nil {
return syserror.EEXIST
}
+ if !dir && rp.MustBeDir() {
+ return syserror.ENOENT
+ }
// No cached dentry exists; however, there might still be an existing file
// at name. As above, we attempt the file creation RPC anyway.
if err := createInRemoteDir(parent, name, &ds); err != nil {
@@ -625,7 +636,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b
child.setDeleted()
if child.isSynthetic() {
parent.syntheticChildren--
- child.decRefLocked()
+ child.decRefNoCaching()
}
ds = appendDentry(ds, child)
}
@@ -836,7 +847,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
mode: opts.Mode,
kuid: creds.EffectiveKUID,
kgid: creds.EffectiveKGID,
- pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize),
+ pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize),
})
return nil
}
@@ -1355,7 +1366,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
replaced.setDeleted()
if replaced.isSynthetic() {
newParent.syntheticChildren--
- replaced.decRefLocked()
+ replaced.decRefNoCaching()
}
ds = appendDentry(ds, replaced)
}
@@ -1364,7 +1375,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
// with reference counts and queue oldParent for checkCachingLocked if the
// parent isn't actually changing.
if oldParent != newParent {
- oldParent.decRefLocked()
+ oldParent.decRefNoCaching()
ds = appendDentry(ds, oldParent)
newParent.IncRef()
if renamed.isSynthetic() {
@@ -1512,7 +1523,6 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath
d.IncRef()
return &endpoint{
dentry: d,
- file: d.file.file,
path: opts.Addr,
}, nil
}
@@ -1591,7 +1601,3 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
defer fs.renameMu.RUnlock()
return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b)
}
-
-func (fs *filesystem) nextSyntheticIno() inodeNumber {
- return inodeNumber(atomic.AddUint64(&fs.syntheticSeq, 1) | syntheticInoMask)
-}
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index f1dad1b08..6f82ce61b 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -26,6 +26,9 @@
// *** "memmap.Mappable locks taken by Translate" below this point
// dentry.handleMu
// dentry.dataMu
+// filesystem.inoMu
+// specialFileFD.mu
+// specialFileFD.bufMu
//
// Locking dentry.dirMu in multiple dentries requires that either ancestor
// dentries are locked before descendant dentries, or that filesystem.renameMu
@@ -36,7 +39,6 @@ import (
"fmt"
"strconv"
"strings"
- "sync"
"sync/atomic"
"syscall"
@@ -44,6 +46,8 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
+ refs_vfs1 "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -53,6 +57,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/unet"
"gvisor.dev/gvisor/pkg/usermem"
@@ -81,7 +86,7 @@ type filesystem struct {
iopts InternalFilesystemOptions
// client is the client used by this filesystem. client is immutable.
- client *p9.Client `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
+ client *p9.Client `state:"nosave"`
// clock is a realtime clock used to set timestamps in file operations.
clock ktime.Clock
@@ -89,6 +94,9 @@ type filesystem struct {
// devMinor is the filesystem's minor device number. devMinor is immutable.
devMinor uint32
+ // root is the root dentry. root is immutable.
+ root *dentry
+
// renameMu serves two purposes:
//
// - It synchronizes path resolution with renaming initiated by this
@@ -103,39 +111,35 @@ type filesystem struct {
// cachedDentries contains all dentries with 0 references. (Due to race
// conditions, it may also contain dentries with non-zero references.)
- // cachedDentriesLen is the number of dentries in cachedDentries. These
- // fields are protected by renameMu.
+ // cachedDentriesLen is the number of dentries in cachedDentries. These fields
+ // are protected by renameMu.
cachedDentries dentryList
cachedDentriesLen uint64
- // syncableDentries contains all dentries in this filesystem for which
- // !dentry.file.isNil(). specialFileFDs contains all open specialFileFDs.
- // These fields are protected by syncMu.
+ // syncableDentries contains all non-synthetic dentries. specialFileFDs
+ // contains all open specialFileFDs. These fields are protected by syncMu.
syncMu sync.Mutex `state:"nosave"`
syncableDentries map[*dentry]struct{}
specialFileFDs map[*specialFileFD]struct{}
- // syntheticSeq stores a counter to used to generate unique inodeNumber for
- // synthetic dentries.
- syntheticSeq uint64
-}
+ // inoByQIDPath maps previously-observed QID.Paths to inode numbers
+ // assigned to those paths. inoByQIDPath is not preserved across
+ // checkpoint/restore because QIDs may be reused between different gofer
+ // processes, so QIDs may be repeated for different files across
+ // checkpoint/restore. inoByQIDPath is protected by inoMu.
+ inoMu sync.Mutex `state:"nosave"`
+ inoByQIDPath map[uint64]uint64 `state:"nosave"`
-// inodeNumber represents inode number reported in Dirent.Ino. For regular
-// dentries, it comes from QID.Path from the 9P server. Synthetic dentries
-// have have their inodeNumber generated sequentially, with the MSB reserved to
-// prevent conflicts with regular dentries.
-//
-// +stateify savable
-type inodeNumber uint64
+ // lastIno is the last inode number assigned to a file. lastIno is accessed
+ // using atomic memory operations.
+ lastIno uint64
-// Reserve MSB for synthetic mounts.
-const syntheticInoMask = uint64(1) << 63
+ // savedDentryRW records open read/write handles during save/restore.
+ savedDentryRW map[*dentry]savedDentryRW
-func inoFromPath(path uint64) inodeNumber {
- if path&syntheticInoMask != 0 {
- log.Warningf("Dropping MSB from ino, collision is possible. Original: %d, new: %d", path, path&^syntheticInoMask)
- }
- return inodeNumber(path &^ syntheticInoMask)
+ // released is nonzero once filesystem.Release has been called. It is accessed
+ // with atomic memory operations.
+ released int32
}
// +stateify savable
@@ -149,8 +153,7 @@ type filesystemOptions struct {
msize uint32
version string
- // maxCachedDentries is the maximum number of dentries with 0 references
- // retained by the client.
+ // maxCachedDentries is the maximum size of filesystem.cachedDentries.
maxCachedDentries uint64
// If forcePageCache is true, host FDs may not be used for application
@@ -247,6 +250,10 @@ const (
//
// +stateify savable
type InternalFilesystemOptions struct {
+ // If UniqueID is non-empty, it is an opaque string used to reassociate the
+ // filesystem with a new server FD during restoration from checkpoint.
+ UniqueID string
+
// If LeakConnection is true, do not close the connection to the server
// when the Filesystem is released. This is necessary for deployments in
// which servers can handle only a single client and report failure if that
@@ -286,46 +293,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
mopts := vfs.GenericParseMountOptions(opts.Data)
var fsopts filesystemOptions
- // Check that the transport is "fd".
- trans, ok := mopts["trans"]
- if !ok {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: transport must be specified as 'trans=fd'")
- return nil, nil, syserror.EINVAL
- }
- delete(mopts, "trans")
- if trans != "fd" {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: unsupported transport: trans=%s", trans)
- return nil, nil, syserror.EINVAL
- }
-
- // Check that read and write FDs are provided and identical.
- rfdstr, ok := mopts["rfdno"]
- if !ok {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD must be specified as 'rfdno=<file descriptor>")
- return nil, nil, syserror.EINVAL
- }
- delete(mopts, "rfdno")
- rfd, err := strconv.Atoi(rfdstr)
+ fd, err := getFDFromMountOptionsMap(ctx, mopts)
if err != nil {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid read FD: rfdno=%s", rfdstr)
- return nil, nil, syserror.EINVAL
- }
- wfdstr, ok := mopts["wfdno"]
- if !ok {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: write FD must be specified as 'wfdno=<file descriptor>")
- return nil, nil, syserror.EINVAL
- }
- delete(mopts, "wfdno")
- wfd, err := strconv.Atoi(wfdstr)
- if err != nil {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid write FD: wfdno=%s", wfdstr)
- return nil, nil, syserror.EINVAL
- }
- if rfd != wfd {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD (%d) and write FD (%d) must be equal", rfd, wfd)
- return nil, nil, syserror.EINVAL
+ return nil, nil, err
}
- fsopts.fd = rfd
+ fsopts.fd = fd
// Get the attach name.
fsopts.aname = "/"
@@ -441,57 +413,44 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
// If !ok, iopts being the zero value is correct.
- // Establish a connection with the server.
- conn, err := unet.NewSocket(fsopts.fd)
+ // Construct the filesystem object.
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
if err != nil {
return nil, nil, err
}
+ fs := &filesystem{
+ mfp: mfp,
+ opts: fsopts,
+ iopts: iopts,
+ clock: ktime.RealtimeClockFromContext(ctx),
+ devMinor: devMinor,
+ syncableDentries: make(map[*dentry]struct{}),
+ specialFileFDs: make(map[*specialFileFD]struct{}),
+ inoByQIDPath: make(map[uint64]uint64),
+ }
+ fs.vfsfs.Init(vfsObj, &fstype, fs)
- // Perform version negotiation with the server.
- ctx.UninterruptibleSleepStart(false)
- client, err := p9.NewClient(conn, fsopts.msize, fsopts.version)
- ctx.UninterruptibleSleepFinish(false)
- if err != nil {
- conn.Close()
+ // Connect to the server.
+ if err := fs.dial(ctx); err != nil {
return nil, nil, err
}
- // Ownership of conn has been transferred to client.
// Perform attach to obtain the filesystem root.
ctx.UninterruptibleSleepStart(false)
- attached, err := client.Attach(fsopts.aname)
+ attached, err := fs.client.Attach(fsopts.aname)
ctx.UninterruptibleSleepFinish(false)
if err != nil {
- client.Close()
+ fs.vfsfs.DecRef(ctx)
return nil, nil, err
}
attachFile := p9file{attached}
qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask())
if err != nil {
attachFile.close(ctx)
- client.Close()
+ fs.vfsfs.DecRef(ctx)
return nil, nil, err
}
- // Construct the filesystem object.
- devMinor, err := vfsObj.GetAnonBlockDevMinor()
- if err != nil {
- attachFile.close(ctx)
- client.Close()
- return nil, nil, err
- }
- fs := &filesystem{
- mfp: mfp,
- opts: fsopts,
- iopts: iopts,
- client: client,
- clock: ktime.RealtimeClockFromContext(ctx),
- devMinor: devMinor,
- syncableDentries: make(map[*dentry]struct{}),
- specialFileFDs: make(map[*specialFileFD]struct{}),
- }
- fs.vfsfs.Init(vfsObj, &fstype, fs)
-
// Construct the root dentry.
root, err := fs.newDentry(ctx, attachFile, qid, attrMask, &attr)
if err != nil {
@@ -500,25 +459,87 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
return nil, nil, err
}
// Set the root's reference count to 2. One reference is returned to the
- // caller, and the other is deliberately leaked to prevent the root from
- // being "cached" and subsequently evicted. Its resources will still be
- // cleaned up by fs.Release().
+ // caller, and the other is held by fs to prevent the root from being "cached"
+ // and subsequently evicted.
root.refs = 2
+ fs.root = root
return &fs.vfsfs, &root.vfsd, nil
}
+func getFDFromMountOptionsMap(ctx context.Context, mopts map[string]string) (int, error) {
+ // Check that the transport is "fd".
+ trans, ok := mopts["trans"]
+ if !ok || trans != "fd" {
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: transport must be specified as 'trans=fd'")
+ return -1, syserror.EINVAL
+ }
+ delete(mopts, "trans")
+
+ // Check that read and write FDs are provided and identical.
+ rfdstr, ok := mopts["rfdno"]
+ if !ok {
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD must be specified as 'rfdno=<file descriptor>'")
+ return -1, syserror.EINVAL
+ }
+ delete(mopts, "rfdno")
+ rfd, err := strconv.Atoi(rfdstr)
+ if err != nil {
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid read FD: rfdno=%s", rfdstr)
+ return -1, syserror.EINVAL
+ }
+ wfdstr, ok := mopts["wfdno"]
+ if !ok {
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: write FD must be specified as 'wfdno=<file descriptor>'")
+ return -1, syserror.EINVAL
+ }
+ delete(mopts, "wfdno")
+ wfd, err := strconv.Atoi(wfdstr)
+ if err != nil {
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid write FD: wfdno=%s", wfdstr)
+ return -1, syserror.EINVAL
+ }
+ if rfd != wfd {
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD (%d) and write FD (%d) must be equal", rfd, wfd)
+ return -1, syserror.EINVAL
+ }
+ return rfd, nil
+}
+
+// Preconditions: fs.client == nil.
+func (fs *filesystem) dial(ctx context.Context) error {
+ // Establish a connection with the server.
+ conn, err := unet.NewSocket(fs.opts.fd)
+ if err != nil {
+ return err
+ }
+
+ // Perform version negotiation with the server.
+ ctx.UninterruptibleSleepStart(false)
+ client, err := p9.NewClient(conn, fs.opts.msize, fs.opts.version)
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ conn.Close()
+ return err
+ }
+ // Ownership of conn has been transferred to client.
+
+ fs.client = client
+ return nil
+}
+
// Release implements vfs.FilesystemImpl.Release.
func (fs *filesystem) Release(ctx context.Context) {
- mf := fs.mfp.MemoryFile()
+ atomic.StoreInt32(&fs.released, 1)
+ mf := fs.mfp.MemoryFile()
fs.syncMu.Lock()
for d := range fs.syncableDentries {
d.handleMu.Lock()
d.dataMu.Lock()
if h := d.writeHandleLocked(); h.isOpen() {
// Write dirty cached data to the remote file.
- if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, fs.mfp.MemoryFile(), h.writeFromBlocksAt); err != nil {
+ if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, h.writeFromBlocksAt); err != nil {
log.Warningf("gofer.filesystem.Release: failed to flush dentry: %v", err)
}
// TODO(jamieliu): Do we need to flushf/fsync d?
@@ -539,6 +560,21 @@ func (fs *filesystem) Release(ctx context.Context) {
// fs.
fs.syncMu.Unlock()
+ // If leak checking is enabled, release all outstanding references in the
+ // filesystem. We deliberately avoid doing this outside of leak checking; we
+ // have released all external resources above rather than relying on dentry
+ // destructors.
+ if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking {
+ fs.renameMu.Lock()
+ fs.root.releaseSyntheticRecursiveLocked(ctx)
+ fs.evictAllCachedDentriesLocked(ctx)
+ fs.renameMu.Unlock()
+
+ // An extra reference was held by the filesystem on the root to prevent it from
+ // being cached/evicted.
+ fs.root.DecRef(ctx)
+ }
+
if !fs.iopts.LeakConnection {
// Close the connection to the server. This implicitly clunks all fids.
fs.client.Close()
@@ -547,6 +583,31 @@ func (fs *filesystem) Release(ctx context.Context) {
fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
}
+// releaseSyntheticRecursiveLocked traverses the tree with root d and decrements
+// the reference count on every synthetic dentry. Synthetic dentries have one
+// reference for existence that should be dropped during filesystem.Release.
+//
+// Precondition: d.fs.renameMu is locked.
+func (d *dentry) releaseSyntheticRecursiveLocked(ctx context.Context) {
+ if d.isSynthetic() {
+ d.decRefNoCaching()
+ d.checkCachingLocked(ctx)
+ }
+ if d.isDir() {
+ var children []*dentry
+ d.dirMu.Lock()
+ for _, child := range d.children {
+ children = append(children, child)
+ }
+ d.dirMu.Unlock()
+ for _, child := range children {
+ if child != nil {
+ child.releaseSyntheticRecursiveLocked(ctx)
+ }
+ }
+ }
+}
+
// dentry implements vfs.DentryImpl.
//
// +stateify savable
@@ -574,12 +635,15 @@ type dentry struct {
// filesystem.renameMu.
name string
+ // qidPath is the p9.QID.Path for this file. qidPath is immutable.
+ qidPath uint64
+
// file is the unopened p9.File that backs this dentry. file is immutable.
//
// If file.isNil(), this dentry represents a synthetic file, i.e. a file
// that does not exist on the remote filesystem. As of this writing, the
// only files that can be synthetic are sockets, pipes, and directories.
- file p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
+ file p9file `state:"nosave"`
// If deleted is non-zero, the file represented by this dentry has been
// deleted. deleted is accessed using atomic memory operations.
@@ -623,12 +687,12 @@ type dentry struct {
// To mutate:
// - Lock metadataMu and use atomic operations to update because we might
// have atomic readers that don't hold the lock.
- metadataMu sync.Mutex `state:"nosave"`
- ino inodeNumber // immutable
- mode uint32 // type is immutable, perms are mutable
- uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
- gid uint32 // auth.KGID, but ...
- blockSize uint32 // 0 if unknown
+ metadataMu sync.Mutex `state:"nosave"`
+ ino uint64 // immutable
+ mode uint32 // type is immutable, perms are mutable
+ uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
+ gid uint32 // auth.KGID, but ...
+ blockSize uint32 // 0 if unknown
// Timestamps, all nsecs from the Unix epoch.
atime int64
mtime int64
@@ -679,9 +743,9 @@ type dentry struct {
// (isNil() == false), it may be mutated with handleMu locked, but cannot
// be closed until the dentry is destroyed.
handleMu sync.RWMutex `state:"nosave"`
- readFile p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
- writeFile p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
- hostFD int32
+ readFile p9file `state:"nosave"`
+ writeFile p9file `state:"nosave"`
+ hostFD int32 `state:"nosave"`
dataMu sync.RWMutex `state:"nosave"`
@@ -758,8 +822,9 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma
d := &dentry{
fs: fs,
+ qidPath: qid.Path,
file: file,
- ino: inoFromPath(qid.Path),
+ ino: fs.inoFromQIDPath(qid.Path),
mode: uint32(attr.Mode),
uid: uint32(fs.opts.dfltuid),
gid: uint32(fs.opts.dfltgid),
@@ -795,13 +860,28 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma
d.nlink = uint32(attr.NLink)
}
d.vfsd.Init(d)
-
+ refsvfs2.Register(d)
fs.syncMu.Lock()
fs.syncableDentries[d] = struct{}{}
fs.syncMu.Unlock()
return d, nil
}
+func (fs *filesystem) inoFromQIDPath(qidPath uint64) uint64 {
+ fs.inoMu.Lock()
+ defer fs.inoMu.Unlock()
+ if ino, ok := fs.inoByQIDPath[qidPath]; ok {
+ return ino
+ }
+ ino := fs.nextIno()
+ fs.inoByQIDPath[qidPath] = ino
+ return ino
+}
+
+func (fs *filesystem) nextIno() uint64 {
+ return atomic.AddUint64(&fs.lastIno, 1)
+}
+
func (d *dentry) isSynthetic() bool {
return d.file.isNil()
}
@@ -853,7 +933,7 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) {
}
}
-// Preconditions: !d.isSynthetic()
+// Preconditions: !d.isSynthetic().
func (d *dentry) updateFromGetattr(ctx context.Context) error {
// Use d.readFile or d.writeFile, which represent 9P fids that have been
// opened, in preference to d.file, which represents a 9P fid that has not.
@@ -916,10 +996,10 @@ func (d *dentry) statTo(stat *linux.Statx) {
// This is consistent with regularFileFD.Seek(), which treats regular files
// as having no holes.
stat.Blocks = (stat.Size + 511) / 512
- stat.Atime = statxTimestampFromDentry(atomic.LoadInt64(&d.atime))
- stat.Btime = statxTimestampFromDentry(atomic.LoadInt64(&d.btime))
- stat.Ctime = statxTimestampFromDentry(atomic.LoadInt64(&d.ctime))
- stat.Mtime = statxTimestampFromDentry(atomic.LoadInt64(&d.mtime))
+ stat.Atime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.atime))
+ stat.Btime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.btime))
+ stat.Ctime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.ctime))
+ stat.Mtime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.mtime))
stat.DevMajor = linux.UNNAMED_MAJOR
stat.DevMinor = d.fs.devMinor
}
@@ -967,10 +1047,10 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
// Use client clocks for timestamps.
now = d.fs.clock.Now().Nanoseconds()
if stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec == linux.UTIME_NOW {
- stat.Atime = statxTimestampFromDentry(now)
+ stat.Atime = linux.NsecToStatxTimestamp(now)
}
if stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec == linux.UTIME_NOW {
- stat.Mtime = statxTimestampFromDentry(now)
+ stat.Mtime = linux.NsecToStatxTimestamp(now)
}
}
@@ -1029,11 +1109,11 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
// !d.cachedMetadataAuthoritative() then we returned after calling
// d.file.setAttr(). For the same reason, now must have been initialized.
if stat.Mask&linux.STATX_ATIME != 0 {
- atomic.StoreInt64(&d.atime, dentryTimestampFromStatx(stat.Atime))
+ atomic.StoreInt64(&d.atime, stat.Atime.ToNsec())
atomic.StoreUint32(&d.atimeDirty, 0)
}
if stat.Mask&linux.STATX_MTIME != 0 {
- atomic.StoreInt64(&d.mtime, dentryTimestampFromStatx(stat.Mtime))
+ atomic.StoreInt64(&d.mtime, stat.Mtime.ToNsec())
atomic.StoreUint32(&d.mtimeDirty, 0)
}
atomic.StoreInt64(&d.ctime, now)
@@ -1139,17 +1219,19 @@ func dentryGIDFromP9GID(gid p9.GID) uint32 {
func (d *dentry) IncRef() {
// d.refs may be 0 if d.fs.renameMu is locked, which serializes against
// d.checkCachingLocked().
- atomic.AddInt64(&d.refs, 1)
+ r := atomic.AddInt64(&d.refs, 1)
+ refsvfs2.LogIncRef(d, r)
}
// TryIncRef implements vfs.DentryImpl.TryIncRef.
func (d *dentry) TryIncRef() bool {
for {
- refs := atomic.LoadInt64(&d.refs)
- if refs <= 0 {
+ r := atomic.LoadInt64(&d.refs)
+ if r <= 0 {
return false
}
- if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) {
+ if atomic.CompareAndSwapInt64(&d.refs, r, r+1) {
+ refsvfs2.LogTryIncRef(d, r+1)
return true
}
}
@@ -1157,22 +1239,41 @@ func (d *dentry) TryIncRef() bool {
// DecRef implements vfs.DentryImpl.DecRef.
func (d *dentry) DecRef(ctx context.Context) {
- if refs := atomic.AddInt64(&d.refs, -1); refs == 0 {
+ if d.decRefNoCaching() == 0 {
d.fs.renameMu.Lock()
d.checkCachingLocked(ctx)
d.fs.renameMu.Unlock()
- } else if refs < 0 {
- panic("gofer.dentry.DecRef() called without holding a reference")
}
}
-// decRefLocked decrements d's reference count without calling
+// decRefNoCaching decrements d's reference count without calling
// d.checkCachingLocked, even if d's reference count reaches 0; callers are
// responsible for ensuring that d.checkCachingLocked will be called later.
-func (d *dentry) decRefLocked() {
- if refs := atomic.AddInt64(&d.refs, -1); refs < 0 {
- panic("gofer.dentry.decRefLocked() called without holding a reference")
+func (d *dentry) decRefNoCaching() int64 {
+ r := atomic.AddInt64(&d.refs, -1)
+ refsvfs2.LogDecRef(d, r)
+ if r < 0 {
+ panic("gofer.dentry.decRefNoCaching() called without holding a reference")
}
+ return r
+}
+
+// RefType implements refsvfs2.CheckedObject.Type.
+func (d *dentry) RefType() string {
+ return "gofer.dentry"
+}
+
+// LeakMessage implements refsvfs2.CheckedObject.LeakMessage.
+func (d *dentry) LeakMessage() string {
+ return fmt.Sprintf("[gofer.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs))
+}
+
+// LogRefs implements refsvfs2.CheckedObject.LogRefs.
+//
+// This should only be set to true for debugging purposes, as it can generate an
+// extremely large amount of output and drastically degrade performance.
+func (d *dentry) LogRefs() bool {
+ return false
}
// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
@@ -1223,6 +1324,10 @@ func (d *dentry) checkCachingLocked(ctx context.Context) {
// resolution, which requires renameMu, so if d.refs is zero then it will
// remain zero while we hold renameMu for writing.)
refs := atomic.LoadInt64(&d.refs)
+ if refs == -1 {
+ // Dentry has already been destroyed.
+ return
+ }
if refs > 0 {
if d.cached {
d.fs.cachedDentries.Remove(d)
@@ -1231,10 +1336,6 @@ func (d *dentry) checkCachingLocked(ctx context.Context) {
}
return
}
- if refs == -1 {
- // Dentry has already been destroyed.
- return
- }
// Deleted and invalidated dentries with zero references are no longer
// reachable by path resolution and should be dropped immediately.
if d.vfsd.IsDead() {
@@ -1257,6 +1358,16 @@ func (d *dentry) checkCachingLocked(ctx context.Context) {
if d.watches.Size() > 0 {
return
}
+
+ if atomic.LoadInt32(&d.fs.released) != 0 {
+ if d.parent != nil {
+ d.parent.dirMu.Lock()
+ delete(d.parent.children, d.name)
+ d.parent.dirMu.Unlock()
+ }
+ d.destroyLocked(ctx)
+ }
+
// If d is already cached, just move it to the front of the LRU.
if d.cached {
d.fs.cachedDentries.Remove(d)
@@ -1269,33 +1380,48 @@ func (d *dentry) checkCachingLocked(ctx context.Context) {
d.fs.cachedDentriesLen++
d.cached = true
if d.fs.cachedDentriesLen > d.fs.opts.maxCachedDentries {
- victim := d.fs.cachedDentries.Back()
- d.fs.cachedDentries.Remove(victim)
- d.fs.cachedDentriesLen--
- victim.cached = false
- // victim.refs may have become non-zero from an earlier path resolution
- // since it was inserted into fs.cachedDentries.
- if atomic.LoadInt64(&victim.refs) == 0 {
- if victim.parent != nil {
- victim.parent.dirMu.Lock()
- if !victim.vfsd.IsDead() {
- // Note that victim can't be a mount point (in any mount
- // namespace), since VFS holds references on mount points.
- d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd)
- delete(victim.parent.children, victim.name)
- // We're only deleting the dentry, not the file it
- // represents, so we don't need to update
- // victimParent.dirents etc.
- }
- victim.parent.dirMu.Unlock()
- }
- victim.destroyLocked(ctx)
- }
+ d.fs.evictCachedDentryLocked(ctx)
// Whether or not victim was destroyed, we brought fs.cachedDentriesLen
// back down to fs.opts.maxCachedDentries, so we don't loop.
}
}
+// Precondition: fs.renameMu must be locked for writing; it may be temporarily
+// unlocked.
+func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) {
+ for fs.cachedDentriesLen != 0 {
+ fs.evictCachedDentryLocked(ctx)
+ }
+}
+
+// Preconditions:
+// * fs.renameMu must be locked for writing; it may be temporarily unlocked.
+// * fs.cachedDentriesLen != 0.
+func (fs *filesystem) evictCachedDentryLocked(ctx context.Context) {
+ victim := fs.cachedDentries.Back()
+ fs.cachedDentries.Remove(victim)
+ fs.cachedDentriesLen--
+ victim.cached = false
+ // victim.refs may have become non-zero from an earlier path resolution
+ // since it was inserted into fs.cachedDentries.
+ if atomic.LoadInt64(&victim.refs) == 0 {
+ if victim.parent != nil {
+ victim.parent.dirMu.Lock()
+ if !victim.vfsd.IsDead() {
+ // Note that victim can't be a mount point (in any mount
+ // namespace), since VFS holds references on mount points.
+ fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd)
+ delete(victim.parent.children, victim.name)
+ // We're only deleting the dentry, not the file it
+ // represents, so we don't need to update
+ // victimParent.dirents etc.
+ }
+ victim.parent.dirMu.Unlock()
+ }
+ victim.destroyLocked(ctx)
+ }
+}
+
// destroyLocked destroys the dentry.
//
// Preconditions:
@@ -1373,13 +1499,10 @@ func (d *dentry) destroyLocked(ctx context.Context) {
// Drop the reference held by d on its parent without recursively locking
// d.fs.renameMu.
- if d.parent != nil {
- if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 {
- d.parent.checkCachingLocked(ctx)
- } else if refs < 0 {
- panic("gofer.dentry.DecRef() called without holding a reference")
- }
+ if d.parent != nil && d.parent.decRefNoCaching() == 0 {
+ d.parent.checkCachingLocked(ctx)
}
+ refsvfs2.Unregister(d)
}
func (d *dentry) isDeleted() bool {
@@ -1623,6 +1746,33 @@ func (d *dentry) syncRemoteFileLocked(ctx context.Context) error {
return nil
}
+func (d *dentry) syncCachedFile(ctx context.Context, forFilesystemSync bool) error {
+ d.handleMu.RLock()
+ defer d.handleMu.RUnlock()
+ h := d.writeHandleLocked()
+ if h.isOpen() {
+ // Write back dirty pages to the remote file.
+ d.dataMu.Lock()
+ err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), h.writeFromBlocksAt)
+ d.dataMu.Unlock()
+ if err != nil {
+ return err
+ }
+ }
+ if err := d.syncRemoteFileLocked(ctx); err != nil {
+ if !forFilesystemSync {
+ return err
+ }
+ // Only return err if we can reasonably have expected sync to succeed
+ // (d is a regular file and was opened for writing).
+ if d.isRegularFile() && h.isOpen() {
+ return err
+ }
+ ctx.Debugf("gofer.dentry.syncCachedFile: syncing non-writable or non-regular-file dentry failed: %v", err)
+ }
+ return nil
+}
+
// incLinks increments link count.
func (d *dentry) incLinks() {
if atomic.LoadUint32(&d.nlink) == 0 {
@@ -1650,7 +1800,7 @@ type fileDescription struct {
vfs.FileDescriptionDefaultImpl
vfs.LockFD
- lockLogging sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
+ lockLogging sync.Once `state:"nosave"`
}
func (fd *fileDescription) filesystem() *filesystem {
diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go
index bfe75dfe4..76f08e252 100644
--- a/pkg/sentry/fsimpl/gofer/gofer_test.go
+++ b/pkg/sentry/fsimpl/gofer/gofer_test.go
@@ -26,12 +26,13 @@ import (
func TestDestroyIdempotent(t *testing.T) {
ctx := contexttest.Context(t)
fs := filesystem{
- mfp: pgalloc.MemoryFileProviderFromContext(ctx),
- syncableDentries: make(map[*dentry]struct{}),
+ mfp: pgalloc.MemoryFileProviderFromContext(ctx),
opts: filesystemOptions{
// Test relies on no dentry being held in the cache.
maxCachedDentries: 0,
},
+ syncableDentries: make(map[*dentry]struct{}),
+ inoByQIDPath: make(map[uint64]uint64),
}
attr := &p9.Attr{
diff --git a/pkg/sentry/fsimpl/gofer/host_named_pipe.go b/pkg/sentry/fsimpl/gofer/host_named_pipe.go
index 7294de7d6..c7bf10007 100644
--- a/pkg/sentry/fsimpl/gofer/host_named_pipe.go
+++ b/pkg/sentry/fsimpl/gofer/host_named_pipe.go
@@ -51,8 +51,24 @@ func blockUntilNonblockingPipeHasWriter(ctx context.Context, fd int32) error {
if ok {
return nil
}
- if err := sleepBetweenNamedPipeOpenChecks(ctx); err != nil {
- return err
+ if sleepErr := sleepBetweenNamedPipeOpenChecks(ctx); sleepErr != nil {
+ // Another application thread may have opened this pipe for
+ // writing, succeeded because we previously opened the pipe for
+ // reading, and subsequently interrupted us for checkpointing (e.g.
+ // this occurs in mknod tests under cooperative save/restore). In
+ // this case, our open has to succeed for the checkpoint to include
+ // a readable FD for the pipe, which is in turn necessary to
+ // restore the other thread's writable FD for the same pipe
+ // (otherwise it will get ENXIO). So we have to check
+ // nonblockingPipeHasWriter() once last time.
+ ok, err := nonblockingPipeHasWriter(fd)
+ if err != nil {
+ return err
+ }
+ if ok {
+ return nil
+ }
+ return sleepErr
}
}
}
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index f8b19bae7..dc8a890cb 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -18,7 +18,6 @@ import (
"fmt"
"io"
"math"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -31,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -624,23 +624,7 @@ func regularFileSeekLocked(ctx context.Context, d *dentry, fdOffset, offset int6
// Sync implements vfs.FileDescriptionImpl.Sync.
func (fd *regularFileFD) Sync(ctx context.Context) error {
- return fd.dentry().syncCachedFile(ctx)
-}
-
-func (d *dentry) syncCachedFile(ctx context.Context) error {
- d.handleMu.RLock()
- defer d.handleMu.RUnlock()
-
- if h := d.writeHandleLocked(); h.isOpen() {
- d.dataMu.Lock()
- // Write dirty cached data to the remote file.
- err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), h.writeFromBlocksAt)
- d.dataMu.Unlock()
- if err != nil {
- return err
- }
- }
- return d.syncRemoteFileLocked(ctx)
+ return fd.dentry().syncCachedFile(ctx, false /* lowSyncExpectations */)
}
// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
@@ -913,7 +897,7 @@ type dentryPlatformFile struct {
hostFileMapper fsutil.HostFileMapper
// hostFileMapperInitOnce is used to lazily initialize hostFileMapper.
- hostFileMapperInitOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
+ hostFileMapperInitOnce sync.Once `state:"nosave"`
}
// IncRef implements memmap.File.IncRef.
diff --git a/pkg/sentry/fsimpl/gofer/save_restore.go b/pkg/sentry/fsimpl/gofer/save_restore.go
new file mode 100644
index 000000000..17849dcc0
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/save_restore.go
@@ -0,0 +1,329 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+ "io"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type saveRestoreContextID int
+
+const (
+ // CtxRestoreServerFDMap is a Context.Value key for a map[string]int
+ // mapping filesystem unique IDs (cf. InternalFilesystemOptions.UniqueID)
+ // to host FDs.
+ CtxRestoreServerFDMap saveRestoreContextID = iota
+)
+
+// +stateify savable
+type savedDentryRW struct {
+ read bool
+ write bool
+}
+
+// PreprareSave implements vfs.FilesystemImplSaveRestoreExtension.PrepareSave.
+func (fs *filesystem) PrepareSave(ctx context.Context) error {
+ if len(fs.iopts.UniqueID) == 0 {
+ return fmt.Errorf("gofer.filesystem with no UniqueID cannot be saved")
+ }
+
+ // Purge cached dentries, which may not be reopenable after restore due to
+ // permission changes.
+ fs.renameMu.Lock()
+ fs.evictAllCachedDentriesLocked(ctx)
+ fs.renameMu.Unlock()
+
+ // Buffer pipe data so that it's available for reading after restore. (This
+ // is a legacy VFS1 feature.)
+ fs.syncMu.Lock()
+ for sffd := range fs.specialFileFDs {
+ if sffd.dentry().fileType() == linux.S_IFIFO && sffd.vfsfd.IsReadable() {
+ if err := sffd.savePipeData(ctx); err != nil {
+ fs.syncMu.Unlock()
+ return err
+ }
+ }
+ }
+ fs.syncMu.Unlock()
+
+ // Flush local state to the remote filesystem.
+ if err := fs.Sync(ctx); err != nil {
+ return err
+ }
+
+ fs.savedDentryRW = make(map[*dentry]savedDentryRW)
+ return fs.root.prepareSaveRecursive(ctx)
+}
+
+// Preconditions:
+// * fd represents a pipe.
+// * fd is readable.
+func (fd *specialFileFD) savePipeData(ctx context.Context) error {
+ fd.bufMu.Lock()
+ defer fd.bufMu.Unlock()
+ var buf [usermem.PageSize]byte
+ for {
+ n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:])), ^uint64(0))
+ if n != 0 {
+ fd.buf = append(fd.buf, buf[:n]...)
+ }
+ if err != nil {
+ if err == io.EOF || err == syserror.EAGAIN {
+ break
+ }
+ return err
+ }
+ }
+ if len(fd.buf) != 0 {
+ atomic.StoreUint32(&fd.haveBuf, 1)
+ }
+ return nil
+}
+
+func (d *dentry) prepareSaveRecursive(ctx context.Context) error {
+ if d.isRegularFile() && !d.cachedMetadataAuthoritative() {
+ // Get updated metadata for d in case we need to perform metadata
+ // validation during restore.
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return err
+ }
+ }
+ if !d.readFile.isNil() || !d.writeFile.isNil() {
+ d.fs.savedDentryRW[d] = savedDentryRW{
+ read: !d.readFile.isNil(),
+ write: !d.writeFile.isNil(),
+ }
+ }
+ d.dirMu.Lock()
+ defer d.dirMu.Unlock()
+ for _, child := range d.children {
+ if child != nil {
+ if err := child.prepareSaveRecursive(ctx); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// beforeSave is invoked by stateify.
+func (d *dentry) beforeSave() {
+ if d.vfsd.IsDead() {
+ panic(fmt.Sprintf("gofer.dentry(%q).beforeSave: deleted and invalidated dentries can't be restored", genericDebugPathname(d)))
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (d *dentry) afterLoad() {
+ d.hostFD = -1
+ if atomic.LoadInt64(&d.refs) != -1 {
+ refsvfs2.Register(d)
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (d *dentryPlatformFile) afterLoad() {
+ if d.hostFileMapper.IsInited() {
+ // Ensure that we don't call d.hostFileMapper.Init() again.
+ d.hostFileMapperInitOnce.Do(func() {})
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (fd *specialFileFD) afterLoad() {
+ fd.handle.fd = -1
+}
+
+// CompleteRestore implements
+// vfs.FilesystemImplSaveRestoreExtension.CompleteRestore.
+func (fs *filesystem) CompleteRestore(ctx context.Context, opts vfs.CompleteRestoreOptions) error {
+ fdmapv := ctx.Value(CtxRestoreServerFDMap)
+ if fdmapv == nil {
+ return fmt.Errorf("no server FD map available")
+ }
+ fdmap := fdmapv.(map[string]int)
+ fd, ok := fdmap[fs.iopts.UniqueID]
+ if !ok {
+ return fmt.Errorf("no server FD available for filesystem with unique ID %q", fs.iopts.UniqueID)
+ }
+ fs.opts.fd = fd
+ if err := fs.dial(ctx); err != nil {
+ return err
+ }
+ fs.inoByQIDPath = make(map[uint64]uint64)
+
+ // Restore the filesystem root.
+ ctx.UninterruptibleSleepStart(false)
+ attached, err := fs.client.Attach(fs.opts.aname)
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ return err
+ }
+ attachFile := p9file{attached}
+ qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask())
+ if err != nil {
+ return err
+ }
+ if err := fs.root.restoreFile(ctx, attachFile, qid, attrMask, &attr, &opts); err != nil {
+ return err
+ }
+
+ // Restore remaining dentries.
+ if err := fs.root.restoreDescendantsRecursive(ctx, &opts); err != nil {
+ return err
+ }
+
+ // Re-open handles for specialFileFDs. Unlike the initial open
+ // (dentry.openSpecialFile()), pipes are always opened without blocking;
+ // non-readable pipe FDs are opened last to ensure that they don't get
+ // ENXIO if another specialFileFD represents the read end of the same pipe.
+ // This is consistent with VFS1.
+ haveWriteOnlyPipes := false
+ for fd := range fs.specialFileFDs {
+ if fd.dentry().fileType() == linux.S_IFIFO && !fd.vfsfd.IsReadable() {
+ haveWriteOnlyPipes = true
+ continue
+ }
+ if err := fd.completeRestore(ctx); err != nil {
+ return err
+ }
+ }
+ if haveWriteOnlyPipes {
+ for fd := range fs.specialFileFDs {
+ if fd.dentry().fileType() == linux.S_IFIFO && !fd.vfsfd.IsReadable() {
+ if err := fd.completeRestore(ctx); err != nil {
+ return err
+ }
+ }
+ }
+ }
+
+ // Discard state only required during restore.
+ fs.savedDentryRW = nil
+
+ return nil
+}
+
+func (d *dentry) restoreFile(ctx context.Context, file p9file, qid p9.QID, attrMask p9.AttrMask, attr *p9.Attr, opts *vfs.CompleteRestoreOptions) error {
+ d.file = file
+
+ // Gofers do not preserve QID across checkpoint/restore, so:
+ //
+ // - We must assume that the remote filesystem did not change in a way that
+ // would invalidate dentries, since we can't revalidate dentries by
+ // checking QIDs.
+ //
+ // - We need to associate the new QID.Path with the existing d.ino.
+ d.qidPath = qid.Path
+ d.fs.inoMu.Lock()
+ d.fs.inoByQIDPath[qid.Path] = d.ino
+ d.fs.inoMu.Unlock()
+
+ // Check metadata stability before updating metadata.
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ if d.isRegularFile() {
+ if opts.ValidateFileSizes {
+ if !attrMask.Size {
+ return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: file size not available", genericDebugPathname(d))
+ }
+ if d.size != attr.Size {
+ return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: size changed from %d to %d", genericDebugPathname(d), d.size, attr.Size)
+ }
+ }
+ if opts.ValidateFileModificationTimestamps {
+ if !attrMask.MTime {
+ return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime not available", genericDebugPathname(d))
+ }
+ if want := dentryTimestampFromP9(attr.MTimeSeconds, attr.MTimeNanoSeconds); d.mtime != want {
+ return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime changed from %+v to %+v", genericDebugPathname(d), linux.NsecToStatxTimestamp(d.mtime), linux.NsecToStatxTimestamp(want))
+ }
+ }
+ }
+ if !d.cachedMetadataAuthoritative() {
+ d.updateFromP9AttrsLocked(attrMask, attr)
+ }
+
+ if rw, ok := d.fs.savedDentryRW[d]; ok {
+ if err := d.ensureSharedHandle(ctx, rw.read, rw.write, false /* trunc */); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// Preconditions: d is not synthetic.
+func (d *dentry) restoreDescendantsRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error {
+ for _, child := range d.children {
+ if child == nil {
+ continue
+ }
+ if _, ok := d.fs.syncableDentries[child]; !ok {
+ // child is synthetic.
+ continue
+ }
+ if err := child.restoreRecursive(ctx, opts); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// Preconditions: d is not synthetic (but note that since this function
+// restores d.file, d.file.isNil() is always true at this point, so this can
+// only be detected by checking filesystem.syncableDentries). d.parent has been
+// restored.
+func (d *dentry) restoreRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error {
+ qid, file, attrMask, attr, err := d.parent.file.walkGetAttrOne(ctx, d.name)
+ if err != nil {
+ return err
+ }
+ if err := d.restoreFile(ctx, file, qid, attrMask, &attr, opts); err != nil {
+ return err
+ }
+ return d.restoreDescendantsRecursive(ctx, opts)
+}
+
+func (fd *specialFileFD) completeRestore(ctx context.Context) error {
+ d := fd.dentry()
+ h, err := openHandle(ctx, d.file, fd.vfsfd.IsReadable(), fd.vfsfd.IsWritable(), false /* trunc */)
+ if err != nil {
+ return err
+ }
+ fd.handle = h
+
+ ftype := d.fileType()
+ fd.haveQueue = (ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK) && fd.handle.fd >= 0
+ if fd.haveQueue {
+ if err := fdnotifier.AddFD(fd.handle.fd, &fd.queue); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go
index 326b940a7..a21199eac 100644
--- a/pkg/sentry/fsimpl/gofer/socket.go
+++ b/pkg/sentry/fsimpl/gofer/socket.go
@@ -42,9 +42,6 @@ type endpoint struct {
// dentry is the filesystem dentry which produced this endpoint.
dentry *dentry
- // file is the p9 file that contains a single unopened fid.
- file p9.File `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
-
// path is the sentry path where this endpoint is bound.
path string
}
@@ -116,7 +113,7 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect
}
func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFlags, queue *waiter.Queue) (*host.SCMConnectedEndpoint, *syserr.Error) {
- hostFile, err := e.file.Connect(flags)
+ hostFile, err := e.dentry.file.connect(ctx, flags)
if err != nil {
return nil, syserr.ErrConnectionRefused
}
@@ -131,7 +128,7 @@ func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFla
c, serr := host.NewSCMEndpoint(ctx, hostFD, queue, e.path)
if serr != nil {
- log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.file, flags, serr)
+ log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.dentry.file, flags, serr)
return nil, serr
}
return c, nil
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
index 71581736c..625400c0b 100644
--- a/pkg/sentry/fsimpl/gofer/special_file.go
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -15,7 +15,6 @@
package gofer
import (
- "sync"
"sync/atomic"
"syscall"
@@ -25,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
@@ -40,7 +40,7 @@ type specialFileFD struct {
fileDescription
// handle is used for file I/O. handle is immutable.
- handle handle `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
+ handle handle `state:"nosave"`
// isRegularFile is true if this FD represents a regular file which is only
// possible when filesystemOptions.regularFilesUseSpecialFileFD is in
@@ -54,12 +54,20 @@ type specialFileFD struct {
// haveQueue is true if this file description represents a file for which
// queue may send I/O readiness events. haveQueue is immutable.
- haveQueue bool
+ haveQueue bool `state:"nosave"`
queue waiter.Queue
// If seekable is true, off is the file offset. off is protected by mu.
mu sync.Mutex `state:"nosave"`
off int64
+
+ // If haveBuf is non-zero, this FD represents a pipe, and buf contains data
+ // read from the pipe from previous calls to specialFileFD.savePipeData().
+ // haveBuf and buf are protected by bufMu. haveBuf is accessed using atomic
+ // memory operations.
+ bufMu sync.Mutex `state:"nosave"`
+ haveBuf uint32
+ buf []byte
}
func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, flags uint32) (*specialFileFD, error) {
@@ -87,6 +95,9 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks,
}
return nil, err
}
+ d.fs.syncMu.Lock()
+ d.fs.specialFileFDs[fd] = struct{}{}
+ d.fs.syncMu.Unlock()
return fd, nil
}
@@ -161,26 +172,51 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs
return 0, syserror.EOPNOTSUPP
}
- // Going through dst.CopyOutFrom() holds MM locks around file operations of
- // unknown duration. For regularFileFD, doing so is necessary to support
- // mmap due to lock ordering; MM locks precede dentry.dataMu. That doesn't
- // hold here since specialFileFD doesn't client-cache data. Just buffer the
- // read instead.
if d := fd.dentry(); d.cachedMetadataAuthoritative() {
d.touchAtime(fd.vfsfd.Mount())
}
+
+ bufN := int64(0)
+ if atomic.LoadUint32(&fd.haveBuf) != 0 {
+ var err error
+ fd.bufMu.Lock()
+ if len(fd.buf) != 0 {
+ var n int
+ n, err = dst.CopyOut(ctx, fd.buf)
+ dst = dst.DropFirst(n)
+ fd.buf = fd.buf[n:]
+ if len(fd.buf) == 0 {
+ atomic.StoreUint32(&fd.haveBuf, 0)
+ fd.buf = nil
+ }
+ bufN = int64(n)
+ if offset >= 0 {
+ offset += bufN
+ }
+ }
+ fd.bufMu.Unlock()
+ if err != nil {
+ return bufN, err
+ }
+ }
+
+ // Going through dst.CopyOutFrom() would hold MM locks around file
+ // operations of unknown duration. For regularFileFD, doing so is necessary
+ // to support mmap due to lock ordering; MM locks precede dentry.dataMu.
+ // That doesn't hold here since specialFileFD doesn't client-cache data.
+ // Just buffer the read instead.
buf := make([]byte, dst.NumBytes())
n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset))
if err == syserror.EAGAIN {
err = syserror.ErrWouldBlock
}
if n == 0 {
- return 0, err
+ return bufN, err
}
if cp, cperr := dst.CopyOut(ctx, buf[:n]); cperr != nil {
- return int64(cp), cperr
+ return bufN + int64(cp), cperr
}
- return int64(n), err
+ return bufN + int64(n), err
}
// Read implements vfs.FileDescriptionImpl.Read.
@@ -217,16 +253,16 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off
}
d := fd.dentry()
- // If the regular file fd was opened with O_APPEND, make sure the file size
- // is updated. There is a possible race here if size is modified externally
- // after metadata cache is updated.
- if fd.isRegularFile && fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() {
- if err := d.updateFromGetattr(ctx); err != nil {
- return 0, offset, err
+ if fd.isRegularFile {
+ // If the regular file fd was opened with O_APPEND, make sure the file
+ // size is updated. There is a possible race here if size is modified
+ // externally after metadata cache is updated.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() {
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return 0, offset, err
+ }
}
- }
- if fd.isRegularFile {
// We need to hold the metadataMu *while* writing to a regular file.
d.metadataMu.Lock()
defer d.metadataMu.Unlock()
@@ -306,13 +342,31 @@ func (fd *specialFileFD) Seek(ctx context.Context, offset int64, whence int32) (
// Sync implements vfs.FileDescriptionImpl.Sync.
func (fd *specialFileFD) Sync(ctx context.Context) error {
- // If we have a host FD, fsyncing it is likely to be faster than an fsync
- // RPC.
- if fd.handle.fd >= 0 {
- ctx.UninterruptibleSleepStart(false)
- err := syscall.Fsync(int(fd.handle.fd))
- ctx.UninterruptibleSleepFinish(false)
- return err
+ return fd.sync(ctx, false /* forFilesystemSync */)
+}
+
+func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool) error {
+ err := func() error {
+ // If we have a host FD, fsyncing it is likely to be faster than an fsync
+ // RPC.
+ if fd.handle.fd >= 0 {
+ ctx.UninterruptibleSleepStart(false)
+ err := syscall.Fsync(int(fd.handle.fd))
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+ }
+ return fd.handle.file.fsync(ctx)
+ }()
+ if err != nil {
+ if !forFilesystemSync {
+ return err
+ }
+ // Only return err if we can reasonably have expected sync to succeed
+ // (fd represents a regular file that was opened for writing).
+ if fd.isRegularFile && fd.vfsfd.IsWritable() {
+ return err
+ }
+ ctx.Debugf("gofer.specialFileFD.sync: syncing non-writable or non-regular-file FD failed: %v", err)
}
- return fd.handle.file.fsync(ctx)
+ return nil
}
diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go
index 7e825caae..9cbe805b9 100644
--- a/pkg/sentry/fsimpl/gofer/time.go
+++ b/pkg/sentry/fsimpl/gofer/time.go
@@ -17,7 +17,6 @@ package gofer
import (
"sync/atomic"
- "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/vfs"
)
@@ -25,17 +24,6 @@ func dentryTimestampFromP9(s, ns uint64) int64 {
return int64(s*1e9 + ns)
}
-func dentryTimestampFromStatx(ts linux.StatxTimestamp) int64 {
- return ts.Sec*1e9 + int64(ts.Nsec)
-}
-
-func statxTimestampFromDentry(ns int64) linux.StatxTimestamp {
- return linux.StatxTimestamp{
- Sec: ns / 1e9,
- Nsec: uint32(ns % 1e9),
- }
-}
-
// Preconditions: d.cachedMetadataAuthoritative() == true.
func (d *dentry) touchAtime(mnt *vfs.Mount) {
if mnt.Flags.NoATime || mnt.ReadOnly() {
diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD
index 56bcf9bdb..4ae9d6d5e 100644
--- a/pkg/sentry/fsimpl/host/BUILD
+++ b/pkg/sentry/fsimpl/host/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "inode_refs.go",
package = "host",
prefix = "inode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "inode",
},
@@ -19,7 +19,7 @@ go_template_instance(
out = "connected_endpoint_refs.go",
package = "host",
prefix = "ConnectedEndpoint",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "ConnectedEndpoint",
},
@@ -33,7 +33,7 @@ go_library(
"host.go",
"inode_refs.go",
"ioctl_unsafe.go",
- "mmap.go",
+ "save_restore.go",
"socket.go",
"socket_iovec.go",
"socket_unsafe.go",
@@ -51,6 +51,7 @@ go_library(
"//pkg/log",
"//pkg/marshal/primitive",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/fs/fsutil",
diff --git a/pkg/sentry/fsimpl/host/control.go b/pkg/sentry/fsimpl/host/control.go
index 0135e4428..13ef48cb5 100644
--- a/pkg/sentry/fsimpl/host/control.go
+++ b/pkg/sentry/fsimpl/host/control.go
@@ -79,7 +79,7 @@ func fdsToFiles(ctx context.Context, fds []int) []*vfs.FileDescription {
}
// Create the file backed by hostFD.
- file, err := ImportFD(ctx, kernel.KernelFromContext(ctx).HostMount(), fd, false /* isTTY */)
+ file, err := NewFD(ctx, kernel.KernelFromContext(ctx).HostMount(), fd, &NewFDOptions{})
if err != nil {
ctx.Warningf("Error creating file from host FD: %v", err)
break
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index 698e913fe..39b902a3e 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -19,6 +19,7 @@ package host
import (
"fmt"
"math"
+ "sync/atomic"
"syscall"
"golang.org/x/sys/unix"
@@ -40,34 +41,97 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-func newInode(fs *filesystem, hostFD int, fileType linux.FileMode, isTTY bool) (*inode, error) {
- // Determine if hostFD is seekable. If not, this syscall will return ESPIPE
- // (see fs/read_write.c:llseek), e.g. for pipes, sockets, and some character
- // devices.
+// inode implements kernfs.Inode.
+//
+// +stateify savable
+type inode struct {
+ kernfs.InodeNoStatFS
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+ kernfs.CachedMappable
+ kernfs.InodeTemporary // This holds no meaning as this inode can't be Looked up and is always valid.
+
+ locks vfs.FileLocks
+
+ // When the reference count reaches zero, the host fd is closed.
+ inodeRefs
+
+ // hostFD contains the host fd that this file was originally created from,
+ // which must be available at time of restore.
+ //
+ // This field is initialized at creation time and is immutable.
+ hostFD int
+
+ // ino is an inode number unique within this filesystem.
+ //
+ // This field is initialized at creation time and is immutable.
+ ino uint64
+
+ // ftype is the file's type (a linux.S_IFMT mask).
+ //
+ // This field is initialized at creation time and is immutable.
+ ftype uint16
+
+ // mayBlock is true if hostFD is non-blocking, and operations on it may
+ // return EAGAIN or EWOULDBLOCK instead of blocking.
+ //
+ // This field is initialized at creation time and is immutable.
+ mayBlock bool
+
+ // seekable is false if lseek(hostFD) returns ESPIPE. We assume that file
+ // offsets are meaningful iff seekable is true.
+ //
+ // This field is initialized at creation time and is immutable.
+ seekable bool
+
+ // isTTY is true if this file represents a TTY.
+ //
+ // This field is initialized at creation time and is immutable.
+ isTTY bool
+
+ // savable is true if hostFD may be saved/restored by its numeric value.
+ //
+ // This field is initialized at creation time and is immutable.
+ savable bool
+
+ // Event queue for blocking operations.
+ queue waiter.Queue
+
+ // If haveBuf is non-zero, hostFD represents a pipe, and buf contains data
+ // read from the pipe from previous calls to inode.beforeSave(). haveBuf
+ // and buf are protected by bufMu. haveBuf is accessed using atomic memory
+ // operations.
+ bufMu sync.Mutex `state:"nosave"`
+ haveBuf uint32
+ buf []byte
+}
+
+func newInode(ctx context.Context, fs *filesystem, hostFD int, savable bool, fileType linux.FileMode, isTTY bool) (*inode, error) {
+ // Determine if hostFD is seekable.
_, err := unix.Seek(hostFD, 0, linux.SEEK_CUR)
seekable := err != syserror.ESPIPE
+ // We expect regular files to be seekable, as this is required for them to
+ // be memory-mappable.
+ if !seekable && fileType == syscall.S_IFREG {
+ ctx.Infof("host.newInode: host FD %d is a non-seekable regular file", hostFD)
+ return nil, syserror.ESPIPE
+ }
i := &inode{
- hostFD: hostFD,
- ino: fs.NextIno(),
- isTTY: isTTY,
- wouldBlock: wouldBlock(uint32(fileType)),
- seekable: seekable,
- // NOTE(b/38213152): Technically, some obscure char devices can be memory
- // mapped, but we only allow regular files.
- canMap: fileType == linux.S_IFREG,
- }
- i.pf.inode = i
+ hostFD: hostFD,
+ ino: fs.NextIno(),
+ ftype: uint16(fileType),
+ mayBlock: fileType != syscall.S_IFREG && fileType != syscall.S_IFDIR,
+ seekable: seekable,
+ isTTY: isTTY,
+ savable: savable,
+ }
+ i.CachedMappable.Init(hostFD)
i.EnableLeakCheck()
- // Non-seekable files can't be memory mapped, assert this.
- if !i.seekable && i.canMap {
- panic("files that can return EWOULDBLOCK (sockets, pipes, etc.) cannot be memory mapped")
- }
-
- // If the hostFD would block, we must set it to non-blocking and handle
- // blocking behavior in the sentry.
- if i.wouldBlock {
+ // If the hostFD can return EWOULDBLOCK when set to non-blocking, do so and
+ // handle blocking behavior in the sentry.
+ if i.mayBlock {
if err := syscall.SetNonblock(i.hostFD, true); err != nil {
return nil, err
}
@@ -80,6 +144,11 @@ func newInode(fs *filesystem, hostFD int, fileType linux.FileMode, isTTY bool) (
// NewFDOptions contains options to NewFD.
type NewFDOptions struct {
+ // If Savable is true, the host file descriptor may be saved/restored by
+ // numeric value; the sandbox API requires a corresponding host FD with the
+ // same numeric value to be provieded at time of restore.
+ Savable bool
+
// If IsTTY is true, the file descriptor is a TTY.
IsTTY bool
@@ -114,7 +183,7 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions)
}
d := &kernfs.Dentry{}
- i, err := newInode(fs, hostFD, linux.FileMode(s.Mode).FileType(), opts.IsTTY)
+ i, err := newInode(ctx, fs, hostFD, opts.Savable, linux.FileMode(s.Mode).FileType(), opts.IsTTY)
if err != nil {
return nil, err
}
@@ -132,7 +201,8 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions)
// ImportFD sets up and returns a vfs.FileDescription from a donated fd.
func ImportFD(ctx context.Context, mnt *vfs.Mount, hostFD int, isTTY bool) (*vfs.FileDescription, error) {
return NewFD(ctx, mnt, hostFD, &NewFDOptions{
- IsTTY: isTTY,
+ Savable: true,
+ IsTTY: isTTY,
})
}
@@ -191,68 +261,6 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
return vfs.PrependPathSyntheticError{}
}
-// inode implements kernfs.Inode.
-//
-// +stateify savable
-type inode struct {
- kernfs.InodeNoStatFS
- kernfs.InodeNotDirectory
- kernfs.InodeNotSymlink
- kernfs.InodeTemporary // This holds no meaning as this inode can't be Looked up and is always valid.
-
- locks vfs.FileLocks
-
- // When the reference count reaches zero, the host fd is closed.
- inodeRefs
-
- // hostFD contains the host fd that this file was originally created from,
- // which must be available at time of restore.
- //
- // This field is initialized at creation time and is immutable.
- hostFD int
-
- // ino is an inode number unique within this filesystem.
- //
- // This field is initialized at creation time and is immutable.
- ino uint64
-
- // isTTY is true if this file represents a TTY.
- //
- // This field is initialized at creation time and is immutable.
- isTTY bool
-
- // seekable is false if the host fd points to a file representing a stream,
- // e.g. a socket or a pipe. Such files are not seekable and can return
- // EWOULDBLOCK for I/O operations.
- //
- // This field is initialized at creation time and is immutable.
- seekable bool
-
- // wouldBlock is true if the host FD would return EWOULDBLOCK for
- // operations that would block.
- //
- // This field is initialized at creation time and is immutable.
- wouldBlock bool
-
- // Event queue for blocking operations.
- queue waiter.Queue
-
- // canMap specifies whether we allow the file to be memory mapped.
- //
- // This field is initialized at creation time and is immutable.
- canMap bool
-
- // mapsMu protects mappings.
- mapsMu sync.Mutex `state:"nosave"`
-
- // If canMap is true, mappings tracks mappings of hostFD into
- // memmap.MappingSpaces.
- mappings memmap.MappingSet
-
- // pf implements platform.File for mappings of hostFD.
- pf inodePlatformFile
-}
-
// CheckPermissions implements kernfs.Inode.CheckPermissions.
func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
var s syscall.Stat_t
@@ -422,14 +430,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
oldpgend, _ := usermem.PageRoundUp(oldSize)
newpgend, _ := usermem.PageRoundUp(s.Size)
if oldpgend != newpgend {
- i.mapsMu.Lock()
- i.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{
- // Compare Linux's mm/truncate.c:truncate_setsize() =>
- // truncate_pagecache() =>
- // mm/memory.c:unmap_mapping_range(evencows=1).
- InvalidatePrivate: true,
- })
- i.mapsMu.Unlock()
+ i.CachedMappable.InvalidateRange(memmap.MappableRange{newpgend, oldpgend})
}
}
}
@@ -448,7 +449,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
// DecRef implements kernfs.Inode.DecRef.
func (i *inode) DecRef(ctx context.Context) {
i.inodeRefs.DecRef(func() {
- if i.wouldBlock {
+ if i.mayBlock {
fdnotifier.RemoveFD(int32(i.hostFD))
}
if err := unix.Close(i.hostFD); err != nil {
@@ -567,6 +568,13 @@ func (f *fileDescription) Allocate(ctx context.Context, mode, offset, length uin
// PRead implements vfs.FileDescriptionImpl.PRead.
func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ // Check that flags are supported.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
+ if opts.Flags&^linux.RWF_HIPRI != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
i := f.inode
if !i.seekable {
return 0, syserror.ESPIPE
@@ -577,19 +585,31 @@ func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, off
// Read implements vfs.FileDescriptionImpl.Read.
func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // Check that flags are supported.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
+ if opts.Flags&^linux.RWF_HIPRI != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
i := f.inode
if !i.seekable {
+ bufN, err := i.readFromBuf(ctx, &dst)
+ if err != nil {
+ return bufN, err
+ }
n, err := readFromHostFD(ctx, i.hostFD, dst, -1, opts.Flags)
+ total := bufN + n
if isBlockError(err) {
// If we got any data at all, return it as a "completed" partial read
// rather than retrying until complete.
- if n != 0 {
+ if total != 0 {
err = nil
} else {
err = syserror.ErrWouldBlock
}
}
- return n, err
+ return total, err
}
f.offsetMu.Lock()
@@ -599,13 +619,26 @@ func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts
return n, err
}
-func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags uint32) (int64, error) {
- // Check that flags are supported.
- //
- // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
- if flags&^linux.RWF_HIPRI != 0 {
- return 0, syserror.EOPNOTSUPP
+func (i *inode) readFromBuf(ctx context.Context, dst *usermem.IOSequence) (int64, error) {
+ if atomic.LoadUint32(&i.haveBuf) == 0 {
+ return 0, nil
+ }
+ i.bufMu.Lock()
+ defer i.bufMu.Unlock()
+ if len(i.buf) == 0 {
+ return 0, nil
}
+ n, err := dst.CopyOut(ctx, i.buf)
+ *dst = dst.DropFirst(n)
+ i.buf = i.buf[n:]
+ if len(i.buf) == 0 {
+ atomic.StoreUint32(&i.haveBuf, 0)
+ i.buf = nil
+ }
+ return int64(n), err
+}
+
+func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags uint32) (int64, error) {
reader := hostfd.GetReadWriterAt(int32(hostFD), offset, flags)
n, err := dst.CopyOutFrom(ctx, reader)
hostfd.PutReadWriterAt(reader)
@@ -735,31 +768,37 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i
}
// Sync implements vfs.FileDescriptionImpl.Sync.
-func (f *fileDescription) Sync(context.Context) error {
+func (f *fileDescription) Sync(ctx context.Context) error {
// TODO(gvisor.dev/issue/1897): Currently, we always sync everything.
return unix.Fsync(f.inode.hostFD)
}
// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
func (f *fileDescription) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts) error {
- if !f.inode.canMap {
+ // NOTE(b/38213152): Technically, some obscure char devices can be memory
+ // mapped, but we only allow regular files.
+ if f.inode.ftype != syscall.S_IFREG {
return syserror.ENODEV
}
i := f.inode
- i.pf.fileMapperInitOnce.Do(i.pf.fileMapper.Init)
+ i.CachedMappable.InitFileMapperOnce()
return vfs.GenericConfigureMMap(&f.vfsfd, i, opts)
}
// EventRegister implements waiter.Waitable.EventRegister.
func (f *fileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
f.inode.queue.EventRegister(e, mask)
- fdnotifier.UpdateFD(int32(f.inode.hostFD))
+ if f.inode.mayBlock {
+ fdnotifier.UpdateFD(int32(f.inode.hostFD))
+ }
}
// EventUnregister implements waiter.Waitable.EventUnregister.
func (f *fileDescription) EventUnregister(e *waiter.Entry) {
f.inode.queue.EventUnregister(e)
- fdnotifier.UpdateFD(int32(f.inode.hostFD))
+ if f.inode.mayBlock {
+ fdnotifier.UpdateFD(int32(f.inode.hostFD))
+ }
}
// Readiness uses the poll() syscall to check the status of the underlying FD.
diff --git a/pkg/sentry/fsimpl/host/save_restore.go b/pkg/sentry/fsimpl/host/save_restore.go
new file mode 100644
index 000000000..8800652a9
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/save_restore.go
@@ -0,0 +1,70 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "fmt"
+ "io"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/hostfd"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// beforeSave is invoked by stateify.
+func (i *inode) beforeSave() {
+ if !i.savable {
+ panic("host.inode is not savable")
+ }
+ if i.ftype == syscall.S_IFIFO {
+ // If this pipe FD is readable, drain it so that bytes in the pipe can
+ // be read after restore. (This is a legacy VFS1 feature.) We don't
+ // know if the pipe FD is readable, so just try reading and tolerate
+ // EBADF from the read.
+ i.bufMu.Lock()
+ defer i.bufMu.Unlock()
+ var buf [usermem.PageSize]byte
+ for {
+ n, err := hostfd.Preadv2(int32(i.hostFD), safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:])), -1 /* offset */, 0 /* flags */)
+ if n != 0 {
+ i.buf = append(i.buf, buf[:n]...)
+ }
+ if err != nil {
+ if err == io.EOF || err == syscall.EAGAIN || err == syscall.EBADF {
+ break
+ }
+ panic(fmt.Errorf("host.inode.beforeSave: buffering from pipe failed: %v", err))
+ }
+ }
+ if len(i.buf) != 0 {
+ atomic.StoreUint32(&i.haveBuf, 1)
+ }
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (i *inode) afterLoad() {
+ if i.mayBlock {
+ if err := syscall.SetNonblock(i.hostFD, true); err != nil {
+ panic(fmt.Sprintf("host.inode.afterLoad: failed to set host FD %d non-blocking: %v", i.hostFD, err))
+ }
+ if err := fdnotifier.AddFD(int32(i.hostFD), &i.queue); err != nil {
+ panic(fmt.Sprintf("host.inode.afterLoad: fdnotifier.AddFD(%d) failed: %v", i.hostFD, err))
+ }
+ }
+}
diff --git a/pkg/sentry/fsimpl/host/util.go b/pkg/sentry/fsimpl/host/util.go
index 412bdb2eb..b2f43a119 100644
--- a/pkg/sentry/fsimpl/host/util.go
+++ b/pkg/sentry/fsimpl/host/util.go
@@ -43,12 +43,6 @@ func timespecToStatxTimestamp(ts unix.Timespec) linux.StatxTimestamp {
return linux.StatxTimestamp{Sec: int64(ts.Sec), Nsec: uint32(ts.Nsec)}
}
-// wouldBlock returns true for file types that can return EWOULDBLOCK
-// for blocking operations, e.g. pipes, character devices, and sockets.
-func wouldBlock(fileType uint32) bool {
- return fileType == syscall.S_IFIFO || fileType == syscall.S_IFCHR || fileType == syscall.S_IFSOCK
-}
-
// isBlockError checks if an error is EAGAIN or EWOULDBLOCK.
// If so, they can be transformed into syserror.ErrWouldBlock.
func isBlockError(err error) bool {
diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD
index 858cc24ce..6dbc7e34d 100644
--- a/pkg/sentry/fsimpl/kernfs/BUILD
+++ b/pkg/sentry/fsimpl/kernfs/BUILD
@@ -4,6 +4,18 @@ load("//tools/go_generics:defs.bzl", "go_template_instance")
licenses(["notice"])
go_template_instance(
+ name = "dentry_list",
+ out = "dentry_list.go",
+ package = "kernfs",
+ prefix = "dentry",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*Dentry",
+ "Linker": "*Dentry",
+ },
+)
+
+go_template_instance(
name = "fstree",
out = "fstree.go",
package = "kernfs",
@@ -27,22 +39,11 @@ go_template_instance(
)
go_template_instance(
- name = "dentry_refs",
- out = "dentry_refs.go",
- package = "kernfs",
- prefix = "Dentry",
- template = "//pkg/refs_vfs2:refs_template",
- types = {
- "T": "Dentry",
- },
-)
-
-go_template_instance(
name = "static_directory_refs",
out = "static_directory_refs.go",
package = "kernfs",
prefix = "StaticDirectory",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "StaticDirectory",
},
@@ -53,7 +54,7 @@ go_template_instance(
out = "dir_refs.go",
package = "kernfs_test",
prefix = "dir",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "dir",
},
@@ -64,7 +65,7 @@ go_template_instance(
out = "readonly_dir_refs.go",
package = "kernfs_test",
prefix = "readonlyDir",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "readonlyDir",
},
@@ -75,7 +76,7 @@ go_template_instance(
out = "synthetic_directory_refs.go",
package = "kernfs",
prefix = "syntheticDirectory",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "syntheticDirectory",
},
@@ -84,13 +85,15 @@ go_template_instance(
go_library(
name = "kernfs",
srcs = [
- "dentry_refs.go",
+ "dentry_list.go",
"dynamic_bytes_file.go",
"fd_impl_util.go",
"filesystem.go",
"fstree.go",
"inode_impl_util.go",
"kernfs.go",
+ "mmap_util.go",
+ "save_restore.go",
"slot_list.go",
"static_directory_refs.go",
"symlink.go",
@@ -104,8 +107,12 @@ go_library(
"//pkg/fspath",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
+ "//pkg/safemem",
+ "//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/lock",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
"//pkg/sentry/memmap",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
@@ -129,6 +136,7 @@ go_test(
"//pkg/context",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/sentry/contexttest",
"//pkg/sentry/fsimpl/testutil",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
index b929118b1..485504995 100644
--- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
+++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
@@ -47,11 +47,11 @@ type DynamicBytesFile struct {
var _ Inode = (*DynamicBytesFile)(nil)
// Init initializes a dynamic bytes file.
-func (f *DynamicBytesFile) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) {
+func (f *DynamicBytesFile) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) {
if perm&^linux.PermissionsMask != 0 {
panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask))
}
- f.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm)
+ f.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm)
f.data = data
}
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
index abf1905d6..f8dae22f8 100644
--- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -145,8 +145,12 @@ func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem {
return fd.vfsfd.VirtualDentry().Mount().Filesystem()
}
+func (fd *GenericDirectoryFD) dentry() *Dentry {
+ return fd.vfsfd.Dentry().Impl().(*Dentry)
+}
+
func (fd *GenericDirectoryFD) inode() Inode {
- return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode
+ return fd.dentry().inode
}
// IterDirents implements vfs.FileDescriptionImpl.IterDirents. IterDirents holds
@@ -176,8 +180,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
// Handle "..".
if fd.off == 1 {
- vfsd := fd.vfsfd.VirtualDentry().Dentry()
- parentInode := genericParentOrSelf(vfsd.Impl().(*Dentry)).inode
+ parentInode := genericParentOrSelf(fd.dentry()).inode
stat, err := parentInode.Stat(ctx, fd.filesystem(), opts)
if err != nil {
return err
@@ -219,7 +222,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
var err error
relOffset := fd.off - int64(len(fd.children.set)) - 2
- fd.off, err = fd.inode().IterDirents(ctx, cb, fd.off, relOffset)
+ fd.off, err = fd.inode().IterDirents(ctx, fd.vfsfd.Mount(), cb, fd.off, relOffset)
return err
}
@@ -265,8 +268,7 @@ func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (l
// SetStat implements vfs.FileDescriptionImpl.SetStat.
func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
creds := auth.CredentialsFromContext(ctx)
- inode := fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode
- return inode.SetStat(ctx, fd.filesystem(), creds, opts)
+ return fd.inode().SetStat(ctx, fd.filesystem(), creds, opts)
}
// Allocate implements vfs.FileDescriptionImpl.Allocate.
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index 6426a55f6..e77523f22 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -207,24 +207,23 @@ func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving
// Preconditions:
// * Filesystem.mu must be locked for at least reading.
// * isDir(parentInode) == true.
-func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *Dentry) (string, error) {
- if err := parent.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
- return "", err
+func checkCreateLocked(ctx context.Context, creds *auth.Credentials, name string, parent *Dentry) error {
+ if err := parent.inode.CheckPermissions(ctx, creds, vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
}
- pc := rp.Component()
- if pc == "." || pc == ".." {
- return "", syserror.EEXIST
+ if name == "." || name == ".." {
+ return syserror.EEXIST
}
- if len(pc) > linux.NAME_MAX {
- return "", syserror.ENAMETOOLONG
+ if len(name) > linux.NAME_MAX {
+ return syserror.ENAMETOOLONG
}
- if _, ok := parent.children[pc]; ok {
- return "", syserror.EEXIST
+ if _, ok := parent.children[name]; ok {
+ return syserror.EEXIST
}
if parent.VFSDentry().IsDead() {
- return "", syserror.ENOENT
+ return syserror.ENOENT
}
- return pc, nil
+ return nil
}
// checkDeleteLocked checks that the file represented by vfsd may be deleted.
@@ -245,7 +244,41 @@ func checkDeleteLocked(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry) er
}
// Release implements vfs.FilesystemImpl.Release.
-func (fs *Filesystem) Release(context.Context) {
+func (fs *Filesystem) Release(ctx context.Context) {
+ root := fs.root
+ if root == nil {
+ return
+ }
+ fs.mu.Lock()
+ root.releaseKeptDentriesLocked(ctx)
+ for fs.cachedDentriesLen != 0 {
+ fs.evictCachedDentryLocked(ctx)
+ }
+ fs.mu.Unlock()
+ // Drop ref acquired in Dentry.InitRoot().
+ root.DecRef(ctx)
+}
+
+// releaseKeptDentriesLocked recursively drops all dentry references created by
+// Lookup when Dentry.inode.Keep() is true.
+//
+// Precondition: Filesystem.mu is held.
+func (d *Dentry) releaseKeptDentriesLocked(ctx context.Context) {
+ if d.inode.Keep() && d != d.fs.root {
+ d.decRefLocked(ctx)
+ }
+
+ if d.isDir() {
+ var children []*Dentry
+ d.dirMu.Lock()
+ for _, child := range d.children {
+ children = append(children, child)
+ }
+ d.dirMu.Unlock()
+ for _, child := range children {
+ child.releaseKeptDentriesLocked(ctx)
+ }
+ }
}
// Sync implements vfs.FilesystemImpl.Sync.
@@ -318,10 +351,13 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
parent.dirMu.Lock()
defer parent.dirMu.Unlock()
- pc, err := checkCreateLocked(ctx, rp, parent)
- if err != nil {
+ pc := rp.Component()
+ if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil {
return err
}
+ if rp.MustBeDir() {
+ return syserror.ENOENT
+ }
if rp.Mount() != vd.Mount() {
return syserror.EXDEV
}
@@ -360,8 +396,8 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
parent.dirMu.Lock()
defer parent.dirMu.Unlock()
- pc, err := checkCreateLocked(ctx, rp, parent)
- if err != nil {
+ pc := rp.Component()
+ if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil {
return err
}
if err := rp.Mount().CheckBeginWrite(); err != nil {
@@ -373,7 +409,7 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
if !opts.ForSyntheticMountpoint || err == syserror.EEXIST {
return err
}
- childI = newSyntheticDirectory(rp.Credentials(), opts.Mode)
+ childI = newSyntheticDirectory(ctx, rp.Credentials(), opts.Mode)
}
var child Dentry
child.Init(fs, childI)
@@ -396,10 +432,13 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
parent.dirMu.Lock()
defer parent.dirMu.Unlock()
- pc, err := checkCreateLocked(ctx, rp, parent)
- if err != nil {
+ pc := rp.Component()
+ if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil {
return err
}
+ if rp.MustBeDir() {
+ return syserror.ENOENT
+ }
if err := rp.Mount().CheckBeginWrite(); err != nil {
return err
}
@@ -517,9 +556,6 @@ afterTrailingSymlink:
}
var child Dentry
child.Init(fs, childI)
- // FIXME(gvisor.dev/issue/1193): Race between checking existence with
- // fs.stepExistingLocked and parent.insertChild. If possible, we should hold
- // dirMu from one to the other.
parent.insertChild(pc, &child)
// Open may block so we need to unlock fs.mu. IncRef child to prevent
// its destruction while fs.mu is unlocked.
@@ -626,8 +662,8 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
// Can we create the dst dentry?
var dst *Dentry
- pc, err := checkCreateLocked(ctx, rp, dstDir)
- switch err {
+ pc := rp.Component()
+ switch err := checkCreateLocked(ctx, rp.Credentials(), pc, dstDir); err {
case nil:
// Ok, continue with rename as replacement.
case syserror.EEXIST:
@@ -791,10 +827,13 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ
parent.dirMu.Lock()
defer parent.dirMu.Unlock()
- pc, err := checkCreateLocked(ctx, rp, parent)
- if err != nil {
+ pc := rp.Component()
+ if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil {
return err
}
+ if rp.MustBeDir() {
+ return syserror.ENOENT
+ }
if err := rp.Mount().CheckBeginWrite(); err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index 122b10591..d83c17f83 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -21,9 +21,11 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// InodeNoopRefCount partially implements the Inode interface, specifically the
@@ -143,7 +145,7 @@ func (InodeNotDirectory) Lookup(ctx context.Context, name string) (Inode, error)
}
// IterDirents implements Inode.IterDirents.
-func (InodeNotDirectory) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
+func (InodeNotDirectory) IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
panic("IterDirents called on non-directory inode")
}
@@ -172,17 +174,23 @@ func (InodeNotSymlink) Getlink(context.Context, *vfs.Mount) (vfs.VirtualDentry,
//
// +stateify savable
type InodeAttrs struct {
- devMajor uint32
- devMinor uint32
- ino uint64
- mode uint32
- uid uint32
- gid uint32
- nlink uint32
+ devMajor uint32
+ devMinor uint32
+ ino uint64
+ mode uint32
+ uid uint32
+ gid uint32
+ nlink uint32
+ blockSize uint32
+
+ // Timestamps, all nsecs from the Unix epoch.
+ atime int64
+ mtime int64
+ ctime int64
}
// Init initializes this InodeAttrs.
-func (a *InodeAttrs) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, mode linux.FileMode) {
+func (a *InodeAttrs) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, mode linux.FileMode) {
if mode.FileType() == 0 {
panic(fmt.Sprintf("No file type specified in 'mode' for InodeAttrs.Init(): mode=0%o", mode))
}
@@ -198,6 +206,11 @@ func (a *InodeAttrs) Init(creds *auth.Credentials, devMajor, devMinor uint32, in
atomic.StoreUint32(&a.uid, uint32(creds.EffectiveKUID))
atomic.StoreUint32(&a.gid, uint32(creds.EffectiveKGID))
atomic.StoreUint32(&a.nlink, nlink)
+ atomic.StoreUint32(&a.blockSize, usermem.PageSize)
+ now := ktime.NowFromContext(ctx).Nanoseconds()
+ atomic.StoreInt64(&a.atime, now)
+ atomic.StoreInt64(&a.mtime, now)
+ atomic.StoreInt64(&a.ctime, now)
}
// DevMajor returns the device major number.
@@ -220,12 +233,33 @@ func (a *InodeAttrs) Mode() linux.FileMode {
return linux.FileMode(atomic.LoadUint32(&a.mode))
}
+// TouchAtime updates a.atime to the current time.
+func (a *InodeAttrs) TouchAtime(ctx context.Context, mnt *vfs.Mount) {
+ if mnt.Flags.NoATime || mnt.ReadOnly() {
+ return
+ }
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return
+ }
+ atomic.StoreInt64(&a.atime, ktime.NowFromContext(ctx).Nanoseconds())
+ mnt.EndWrite()
+}
+
+// TouchCMtime updates a.{c/m}time to the current time. The caller should
+// synchronize calls to this so that ctime and mtime are updated to the same
+// value.
+func (a *InodeAttrs) TouchCMtime(ctx context.Context) {
+ now := ktime.NowFromContext(ctx).Nanoseconds()
+ atomic.StoreInt64(&a.mtime, now)
+ atomic.StoreInt64(&a.ctime, now)
+}
+
// Stat partially implements Inode.Stat. Note that this function doesn't provide
// all the stat fields, and the embedder should consider extending the result
// with filesystem-specific fields.
func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) {
var stat linux.Statx
- stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK
+ stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME
stat.DevMajor = a.devMajor
stat.DevMinor = a.devMinor
stat.Ino = atomic.LoadUint64(&a.ino)
@@ -233,21 +267,15 @@ func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (li
stat.UID = atomic.LoadUint32(&a.uid)
stat.GID = atomic.LoadUint32(&a.gid)
stat.Nlink = atomic.LoadUint32(&a.nlink)
-
- // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps.
-
+ stat.Blksize = atomic.LoadUint32(&a.blockSize)
+ stat.Atime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.atime))
+ stat.Mtime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.mtime))
+ stat.Ctime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.ctime))
return stat, nil
}
// SetStat implements Inode.SetStat.
func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
- return a.SetInodeStat(ctx, fs, creds, opts)
-}
-
-// SetInodeStat sets the corresponding attributes from opts to InodeAttrs.
-// This function can be used by other kernfs-based filesystem implementation to
-// sets the unexported attributes into InodeAttrs.
-func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
if opts.Stat.Mask == 0 {
return nil
}
@@ -256,9 +284,7 @@ func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds
// inode numbers are immutable after node creation. Setting the size is often
// allowed by kernfs files but does not do anything. If some other behavior is
// needed, the embedder should consider extending SetStat.
- //
- // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps.
- if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_SIZE) != 0 {
+ if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_SIZE) != 0 {
return syserror.EPERM
}
if opts.Stat.Mask&linux.STATX_SIZE != 0 && a.Mode().IsDir() {
@@ -286,6 +312,20 @@ func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds
atomic.StoreUint32(&a.gid, stat.GID)
}
+ now := ktime.NowFromContext(ctx).Nanoseconds()
+ if stat.Mask&linux.STATX_ATIME != 0 {
+ if stat.Atime.Nsec == linux.UTIME_NOW {
+ stat.Atime = linux.NsecToStatxTimestamp(now)
+ }
+ atomic.StoreInt64(&a.atime, stat.Atime.ToNsec())
+ }
+ if stat.Mask&linux.STATX_MTIME != 0 {
+ if stat.Mtime.Nsec == linux.UTIME_NOW {
+ stat.Mtime = linux.NsecToStatxTimestamp(now)
+ }
+ atomic.StoreInt64(&a.mtime, stat.Mtime.ToNsec())
+ }
+
return nil
}
@@ -421,7 +461,7 @@ func (o *OrderedChildren) Lookup(ctx context.Context, name string) (Inode, error
}
// IterDirents implements Inode.IterDirents.
-func (o *OrderedChildren) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
+func (o *OrderedChildren) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
// All entries from OrderedChildren have already been handled in
// GenericDirectoryFD.IterDirents.
return offset, nil
@@ -528,13 +568,6 @@ func (o *OrderedChildren) RmDir(ctx context.Context, name string, child Inode) e
return o.Unlink(ctx, name, child)
}
-// +stateify savable
-type renameAcrossDifferentImplementationsError struct{}
-
-func (renameAcrossDifferentImplementationsError) Error() string {
- return "rename across inodes with different implementations"
-}
-
// Rename implements Inode.Rename.
//
// Precondition: Rename may only be called across two directory inodes with
@@ -545,13 +578,18 @@ func (renameAcrossDifferentImplementationsError) Error() string {
//
// Postcondition: reference on any replaced dentry transferred to caller.
func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir Inode) error {
+ if !o.writable {
+ return syserror.EPERM
+ }
+
dst, ok := dstDir.(interface{}).(*OrderedChildren)
if !ok {
- return renameAcrossDifferentImplementationsError{}
+ return syserror.EXDEV
}
- if !o.writable || !dst.writable {
+ if !dst.writable {
return syserror.EPERM
}
+
// Note: There's a potential deadlock below if concurrent calls to Rename
// refer to the same src and dst directories in reverse. We avoid any
// ordering issues because the caller is required to serialize concurrent
@@ -619,9 +657,9 @@ type StaticDirectory struct {
var _ Inode = (*StaticDirectory)(nil)
// NewStaticDir creates a new static directory and returns its dentry.
-func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]Inode, fdOpts GenericDirectoryFDOptions) Inode {
+func NewStaticDir(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]Inode, fdOpts GenericDirectoryFDOptions) Inode {
inode := &StaticDirectory{}
- inode.Init(creds, devMajor, devMinor, ino, perm, fdOpts)
+ inode.Init(ctx, creds, devMajor, devMinor, ino, perm, fdOpts)
inode.EnableLeakCheck()
inode.OrderedChildren.Init(OrderedChildrenOptions{})
@@ -632,12 +670,12 @@ func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64
}
// Init initializes StaticDirectory.
-func (s *StaticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, fdOpts GenericDirectoryFDOptions) {
+func (s *StaticDirectory) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, fdOpts GenericDirectoryFDOptions) {
if perm&^linux.PermissionsMask != 0 {
panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask))
}
s.fdOpts = fdOpts
- s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeDirectory|perm)
+ s.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeDirectory|perm)
}
// Open implements Inode.Open.
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index 606081e68..abb477c7d 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -61,6 +61,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
@@ -107,6 +108,23 @@ type Filesystem struct {
// nextInoMinusOne is used to to allocate inode numbers on this
// filesystem. Must be accessed by atomic operations.
nextInoMinusOne uint64
+
+ // cachedDentries contains all dentries with 0 references. (Due to race
+ // conditions, it may also contain dentries with non-zero references.)
+ // cachedDentriesLen is the number of dentries in cachedDentries. These
+ // fields are protected by mu.
+ cachedDentries dentryList
+ cachedDentriesLen uint64
+
+ // MaxCachedDentries is the maximum size of cachedDentries. If not set,
+ // defaults to 0 and kernfs does not cache any dentries. This is immutable.
+ MaxCachedDentries uint64
+
+ // root is the root dentry of this filesystem. Note that root may be nil for
+ // filesystems on a disconnected mount without a root (e.g. pipefs, sockfs,
+ // hostfs). Filesystem holds an extra reference on root to prevent it from
+ // being destroyed prematurely. This is immutable.
+ root *Dentry
}
// deferDecRef defers dropping a dentry ref until the next call to
@@ -165,7 +183,12 @@ const (
// +stateify savable
type Dentry struct {
vfsd vfs.Dentry
- DentryRefs
+
+ // refs is the reference count. When refs reaches 0, the dentry may be
+ // added to the cache or destroyed. If refs == -1, the dentry has already
+ // been destroyed. refs are allowed to go to 0 and increase again. refs is
+ // accessed using atomic memory operations.
+ refs int64
// fs is the owning filesystem. fs is immutable.
fs *Filesystem
@@ -177,6 +200,12 @@ type Dentry struct {
parent *Dentry
name string
+ // If cached is true, dentryEntry links dentry into
+ // Filesystem.cachedDentries. cached and dentryEntry are protected by
+ // Filesystem.mu.
+ cached bool
+ dentryEntry
+
// dirMu protects children and the names of child Dentries.
//
// Note that holding fs.mu for writing is not sufficient;
@@ -188,6 +217,201 @@ type Dentry struct {
inode Inode
}
+// IncRef implements vfs.DentryImpl.IncRef.
+func (d *Dentry) IncRef() {
+ // d.refs may be 0 if d.fs.mu is locked, which serializes against
+ // d.cacheLocked().
+ r := atomic.AddInt64(&d.refs, 1)
+ refsvfs2.LogIncRef(d, r)
+}
+
+// TryIncRef implements vfs.DentryImpl.TryIncRef.
+func (d *Dentry) TryIncRef() bool {
+ for {
+ r := atomic.LoadInt64(&d.refs)
+ if r <= 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt64(&d.refs, r, r+1) {
+ refsvfs2.LogTryIncRef(d, r+1)
+ return true
+ }
+ }
+}
+
+// DecRef implements vfs.DentryImpl.DecRef.
+func (d *Dentry) DecRef(ctx context.Context) {
+ r := atomic.AddInt64(&d.refs, -1)
+ refsvfs2.LogDecRef(d, r)
+ if r == 0 {
+ d.fs.mu.Lock()
+ d.cacheLocked(ctx)
+ d.fs.mu.Unlock()
+ } else if r < 0 {
+ panic("kernfs.Dentry.DecRef() called without holding a reference")
+ }
+}
+
+func (d *Dentry) decRefLocked(ctx context.Context) {
+ r := atomic.AddInt64(&d.refs, -1)
+ refsvfs2.LogDecRef(d, r)
+ if r == 0 {
+ d.cacheLocked(ctx)
+ } else if r < 0 {
+ panic("kernfs.Dentry.DecRef() called without holding a reference")
+ }
+}
+
+// cacheLocked should be called after d's reference count becomes 0. The ref
+// count check may happen before acquiring d.fs.mu so there might be a race
+// condition where the ref count is increased again by the time the caller
+// acquires d.fs.mu. This race is handled.
+// Only reachable dentries are added to the cache. However, a dentry might
+// become unreachable *while* it is in the cache due to invalidation.
+//
+// Preconditions: d.fs.mu must be locked for writing.
+func (d *Dentry) cacheLocked(ctx context.Context) {
+ // Dentries with a non-zero reference count must be retained. (The only way
+ // to obtain a reference on a dentry with zero references is via path
+ // resolution, which requires d.fs.mu, so if d.refs is zero then it will
+ // remain zero while we hold d.fs.mu for writing.)
+ refs := atomic.LoadInt64(&d.refs)
+ if refs == -1 {
+ // Dentry has already been destroyed.
+ panic(fmt.Sprintf("cacheLocked called on a dentry which has already been destroyed: %v", d))
+ }
+ if refs > 0 {
+ if d.cached {
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentriesLen--
+ d.cached = false
+ }
+ return
+ }
+ // If the dentry is deleted and invalidated or has no parent, then it is no
+ // longer reachable by path resolution and should be dropped immediately
+ // because it has zero references.
+ // Note that a dentry may not always have a parent; for example magic links
+ // as described in Inode.Getlink.
+ if isDead := d.VFSDentry().IsDead(); isDead || d.parent == nil {
+ if !isDead {
+ d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, d.VFSDentry())
+ }
+ if d.cached {
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentriesLen--
+ d.cached = false
+ }
+ d.destroyLocked(ctx)
+ return
+ }
+ // If d is already cached, just move it to the front of the LRU.
+ if d.cached {
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentries.PushFront(d)
+ return
+ }
+ // Cache the dentry, then evict the least recently used cached dentry if
+ // the cache becomes over-full.
+ d.fs.cachedDentries.PushFront(d)
+ d.fs.cachedDentriesLen++
+ d.cached = true
+ if d.fs.cachedDentriesLen <= d.fs.MaxCachedDentries {
+ return
+ }
+ d.fs.evictCachedDentryLocked(ctx)
+ // Whether or not victim was destroyed, we brought fs.cachedDentriesLen
+ // back down to fs.opts.maxCachedDentries, so we don't loop.
+}
+
+// Preconditions:
+// * fs.mu must be locked for writing.
+// * fs.cachedDentriesLen != 0.
+func (fs *Filesystem) evictCachedDentryLocked(ctx context.Context) {
+ // Evict the least recently used dentry because cache size is greater than
+ // max cache size (configured on mount).
+ victim := fs.cachedDentries.Back()
+ fs.cachedDentries.Remove(victim)
+ fs.cachedDentriesLen--
+ victim.cached = false
+ // victim.refs may have become non-zero from an earlier path resolution
+ // after it was inserted into fs.cachedDentries.
+ if atomic.LoadInt64(&victim.refs) == 0 {
+ if !victim.vfsd.IsDead() {
+ victim.parent.dirMu.Lock()
+ // Note that victim can't be a mount point (in any mount
+ // namespace), since VFS holds references on mount points.
+ fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, victim.VFSDentry())
+ delete(victim.parent.children, victim.name)
+ victim.parent.dirMu.Unlock()
+ }
+ victim.destroyLocked(ctx)
+ }
+ // Whether or not victim was destroyed, we brought fs.cachedDentriesLen
+ // back down to fs.MaxCachedDentries, so we don't loop.
+}
+
+// destroyLocked destroys the dentry.
+//
+// Preconditions:
+// * d.fs.mu must be locked for writing.
+// * d.refs == 0.
+// * d should have been removed from d.parent.children, i.e. d is not reachable
+// by path traversal.
+// * d.vfsd.IsDead() is true.
+func (d *Dentry) destroyLocked(ctx context.Context) {
+ refs := atomic.LoadInt64(&d.refs)
+ switch refs {
+ case 0:
+ // Mark the dentry destroyed.
+ atomic.StoreInt64(&d.refs, -1)
+ case -1:
+ panic("dentry.destroyLocked() called on already destroyed dentry")
+ default:
+ panic("dentry.destroyLocked() called with references on the dentry")
+ }
+
+ d.inode.DecRef(ctx) // IncRef from Init.
+ d.inode = nil
+
+ if d.parent != nil {
+ d.parent.decRefLocked(ctx)
+ }
+
+ refsvfs2.Unregister(d)
+}
+
+// RefType implements refsvfs2.CheckedObject.Type.
+func (d *Dentry) RefType() string {
+ return "kernfs.Dentry"
+}
+
+// LeakMessage implements refsvfs2.CheckedObject.LeakMessage.
+func (d *Dentry) LeakMessage() string {
+ return fmt.Sprintf("[kernfs.Dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs))
+}
+
+// LogRefs implements refsvfs2.CheckedObject.LogRefs.
+//
+// This should only be set to true for debugging purposes, as it can generate an
+// extremely large amount of output and drastically degrade performance.
+func (d *Dentry) LogRefs() bool {
+ return false
+}
+
+// InitRoot initializes this dentry as the root of the filesystem.
+//
+// Precondition: Caller must hold a reference on inode.
+//
+// Postcondition: Caller's reference on inode is transferred to the dentry.
+func (d *Dentry) InitRoot(fs *Filesystem, inode Inode) {
+ d.Init(fs, inode)
+ fs.root = d
+ // Hold an extra reference on the root dentry. It is held by fs to prevent the
+ // root from being "cached" and subsequently evicted.
+ d.IncRef()
+}
+
// Init initializes this dentry.
//
// Precondition: Caller must hold a reference on inode.
@@ -197,6 +421,7 @@ func (d *Dentry) Init(fs *Filesystem, inode Inode) {
d.vfsd.Init(d)
d.fs = fs
d.inode = inode
+ atomic.StoreInt64(&d.refs, 1)
ftype := inode.Mode().FileType()
if ftype == linux.ModeDirectory {
d.flags |= dflagsIsDir
@@ -204,7 +429,7 @@ func (d *Dentry) Init(fs *Filesystem, inode Inode) {
if ftype == linux.ModeSymlink {
d.flags |= dflagsIsSymlink
}
- d.EnableLeakCheck()
+ refsvfs2.Register(d)
}
// VFSDentry returns the generic vfs dentry for this kernfs dentry.
@@ -222,32 +447,6 @@ func (d *Dentry) isSymlink() bool {
return atomic.LoadUint32(&d.flags)&dflagsIsSymlink != 0
}
-// DecRef implements vfs.DentryImpl.DecRef.
-func (d *Dentry) DecRef(ctx context.Context) {
- decRefParent := false
- d.fs.mu.Lock()
- d.DentryRefs.DecRef(func() {
- d.inode.DecRef(ctx) // IncRef from Init.
- d.inode = nil
- if d.parent != nil {
- // We will DecRef d.parent once all locks are dropped.
- decRefParent = true
- d.parent.dirMu.Lock()
- // Remove d from parent.children. It might already have been
- // removed due to invalidation.
- if _, ok := d.parent.children[d.name]; ok {
- delete(d.parent.children, d.name)
- d.fs.VFSFilesystem().VirtualFilesystem().InvalidateDentry(ctx, d.VFSDentry())
- }
- d.parent.dirMu.Unlock()
- }
- })
- d.fs.mu.Unlock()
- if decRefParent {
- d.parent.DecRef(ctx) // IncRef from Dentry.insertChild.
- }
-}
-
// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
//
// Although Linux technically supports inotify on pseudo filesystems (inotify
@@ -267,7 +466,9 @@ func (d *Dentry) OnZeroWatches(context.Context) {}
// this dentry. This does not update the directory inode, so calling this on its
// own isn't sufficient to insert a child into a directory.
//
-// Precondition: d must represent a directory inode.
+// Preconditions:
+// * d must represent a directory inode.
+// * d.fs.mu must be locked for at least reading.
func (d *Dentry) insertChild(name string, child *Dentry) {
d.dirMu.Lock()
d.insertChildLocked(name, child)
@@ -280,6 +481,7 @@ func (d *Dentry) insertChild(name string, child *Dentry) {
// Preconditions:
// * d must represent a directory inode.
// * d.dirMu must be locked.
+// * d.fs.mu must be locked for at least reading.
func (d *Dentry) insertChildLocked(name string, child *Dentry) {
if !d.isDir() {
panic(fmt.Sprintf("insertChildLocked called on non-directory Dentry: %+v.", d))
@@ -436,7 +638,7 @@ type inodeDirectory interface {
// the inode is a directory.
//
// The child returned by Lookup will be hashed into the VFS dentry tree,
- // atleast for the duration of the current FS operation.
+ // at least for the duration of the current FS operation.
//
// Lookup must return the child with an extra reference whose ownership is
// transferred to the dentry that is created to point to that inode. If
@@ -454,7 +656,7 @@ type inodeDirectory interface {
// inside the entries returned by this IterDirents invocation. In other words,
// 'offset' should be used to calculate each vfs.Dirent.NextOff as well as
// the return value, while 'relOffset' is the place to start iteration.
- IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error)
+ IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error)
}
type inodeSymlink interface {
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
index 82fa19c03..2418eec44 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -36,7 +36,7 @@ const staticFileContent = "This is sample content for a static test file."
// RootDentryFn is a generator function for creating the root dentry of a test
// filesystem. See newTestSystem.
-type RootDentryFn func(*auth.Credentials, *filesystem) kernfs.Inode
+type RootDentryFn func(context.Context, *auth.Credentials, *filesystem) kernfs.Inode
// newTestSystem sets up a minimal environment for running a test, including an
// instance of a test filesystem. Tests can control the contents of the
@@ -72,10 +72,10 @@ type file struct {
content string
}
-func (fs *filesystem) newFile(creds *auth.Credentials, content string) kernfs.Inode {
+func (fs *filesystem) newFile(ctx context.Context, creds *auth.Credentials, content string) kernfs.Inode {
f := &file{}
f.content = content
- f.DynamicBytesFile.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), f, 0777)
+ f.DynamicBytesFile.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), f, 0777)
return f
}
@@ -105,9 +105,9 @@ type readonlyDir struct {
locks vfs.FileLocks
}
-func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
+func (fs *filesystem) newReadonlyDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
dir := &readonlyDir{}
- dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode)
+ dir.attrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode)
dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
dir.EnableLeakCheck()
dir.IncLinks(dir.OrderedChildren.Populate(contents))
@@ -142,10 +142,10 @@ type dir struct {
fs *filesystem
}
-func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
+func (fs *filesystem) newDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
dir := &dir{}
dir.fs = fs
- dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode)
+ dir.attrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode)
dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true})
dir.EnableLeakCheck()
@@ -169,22 +169,24 @@ func (d *dir) DecRef(ctx context.Context) {
func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (kernfs.Inode, error) {
creds := auth.CredentialsFromContext(ctx)
- dir := d.fs.newDir(creds, opts.Mode, nil)
+ dir := d.fs.newDir(ctx, creds, opts.Mode, nil)
if err := d.OrderedChildren.Insert(name, dir); err != nil {
dir.DecRef(ctx)
return nil, err
}
+ d.TouchCMtime(ctx)
d.IncLinks(1)
return dir, nil
}
func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (kernfs.Inode, error) {
creds := auth.CredentialsFromContext(ctx)
- f := d.fs.newFile(creds, "")
+ f := d.fs.newFile(ctx, creds, "")
if err := d.OrderedChildren.Insert(name, f); err != nil {
f.DecRef(ctx)
return nil, err
}
+ d.TouchCMtime(ctx)
return f, nil
}
@@ -209,7 +211,7 @@ func (fsType) Release(ctx context.Context) {}
func (fst fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opt vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
fs := &filesystem{}
fs.VFSFilesystem().Init(vfsObj, &fst, fs)
- root := fst.rootFn(creds, fs)
+ root := fst.rootFn(ctx, creds, fs)
var d kernfs.Dentry
d.Init(&fs.Filesystem, root)
return fs.VFSFilesystem(), d.VFSDentry(), nil
@@ -218,9 +220,9 @@ func (fst fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesyst
// -------------------- Remainder of the file are test cases --------------------
func TestBasic(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
- return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
- "file1": fs.newFile(creds, staticFileContent),
+ sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{
+ "file1": fs.newFile(ctx, creds, staticFileContent),
})
})
defer sys.Destroy()
@@ -228,9 +230,9 @@ func TestBasic(t *testing.T) {
}
func TestMkdirGetDentry(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
- return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
- "dir1": fs.newDir(creds, 0755, nil),
+ sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{
+ "dir1": fs.newDir(ctx, creds, 0755, nil),
})
})
defer sys.Destroy()
@@ -243,9 +245,9 @@ func TestMkdirGetDentry(t *testing.T) {
}
func TestReadStaticFile(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
- return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
- "file1": fs.newFile(creds, staticFileContent),
+ sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{
+ "file1": fs.newFile(ctx, creds, staticFileContent),
})
})
defer sys.Destroy()
@@ -269,9 +271,9 @@ func TestReadStaticFile(t *testing.T) {
}
func TestCreateNewFileInStaticDir(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
- return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
- "dir1": fs.newDir(creds, 0755, nil),
+ sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{
+ "dir1": fs.newDir(ctx, creds, 0755, nil),
})
})
defer sys.Destroy()
@@ -296,8 +298,8 @@ func TestCreateNewFileInStaticDir(t *testing.T) {
}
func TestDirFDReadWrite(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
- return fs.newReadonlyDir(creds, 0755, nil)
+ sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(ctx, creds, 0755, nil)
})
defer sys.Destroy()
@@ -320,14 +322,14 @@ func TestDirFDReadWrite(t *testing.T) {
}
func TestDirFDIterDirents(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
- return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
+ sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{
// Fill root with nodes backed by various inode implementations.
- "dir1": fs.newReadonlyDir(creds, 0755, nil),
- "dir2": fs.newDir(creds, 0755, map[string]kernfs.Inode{
- "dir3": fs.newDir(creds, 0755, nil),
+ "dir1": fs.newReadonlyDir(ctx, creds, 0755, nil),
+ "dir2": fs.newDir(ctx, creds, 0755, map[string]kernfs.Inode{
+ "dir3": fs.newDir(ctx, creds, 0755, nil),
}),
- "file1": fs.newFile(creds, staticFileContent),
+ "file1": fs.newFile(ctx, creds, staticFileContent),
})
})
defer sys.Destroy()
diff --git a/pkg/sentry/fsimpl/host/mmap.go b/pkg/sentry/fsimpl/kernfs/mmap_util.go
index b51a17bed..bd6a134b4 100644
--- a/pkg/sentry/fsimpl/host/mmap.go
+++ b/pkg/sentry/fsimpl/kernfs/mmap_util.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package host
+package kernfs
import (
"gvisor.dev/gvisor/pkg/context"
@@ -26,11 +26,14 @@ import (
// inodePlatformFile implements memmap.File. It exists solely because inode
// cannot implement both kernfs.Inode.IncRef and memmap.File.IncRef.
//
-// inodePlatformFile should only be used if inode.canMap is true.
-//
// +stateify savable
type inodePlatformFile struct {
- *inode
+ // hostFD contains the host fd that this file was originally created from,
+ // which must be available at time of restore.
+ //
+ // This field is initialized at creation time and is immutable.
+ // inodePlatformFile does not own hostFD and hence should not close it.
+ hostFD int
// fdRefsMu protects fdRefs.
fdRefsMu sync.Mutex `state:"nosave"`
@@ -43,12 +46,12 @@ type inodePlatformFile struct {
fileMapper fsutil.HostFileMapper
// fileMapperInitOnce is used to lazily initialize fileMapper.
- fileMapperInitOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
+ fileMapperInitOnce sync.Once `state:"nosave"`
}
+var _ memmap.File = (*inodePlatformFile)(nil)
+
// IncRef implements memmap.File.IncRef.
-//
-// Precondition: i.inode.canMap must be true.
func (i *inodePlatformFile) IncRef(fr memmap.FileRange) {
i.fdRefsMu.Lock()
i.fdRefs.IncRefAndAccount(fr)
@@ -56,8 +59,6 @@ func (i *inodePlatformFile) IncRef(fr memmap.FileRange) {
}
// DecRef implements memmap.File.DecRef.
-//
-// Precondition: i.inode.canMap must be true.
func (i *inodePlatformFile) DecRef(fr memmap.FileRange) {
i.fdRefsMu.Lock()
i.fdRefs.DecRefAndAccount(fr)
@@ -65,8 +66,6 @@ func (i *inodePlatformFile) DecRef(fr memmap.FileRange) {
}
// MapInternal implements memmap.File.MapInternal.
-//
-// Precondition: i.inode.canMap must be true.
func (i *inodePlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
return i.fileMapper.MapInternal(fr, i.hostFD, at.Write)
}
@@ -76,10 +75,32 @@ func (i *inodePlatformFile) FD() int {
return i.hostFD
}
-// AddMapping implements memmap.Mappable.AddMapping.
+// CachedMappable implements memmap.Mappable. This utility can be embedded in a
+// kernfs.Inode that represents a host file to make the inode mappable.
+// CachedMappable caches the mappings of the host file. CachedMappable must be
+// initialized (via Init) with a hostFD before use.
//
-// Precondition: i.inode.canMap must be true.
-func (i *inode) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+// +stateify savable
+type CachedMappable struct {
+ // mapsMu protects mappings.
+ mapsMu sync.Mutex `state:"nosave"`
+
+ // mappings tracks mappings of hostFD into memmap.MappingSpaces.
+ mappings memmap.MappingSet
+
+ // pf implements memmap.File for mappings backed by a host fd.
+ pf inodePlatformFile
+}
+
+var _ memmap.Mappable = (*CachedMappable)(nil)
+
+// Init initializes i.pf. This must be called before using CachedMappable.
+func (i *CachedMappable) Init(hostFD int) {
+ i.pf.hostFD = hostFD
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (i *CachedMappable) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
i.mapsMu.Lock()
mapped := i.mappings.AddMapping(ms, ar, offset, writable)
for _, r := range mapped {
@@ -90,9 +111,7 @@ func (i *inode) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar userm
}
// RemoveMapping implements memmap.Mappable.RemoveMapping.
-//
-// Precondition: i.inode.canMap must be true.
-func (i *inode) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+func (i *CachedMappable) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
i.mapsMu.Lock()
unmapped := i.mappings.RemoveMapping(ms, ar, offset, writable)
for _, r := range unmapped {
@@ -102,16 +121,12 @@ func (i *inode) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar us
}
// CopyMapping implements memmap.Mappable.CopyMapping.
-//
-// Precondition: i.inode.canMap must be true.
-func (i *inode) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+func (i *CachedMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
return i.AddMapping(ctx, ms, dstAR, offset, writable)
}
// Translate implements memmap.Mappable.Translate.
-//
-// Precondition: i.inode.canMap must be true.
-func (i *inode) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+func (i *CachedMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
mr := optional
return []memmap.Translation{
{
@@ -124,10 +139,26 @@ func (i *inode) Translate(ctx context.Context, required, optional memmap.Mappabl
}
// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
-//
-// Precondition: i.inode.canMap must be true.
-func (i *inode) InvalidateUnsavable(ctx context.Context) error {
+func (i *CachedMappable) InvalidateUnsavable(ctx context.Context) error {
// We expect the same host fd across save/restore, so all translations
// should be valid.
return nil
}
+
+// InvalidateRange invalidates the passed range on i.mappings.
+func (i *CachedMappable) InvalidateRange(r memmap.MappableRange) {
+ i.mapsMu.Lock()
+ i.mappings.Invalidate(r, memmap.InvalidateOpts{
+ // Compare Linux's mm/truncate.c:truncate_setsize() =>
+ // truncate_pagecache() =>
+ // mm/memory.c:unmap_mapping_range(evencows=1).
+ InvalidatePrivate: true,
+ })
+ i.mapsMu.Unlock()
+}
+
+// InitFileMapperOnce initializes the host file mapper. It ensures that the
+// file mapper is initialized just once.
+func (i *CachedMappable) InitFileMapperOnce() {
+ i.pf.fileMapperInitOnce.Do(i.pf.fileMapper.Init)
+}
diff --git a/pkg/sentry/fsimpl/kernfs/save_restore.go b/pkg/sentry/fsimpl/kernfs/save_restore.go
new file mode 100644
index 000000000..f78509eb7
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/save_restore.go
@@ -0,0 +1,36 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernfs
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/refsvfs2"
+)
+
+// afterLoad is invoked by stateify.
+func (d *Dentry) afterLoad() {
+ if atomic.LoadInt64(&d.refs) >= 0 {
+ refsvfs2.Register(d)
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (i *inodePlatformFile) afterLoad() {
+ if i.fileMapper.IsInited() {
+ // Ensure that we don't call i.fileMapper.Init() again.
+ i.fileMapperInitOnce.Do(func() {})
+ }
+}
diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go
index 934cc6c9e..a0736c0d6 100644
--- a/pkg/sentry/fsimpl/kernfs/symlink.go
+++ b/pkg/sentry/fsimpl/kernfs/symlink.go
@@ -38,16 +38,16 @@ type StaticSymlink struct {
var _ Inode = (*StaticSymlink)(nil)
// NewStaticSymlink creates a new symlink file pointing to 'target'.
-func NewStaticSymlink(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, target string) Inode {
+func NewStaticSymlink(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, target string) Inode {
inode := &StaticSymlink{}
- inode.Init(creds, devMajor, devMinor, ino, target)
+ inode.Init(ctx, creds, devMajor, devMinor, ino, target)
return inode
}
// Init initializes the instance.
-func (s *StaticSymlink) Init(creds *auth.Credentials, devMajor uint32, devMinor uint32, ino uint64, target string) {
+func (s *StaticSymlink) Init(ctx context.Context, creds *auth.Credentials, devMajor uint32, devMinor uint32, ino uint64, target string) {
s.target = target
- s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeSymlink|0777)
+ s.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeSymlink|0777)
}
// Readlink implements Inode.Readlink.
diff --git a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go
index d0ed17b18..463d77d79 100644
--- a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go
+++ b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go
@@ -41,17 +41,17 @@ type syntheticDirectory struct {
var _ Inode = (*syntheticDirectory)(nil)
-func newSyntheticDirectory(creds *auth.Credentials, perm linux.FileMode) Inode {
+func newSyntheticDirectory(ctx context.Context, creds *auth.Credentials, perm linux.FileMode) Inode {
inode := &syntheticDirectory{}
- inode.Init(creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm)
+ inode.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm)
return inode
}
-func (dir *syntheticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) {
+func (dir *syntheticDirectory) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) {
if perm&^linux.PermissionsMask != 0 {
panic(fmt.Sprintf("perm contains non-permission bits: %#o", perm))
}
- dir.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.S_IFDIR|perm)
+ dir.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.S_IFDIR|perm)
dir.OrderedChildren.Init(OrderedChildrenOptions{
Writable: true,
})
@@ -76,11 +76,12 @@ func (dir *syntheticDirectory) NewDir(ctx context.Context, name string, opts vfs
if !opts.ForSyntheticMountpoint {
return nil, syserror.EPERM
}
- subdirI := newSyntheticDirectory(auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask)
+ subdirI := newSyntheticDirectory(ctx, auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask)
if err := dir.OrderedChildren.Insert(name, subdirI); err != nil {
subdirI.DecRef(ctx)
return nil, err
}
+ dir.TouchCMtime(ctx)
return subdirI, nil
}
diff --git a/pkg/sentry/fsimpl/overlay/BUILD b/pkg/sentry/fsimpl/overlay/BUILD
index 1e11b0428..bf13bbbf4 100644
--- a/pkg/sentry/fsimpl/overlay/BUILD
+++ b/pkg/sentry/fsimpl/overlay/BUILD
@@ -23,6 +23,7 @@ go_library(
"fstree.go",
"overlay.go",
"regular_file.go",
+ "save_restore.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
@@ -30,6 +31,8 @@ go_library(
"//pkg/context",
"//pkg/fspath",
"//pkg/log",
+ "//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/sentry/arch",
"//pkg/sentry/fs/lock",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go
index 4506642ca..469f3a33d 100644
--- a/pkg/sentry/fsimpl/overlay/copy_up.go
+++ b/pkg/sentry/fsimpl/overlay/copy_up.go
@@ -409,7 +409,7 @@ func (d *dentry) copyUpDescendantsLocked(ctx context.Context, ds **[]*dentry) er
if dirent.Name == "." || dirent.Name == ".." {
continue
}
- child, err := d.fs.getChildLocked(ctx, d, dirent.Name, ds)
+ child, _, err := d.fs.getChildLocked(ctx, d, dirent.Name, ds)
if err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
index 78a01bbb7..bc07d72c0 100644
--- a/pkg/sentry/fsimpl/overlay/filesystem.go
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -121,63 +122,63 @@ func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*de
// * fs.renameMu must be locked.
// * d.dirMu must be locked.
// * !rp.Done().
-func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) {
+func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, lookupLayer, error) {
if !d.isDir() {
- return nil, syserror.ENOTDIR
+ return nil, lookupLayerNone, syserror.ENOTDIR
}
if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
- return nil, err
+ return nil, lookupLayerNone, err
}
afterSymlink:
name := rp.Component()
if name == "." {
rp.Advance()
- return d, nil
+ return d, d.topLookupLayer(), nil
}
if name == ".." {
if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil {
- return nil, err
+ return nil, lookupLayerNone, err
} else if isRoot || d.parent == nil {
rp.Advance()
- return d, nil
+ return d, d.topLookupLayer(), nil
}
if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil {
- return nil, err
+ return nil, lookupLayerNone, err
}
rp.Advance()
- return d.parent, nil
+ return d.parent, d.parent.topLookupLayer(), nil
}
- child, err := fs.getChildLocked(ctx, d, name, ds)
+ child, topLookupLayer, err := fs.getChildLocked(ctx, d, name, ds)
if err != nil {
- return nil, err
+ return nil, topLookupLayer, err
}
if err := rp.CheckMount(ctx, &child.vfsd); err != nil {
- return nil, err
+ return nil, lookupLayerNone, err
}
if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() {
target, err := child.readlink(ctx)
if err != nil {
- return nil, err
+ return nil, lookupLayerNone, err
}
if err := rp.HandleSymlink(target); err != nil {
- return nil, err
+ return nil, topLookupLayer, err
}
goto afterSymlink // don't check the current directory again
}
rp.Advance()
- return child, nil
+ return child, topLookupLayer, nil
}
// Preconditions:
// * fs.renameMu must be locked.
// * d.dirMu must be locked.
-func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
+func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, lookupLayer, error) {
if child, ok := parent.children[name]; ok {
- return child, nil
+ return child, child.topLookupLayer(), nil
}
- child, err := fs.lookupLocked(ctx, parent, name)
+ child, topLookupLayer, err := fs.lookupLocked(ctx, parent, name)
if err != nil {
- return nil, err
+ return nil, topLookupLayer, err
}
if parent.children == nil {
parent.children = make(map[string]*dentry)
@@ -185,16 +186,16 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
parent.children[name] = child
// child's refcount is initially 0, so it may be dropped after traversal.
*ds = appendDentry(*ds, child)
- return child, nil
+ return child, topLookupLayer, nil
}
// Preconditions:
// * fs.renameMu must be locked.
// * parent.dirMu must be locked.
-func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) {
+func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name string) (*dentry, lookupLayer, error) {
childPath := fspath.Parse(name)
child := fs.newDentry()
- existsOnAnyLayer := false
+ topLookupLayer := lookupLayerNone
var lookupErr error
vfsObj := fs.vfsfs.VirtualFilesystem()
@@ -215,7 +216,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str
defer childVD.DecRef(ctx)
mask := uint32(linux.STATX_TYPE)
- if !existsOnAnyLayer {
+ if topLookupLayer == lookupLayerNone {
// Mode, UID, GID, and (for non-directories) inode number come from
// the topmost layer on which the file exists.
mask |= linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO
@@ -238,10 +239,13 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str
if isWhiteout(&stat) {
// This is a whiteout, so it "doesn't exist" on this layer, and
// layers below this one are ignored.
+ if isUpper {
+ topLookupLayer = lookupLayerUpperWhiteout
+ }
return false
}
isDir := stat.Mode&linux.S_IFMT == linux.S_IFDIR
- if existsOnAnyLayer && !isDir {
+ if topLookupLayer != lookupLayerNone && !isDir {
// Directories are not merged with non-directory files from lower
// layers; instead, layers including and below the first
// non-directory file are ignored. (This file must be a directory
@@ -258,8 +262,12 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str
} else {
child.lowerVDs = append(child.lowerVDs, childVD)
}
- if !existsOnAnyLayer {
- existsOnAnyLayer = true
+ if topLookupLayer == lookupLayerNone {
+ if isUpper {
+ topLookupLayer = lookupLayerUpper
+ } else {
+ topLookupLayer = lookupLayerLower
+ }
child.mode = uint32(stat.Mode)
child.uid = stat.UID
child.gid = stat.GID
@@ -288,11 +296,11 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str
if lookupErr != nil {
child.destroyLocked(ctx)
- return nil, lookupErr
+ return nil, topLookupLayer, lookupErr
}
- if !existsOnAnyLayer {
+ if !topLookupLayer.existsInOverlay() {
child.destroyLocked(ctx)
- return nil, syserror.ENOENT
+ return nil, topLookupLayer, syserror.ENOENT
}
// Device and inode numbers were copied from the topmost layer above;
@@ -302,14 +310,20 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str
child.devMinor = fs.dirDevMinor
child.ino = fs.newDirIno()
} else if !child.upperVD.Ok() {
+ childDevMinor, err := fs.getLowerDevMinor(child.devMajor, child.devMinor)
+ if err != nil {
+ ctx.Infof("overlay.filesystem.lookupLocked: failed to map lower layer device number (%d, %d) to an overlay-specific device number: %v", child.devMajor, child.devMinor, err)
+ child.destroyLocked(ctx)
+ return nil, topLookupLayer, err
+ }
child.devMajor = linux.UNNAMED_MAJOR
- child.devMinor = fs.lowerDevMinors[child.lowerVDs[0].Mount().Filesystem()]
+ child.devMinor = childDevMinor
}
parent.IncRef()
child.parent = parent
child.name = name
- return child, nil
+ return child, topLookupLayer, nil
}
// lookupLayerLocked is similar to lookupLocked, but only returns information
@@ -408,7 +422,7 @@ func (ll lookupLayer) existsInOverlay() bool {
func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) {
for !rp.Final() {
d.dirMu.Lock()
- next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
+ next, _, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
d.dirMu.Unlock()
if err != nil {
return nil, err
@@ -428,7 +442,7 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath,
d := rp.Start().Impl().(*dentry)
for !rp.Done() {
d.dirMu.Lock()
- next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
+ next, _, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
d.dirMu.Unlock()
if err != nil {
return nil, err
@@ -463,9 +477,6 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
if name == "." || name == ".." {
return syserror.EEXIST
}
- if !dir && rp.MustBeDir() {
- return syserror.ENOENT
- }
if parent.vfsd.IsDead() {
return syserror.ENOENT
}
@@ -489,6 +500,10 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir
return syserror.EEXIST
}
+ if !dir && rp.MustBeDir() {
+ return syserror.ENOENT
+ }
+
// Ensure that the parent directory is copied-up so that we can create the
// new file in the upper layer.
if err := parent.copyUpLocked(ctx); err != nil {
@@ -791,9 +806,9 @@ afterTrailingSymlink:
}
// Determine whether or not we need to create a file.
parent.dirMu.Lock()
- child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
+ child, topLookupLayer, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
if err == syserror.ENOENT && mayCreate {
- fd, err := fs.createAndOpenLocked(ctx, rp, parent, &opts, &ds)
+ fd, err := fs.createAndOpenLocked(ctx, rp, parent, &opts, &ds, topLookupLayer == lookupLayerUpperWhiteout)
parent.dirMu.Unlock()
return fd, err
}
@@ -893,7 +908,7 @@ func (d *dentry) openCopiedUp(ctx context.Context, rp *vfs.ResolvingPath, opts *
// Preconditions:
// * parent.dirMu must be locked.
// * parent does not already contain a child named rp.Component().
-func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *dentry, opts *vfs.OpenOptions, ds **[]*dentry) (*vfs.FileDescription, error) {
+func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *dentry, opts *vfs.OpenOptions, ds **[]*dentry, haveUpperWhiteout bool) (*vfs.FileDescription, error) {
creds := rp.Credentials()
if err := parent.checkPermissions(creds, vfs.MayWrite); err != nil {
return nil, err
@@ -918,19 +933,12 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving
Start: parent.upperVD,
Path: fspath.Parse(childName),
}
- // We don't know if a whiteout exists on the upper layer; speculatively
- // unlink it.
- //
- // TODO(gvisor.dev/issue/1199): Modify OpenAt => stepLocked so that we do
- // know whether a whiteout exists.
- var haveUpperWhiteout bool
- switch err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err {
- case nil:
- haveUpperWhiteout = true
- case syserror.ENOENT:
- haveUpperWhiteout = false
- default:
- return nil, err
+ // Unlink the whiteout if it exists.
+ if haveUpperWhiteout {
+ if err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err != nil {
+ log.Warningf("overlay.filesystem.createAndOpenLocked: failed to unlink whiteout: %v", err)
+ return nil, err
+ }
}
// Create the file on the upper layer, and get an FD representing it.
upperFD, err := vfsObj.OpenAt(ctx, fs.creds, &pop, &vfs.OpenOptions{
@@ -961,7 +969,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving
}
// Re-lookup to get a dentry representing the new file, which is needed for
// the returned FD.
- child, err := fs.getChildLocked(ctx, parent, childName, ds)
+ child, _, err := fs.getChildLocked(ctx, parent, childName, ds)
if err != nil {
if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil {
panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) dentry lookup failure: %v", cleanupErr))
@@ -970,7 +978,10 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving
}
return nil, err
}
- // Finally construct the overlay FD.
+ // Finally construct the overlay FD. Below this point, we don't perform
+ // cleanup (the file was created successfully even if we can no longer open
+ // it for some reason).
+ parent.dirents = nil
upperFlags := upperFD.StatusFlags()
fd := &regularFileFD{
copiedUp: true,
@@ -981,8 +992,6 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving
upperFDOpts := upperFD.Options()
if err := fd.vfsfd.Init(fd, upperFlags, mnt, &child.vfsd, &upperFDOpts); err != nil {
upperFD.DecRef(ctx)
- // Don't bother with cleanup; the file was created successfully, we
- // just can't open it anymore for some reason.
return nil, err
}
parent.watches.Notify(ctx, childName, linux.IN_CREATE, 0 /* cookie */, vfs.PathEvent, false /* unlinked */)
@@ -1040,7 +1049,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
// directory, we need to check for write permission on it.
oldParent.dirMu.Lock()
defer oldParent.dirMu.Unlock()
- renamed, err := fs.getChildLocked(ctx, oldParent, oldName, &ds)
+ renamed, _, err := fs.getChildLocked(ctx, oldParent, oldName, &ds)
if err != nil {
return err
}
@@ -1072,20 +1081,17 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
if newParent.vfsd.IsDead() {
return syserror.ENOENT
}
- replacedLayer, err := fs.lookupLayerLocked(ctx, newParent, newName)
- if err != nil {
- return err
- }
var (
- replaced *dentry
- replacedVFSD *vfs.Dentry
- whiteouts map[string]bool
+ replaced *dentry
+ replacedVFSD *vfs.Dentry
+ replacedLayer lookupLayer
+ whiteouts map[string]bool
)
- if replacedLayer.existsInOverlay() {
- replaced, err = fs.getChildLocked(ctx, newParent, newName, &ds)
- if err != nil {
- return err
- }
+ replaced, replacedLayer, err = fs.getChildLocked(ctx, newParent, newName, &ds)
+ if err != nil && err != syserror.ENOENT {
+ return err
+ }
+ if replaced != nil {
replacedVFSD = &replaced.vfsd
if replaced.isDir() {
if !renamed.isDir() {
@@ -1289,7 +1295,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
// Unlike UnlinkAt, we need a dentry representing the child directory being
// removed in order to verify that it's empty.
- child, err := fs.getChildLocked(ctx, parent, name, &ds)
+ child, _, err := fs.getChildLocked(ctx, parent, name, &ds)
if err != nil {
return err
}
@@ -1541,7 +1547,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
if parentMode&linux.S_ISVTX != 0 {
// If the parent's sticky bit is set, we need a child dentry to get
// its owner.
- child, err = fs.getChildLocked(ctx, parent, name, &ds)
+ child, _, err = fs.getChildLocked(ctx, parent, name, &ds)
if err != nil {
return err
}
diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go
index 4c5de8d32..73130bc8d 100644
--- a/pkg/sentry/fsimpl/overlay/overlay.go
+++ b/pkg/sentry/fsimpl/overlay/overlay.go
@@ -22,6 +22,7 @@
// filesystem.renameMu
// dentry.dirMu
// dentry.copyMu
+// filesystem.devMu
// *** "memmap.Mappable locks" below this point
// dentry.mapsMu
// *** "memmap.Mappable locks taken by Translate" below this point
@@ -33,12 +34,14 @@
package overlay
import (
+ "fmt"
"strings"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
@@ -99,10 +102,15 @@ type filesystem struct {
// is immutable.
dirDevMinor uint32
- // lowerDevMinors maps lower layer filesystems to device minor numbers
- // assigned to non-directory files originating from that filesystem.
- // lowerDevMinors is immutable.
- lowerDevMinors map[*vfs.Filesystem]uint32
+ // lowerDevMinors maps device numbers from lower layer filesystems to
+ // device minor numbers assigned to non-directory files originating from
+ // that filesystem. (This remapping is necessary for lower layers because a
+ // file on a lower layer, and that same file on an overlay, are
+ // distinguishable because they will diverge after copy-up; this isn't true
+ // for non-directory files already on the upper layer.) lowerDevMinors is
+ // protected by devMu.
+ devMu sync.Mutex `state:"nosave"`
+ lowerDevMinors map[layerDevNumber]uint32
// renameMu synchronizes renaming with non-renaming operations in order to
// ensure consistent lock ordering between dentry.dirMu in different
@@ -114,78 +122,69 @@ type filesystem struct {
lastDirIno uint64
}
+// +stateify savable
+type layerDevNumber struct {
+ major uint32
+ minor uint32
+}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
mopts := vfs.GenericParseMountOptions(opts.Data)
fsoptsRaw := opts.InternalData
- fsopts, haveFSOpts := fsoptsRaw.(FilesystemOptions)
- if fsoptsRaw != nil && !haveFSOpts {
+ fsopts, ok := fsoptsRaw.(FilesystemOptions)
+ if fsoptsRaw != nil && !ok {
ctx.Infof("overlay.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted overlay.FilesystemOptions or nil", fsoptsRaw)
return nil, nil, syserror.EINVAL
}
- if haveFSOpts {
- if len(fsopts.LowerRoots) == 0 {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: LowerRoots must be non-empty")
+ vfsroot := vfs.RootFromContext(ctx)
+ if vfsroot.Ok() {
+ defer vfsroot.DecRef(ctx)
+ }
+
+ if upperPathname, ok := mopts["upperdir"]; ok {
+ if fsopts.UpperRoot.Ok() {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: both upperdir and FilesystemOptions.UpperRoot are specified")
return nil, nil, syserror.EINVAL
}
- if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two LowerRoots are required when UpperRoot is unspecified")
+ delete(mopts, "upperdir")
+ // Linux overlayfs also requires a workdir when upperdir is
+ // specified; we don't, so silently ignore this option.
+ delete(mopts, "workdir")
+ upperPath := fspath.Parse(upperPathname)
+ if !upperPath.Absolute {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname)
return nil, nil, syserror.EINVAL
}
- // We don't enforce a maximum number of lower layers when not
- // configured by applications; the sandbox owner can have an overlay
- // filesystem with any number of lower layers.
- } else {
- vfsroot := vfs.RootFromContext(ctx)
- defer vfsroot.DecRef(ctx)
- upperPathname, ok := mopts["upperdir"]
- if ok {
- delete(mopts, "upperdir")
- // Linux overlayfs also requires a workdir when upperdir is
- // specified; we don't, so silently ignore this option.
- delete(mopts, "workdir")
- upperPath := fspath.Parse(upperPathname)
- if !upperPath.Absolute {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname)
- return nil, nil, syserror.EINVAL
- }
- upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{
- Root: vfsroot,
- Start: vfsroot,
- Path: upperPath,
- FollowFinalSymlink: true,
- }, &vfs.GetDentryOptions{
- CheckSearchable: true,
- })
- if err != nil {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err)
- return nil, nil, err
- }
- defer upperRoot.DecRef(ctx)
- privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */)
- if err != nil {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err)
- return nil, nil, err
- }
- defer privateUpperRoot.DecRef(ctx)
- fsopts.UpperRoot = privateUpperRoot
+ upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{
+ Root: vfsroot,
+ Start: vfsroot,
+ Path: upperPath,
+ FollowFinalSymlink: true,
+ }, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err)
+ return nil, nil, err
+ }
+ privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */)
+ upperRoot.DecRef(ctx)
+ if err != nil {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err)
+ return nil, nil, err
}
- lowerPathnamesStr, ok := mopts["lowerdir"]
- if !ok {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: missing required option lowerdir")
+ defer privateUpperRoot.DecRef(ctx)
+ fsopts.UpperRoot = privateUpperRoot
+ }
+
+ if lowerPathnamesStr, ok := mopts["lowerdir"]; ok {
+ if len(fsopts.LowerRoots) != 0 {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: both lowerdir and FilesystemOptions.LowerRoots are specified")
return nil, nil, syserror.EINVAL
}
delete(mopts, "lowerdir")
lowerPathnames := strings.Split(lowerPathnamesStr, ":")
- const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK
- if len(lowerPathnames) < 2 && !fsopts.UpperRoot.Ok() {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two lowerdirs are required when upperdir is unspecified")
- return nil, nil, syserror.EINVAL
- }
- if len(lowerPathnames) > maxLowerLayers {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: %d lowerdirs specified, maximum %d", len(lowerPathnames), maxLowerLayers)
- return nil, nil, syserror.EINVAL
- }
for _, lowerPathname := range lowerPathnames {
lowerPath := fspath.Parse(lowerPathname)
if !lowerPath.Absolute {
@@ -204,8 +203,8 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve lowerdir %q: %v", lowerPathname, err)
return nil, nil, err
}
- defer lowerRoot.DecRef(ctx)
privateLowerRoot, err := clonePrivateMount(vfsObj, lowerRoot, true /* forceReadOnly */)
+ lowerRoot.DecRef(ctx)
if err != nil {
ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of lowerdir %q: %v", lowerPathname, err)
return nil, nil, err
@@ -214,31 +213,31 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
fsopts.LowerRoots = append(fsopts.LowerRoots, privateLowerRoot)
}
}
+
if len(mopts) != 0 {
ctx.Infof("overlay.FilesystemType.GetFilesystem: unused options: %v", mopts)
return nil, nil, syserror.EINVAL
}
- // Allocate device numbers.
+ if len(fsopts.LowerRoots) == 0 {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: at least one lower layer is required")
+ return nil, nil, syserror.EINVAL
+ }
+ if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two lower layers are required when no upper layer is present")
+ return nil, nil, syserror.EINVAL
+ }
+ const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK
+ if len(fsopts.LowerRoots) > maxLowerLayers {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: %d lower layers specified, maximum %d", len(fsopts.LowerRoots), maxLowerLayers)
+ return nil, nil, syserror.EINVAL
+ }
+
+ // Allocate dirDevMinor. lowerDevMinors are allocated dynamically.
dirDevMinor, err := vfsObj.GetAnonBlockDevMinor()
if err != nil {
return nil, nil, err
}
- lowerDevMinors := make(map[*vfs.Filesystem]uint32)
- for _, lowerRoot := range fsopts.LowerRoots {
- lowerFS := lowerRoot.Mount().Filesystem()
- if _, ok := lowerDevMinors[lowerFS]; !ok {
- devMinor, err := vfsObj.GetAnonBlockDevMinor()
- if err != nil {
- vfsObj.PutAnonBlockDevMinor(dirDevMinor)
- for _, lowerDevMinor := range lowerDevMinors {
- vfsObj.PutAnonBlockDevMinor(lowerDevMinor)
- }
- return nil, nil, err
- }
- lowerDevMinors[lowerFS] = devMinor
- }
- }
// Take extra references held by the filesystem.
if fsopts.UpperRoot.Ok() {
@@ -252,7 +251,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
opts: fsopts,
creds: creds.Fork(),
dirDevMinor: dirDevMinor,
- lowerDevMinors: lowerDevMinors,
+ lowerDevMinors: make(map[layerDevNumber]uint32),
}
fs.vfsfs.Init(vfsObj, &fstype, fs)
@@ -302,7 +301,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
root.ino = fs.newDirIno()
} else if !root.upperVD.Ok() {
root.devMajor = linux.UNNAMED_MAJOR
- root.devMinor = fs.lowerDevMinors[root.lowerVDs[0].Mount().Filesystem()]
+ rootDevMinor, err := fs.getLowerDevMinor(rootStat.DevMajor, rootStat.DevMinor)
+ if err != nil {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to get device number for root: %v", err)
+ root.destroyLocked(ctx)
+ fs.vfsfs.DecRef(ctx)
+ return nil, nil, err
+ }
+ root.devMinor = rootDevMinor
root.ino = rootStat.Ino
} else {
root.devMajor = rootStat.DevMajor
@@ -375,6 +381,21 @@ func (fs *filesystem) newDirIno() uint64 {
return atomic.AddUint64(&fs.lastDirIno, 1)
}
+func (fs *filesystem) getLowerDevMinor(layerMajor, layerMinor uint32) (uint32, error) {
+ fs.devMu.Lock()
+ defer fs.devMu.Unlock()
+ orig := layerDevNumber{layerMajor, layerMinor}
+ if minor, ok := fs.lowerDevMinors[orig]; ok {
+ return minor, nil
+ }
+ minor, err := fs.vfsfs.VirtualFilesystem().GetAnonBlockDevMinor()
+ if err != nil {
+ return 0, err
+ }
+ fs.lowerDevMinors[orig] = minor
+ return minor, nil
+}
+
// dentry implements vfs.DentryImpl.
//
// +stateify savable
@@ -458,9 +479,9 @@ type dentry struct {
//
// - isMappable is non-zero iff wrappedMappable is non-nil. isMappable is
// accessed using atomic memory operations.
- mapsMu sync.Mutex
+ mapsMu sync.Mutex `state:"nosave"`
lowerMappings memmap.MappingSet
- dataMu sync.RWMutex
+ dataMu sync.RWMutex `state:"nosave"`
wrappedMappable memmap.Mappable
isMappable uint32
@@ -484,6 +505,7 @@ func (fs *filesystem) newDentry() *dentry {
}
d.lowerVDs = d.inlineLowerVDs[:0]
d.vfsd.Init(d)
+ refsvfs2.Register(d)
return d
}
@@ -491,17 +513,19 @@ func (fs *filesystem) newDentry() *dentry {
func (d *dentry) IncRef() {
// d.refs may be 0 if d.fs.renameMu is locked, which serializes against
// d.checkDropLocked().
- atomic.AddInt64(&d.refs, 1)
+ r := atomic.AddInt64(&d.refs, 1)
+ refsvfs2.LogIncRef(d, r)
}
// TryIncRef implements vfs.DentryImpl.TryIncRef.
func (d *dentry) TryIncRef() bool {
for {
- refs := atomic.LoadInt64(&d.refs)
- if refs <= 0 {
+ r := atomic.LoadInt64(&d.refs)
+ if r <= 0 {
return false
}
- if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) {
+ if atomic.CompareAndSwapInt64(&d.refs, r, r+1) {
+ refsvfs2.LogTryIncRef(d, r+1)
return true
}
}
@@ -509,15 +533,27 @@ func (d *dentry) TryIncRef() bool {
// DecRef implements vfs.DentryImpl.DecRef.
func (d *dentry) DecRef(ctx context.Context) {
- if refs := atomic.AddInt64(&d.refs, -1); refs == 0 {
+ r := atomic.AddInt64(&d.refs, -1)
+ refsvfs2.LogDecRef(d, r)
+ if r == 0 {
d.fs.renameMu.Lock()
d.checkDropLocked(ctx)
d.fs.renameMu.Unlock()
- } else if refs < 0 {
+ } else if r < 0 {
panic("overlay.dentry.DecRef() called without holding a reference")
}
}
+func (d *dentry) decRefLocked(ctx context.Context) {
+ r := atomic.AddInt64(&d.refs, -1)
+ refsvfs2.LogDecRef(d, r)
+ if r == 0 {
+ d.checkDropLocked(ctx)
+ } else if r < 0 {
+ panic("overlay.dentry.decRefLocked() called without holding a reference")
+ }
+}
+
// checkDropLocked should be called after d's reference count becomes 0 or it
// becomes deleted.
//
@@ -577,12 +613,27 @@ func (d *dentry) destroyLocked(ctx context.Context) {
d.parent.dirMu.Unlock()
// Drop the reference held by d on its parent without recursively
// locking d.fs.renameMu.
- if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 {
- d.parent.checkDropLocked(ctx)
- } else if refs < 0 {
- panic("overlay.dentry.DecRef() called without holding a reference")
- }
+ d.parent.decRefLocked(ctx)
}
+ refsvfs2.Unregister(d)
+}
+
+// RefType implements refsvfs2.CheckedObject.Type.
+func (d *dentry) RefType() string {
+ return "overlay.dentry"
+}
+
+// LeakMessage implements refsvfs2.CheckedObject.LeakMessage.
+func (d *dentry) LeakMessage() string {
+ return fmt.Sprintf("[overlay.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs))
+}
+
+// LogRefs implements refsvfs2.CheckedObject.LogRefs.
+//
+// This should only be set to true for debugging purposes, as it can generate an
+// extremely large amount of output and drastically degrade performance.
+func (d *dentry) LogRefs() bool {
+ return false
}
// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
@@ -645,6 +696,13 @@ func (d *dentry) topLayer() vfs.VirtualDentry {
return vd
}
+func (d *dentry) topLookupLayer() lookupLayer {
+ if d.upperVD.Ok() {
+ return lookupLayerUpper
+ }
+ return lookupLayerLower
+}
+
func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid)))
}
diff --git a/pkg/sentry/fsimpl/overlay/save_restore.go b/pkg/sentry/fsimpl/overlay/save_restore.go
new file mode 100644
index 000000000..54809f16c
--- /dev/null
+++ b/pkg/sentry/fsimpl/overlay/save_restore.go
@@ -0,0 +1,27 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package overlay
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/refsvfs2"
+)
+
+func (d *dentry) afterLoad() {
+ if atomic.LoadInt64(&d.refs) != -1 {
+ refsvfs2.Register(d)
+ }
+}
diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go
index e44b79b68..0ecb592cf 100644
--- a/pkg/sentry/fsimpl/pipefs/pipefs.go
+++ b/pkg/sentry/fsimpl/pipefs/pipefs.go
@@ -101,7 +101,7 @@ type inode struct {
func newInode(ctx context.Context, fs *filesystem) *inode {
creds := auth.CredentialsFromContext(ctx)
return &inode{
- pipe: pipe.NewVFSPipe(false /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize),
+ pipe: pipe.NewVFSPipe(false /* isNamed */, pipe.DefaultPipeSize),
ino: fs.Filesystem.NextIno(),
uid: creds.EffectiveKUID,
gid: creds.EffectiveKGID,
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
index 2e086e34c..5196a2a80 100644
--- a/pkg/sentry/fsimpl/proc/BUILD
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "fd_dir_inode_refs.go",
package = "proc",
prefix = "fdDirInode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "fdDirInode",
},
@@ -19,7 +19,7 @@ go_template_instance(
out = "fd_info_dir_inode_refs.go",
package = "proc",
prefix = "fdInfoDirInode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "fdInfoDirInode",
},
@@ -30,7 +30,7 @@ go_template_instance(
out = "subtasks_inode_refs.go",
package = "proc",
prefix = "subtasksInode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "subtasksInode",
},
@@ -41,7 +41,7 @@ go_template_instance(
out = "task_inode_refs.go",
package = "proc",
prefix = "taskInode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "taskInode",
},
@@ -52,7 +52,7 @@ go_template_instance(
out = "tasks_inode_refs.go",
package = "proc",
prefix = "tasksInode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "tasksInode",
},
@@ -82,6 +82,7 @@ go_library(
"//pkg/context",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/fs/lock",
"//pkg/sentry/fsbridge",
diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go
index fd70a07de..8716d0a3c 100644
--- a/pkg/sentry/fsimpl/proc/filesystem.go
+++ b/pkg/sentry/fsimpl/proc/filesystem.go
@@ -17,6 +17,7 @@ package proc
import (
"fmt"
+ "strconv"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -24,10 +25,14 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
)
-// Name is the default filesystem name.
-const Name = "proc"
+const (
+ // Name is the default filesystem name.
+ Name = "proc"
+ defaultMaxCachedDentries = uint64(1000)
+)
// FilesystemType is the factory class for procfs.
//
@@ -63,9 +68,22 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF
if err != nil {
return nil, nil, err
}
+
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ maxCachedDentries := defaultMaxCachedDentries
+ if str, ok := mopts["dentry_cache_limit"]; ok {
+ delete(mopts, "dentry_cache_limit")
+ maxCachedDentries, err = strconv.ParseUint(str, 10, 64)
+ if err != nil {
+ ctx.Warningf("proc.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str)
+ return nil, nil, syserror.EINVAL
+ }
+ }
+
procfs := &filesystem{
devMinor: devMinor,
}
+ procfs.MaxCachedDentries = maxCachedDentries
procfs.VFSFilesystem().Init(vfsObj, &ft, procfs)
var cgroups map[string]string
@@ -74,9 +92,9 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF
cgroups = data.Cgroups
}
- inode := procfs.newTasksInode(k, pidns, cgroups)
+ inode := procfs.newTasksInode(ctx, k, pidns, cgroups)
var dentry kernfs.Dentry
- dentry.Init(&procfs.Filesystem, inode)
+ dentry.InitRoot(&procfs.Filesystem, inode)
return procfs.VFSFilesystem(), dentry.VFSDentry(), nil
}
@@ -94,11 +112,11 @@ type dynamicInode interface {
kernfs.Inode
vfs.DynamicBytesSource
- Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode)
+ Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode)
}
-func (fs *filesystem) newInode(creds *auth.Credentials, perm linux.FileMode, inode dynamicInode) dynamicInode {
- inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), inode, perm)
+func (fs *filesystem) newInode(ctx context.Context, creds *auth.Credentials, perm linux.FileMode, inode dynamicInode) dynamicInode {
+ inode.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), inode, perm)
return inode
}
@@ -114,8 +132,8 @@ func newStaticFile(data string) *staticFile {
return &staticFile{StaticData: vfs.StaticData{Data: data}}
}
-func (fs *filesystem) newStaticDir(creds *auth.Credentials, children map[string]kernfs.Inode) kernfs.Inode {
- return kernfs.NewStaticDir(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, children, kernfs.GenericDirectoryFDOptions{
+func (fs *filesystem) newStaticDir(ctx context.Context, creds *auth.Credentials, children map[string]kernfs.Inode) kernfs.Inode {
+ return kernfs.NewStaticDir(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, children, kernfs.GenericDirectoryFDOptions{
SeekEnd: kernfs.SeekEndZero,
})
}
diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go
index bad2fab4f..cb3c5e0fd 100644
--- a/pkg/sentry/fsimpl/proc/subtasks.go
+++ b/pkg/sentry/fsimpl/proc/subtasks.go
@@ -58,7 +58,7 @@ func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace,
cgroupControllers: cgroupControllers,
}
// Note: credentials are overridden by taskOwnedInode.
- subInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+ subInode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
subInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
subInode.EnableLeakCheck()
@@ -84,7 +84,7 @@ func (i *subtasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode,
}
// IterDirents implements kernfs.inodeDirectory.IterDirents.
-func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+func (i *subtasksInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
tasks := i.task.ThreadGroup().MemberIDs(i.pidns)
if len(tasks) == 0 {
return offset, syserror.ENOENT
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index b63a4eca0..19011b010 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -64,6 +64,7 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace
"gid_map": fs.newTaskOwnedInode(task, fs.NextIno(), 0644, &idMapData{task: task, gids: true}),
"io": fs.newTaskOwnedInode(task, fs.NextIno(), 0400, newIO(task, isThreadGroup)),
"maps": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mapsData{task: task}),
+ "mem": fs.newMemInode(task, fs.NextIno(), 0400),
"mountinfo": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mountInfoData{task: task}),
"mounts": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mountsData{task: task}),
"net": fs.newTaskNetDir(task),
@@ -89,7 +90,7 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace
taskInode := &taskInode{task: task}
// Note: credentials are overridden by taskOwnedInode.
- taskInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+ taskInode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
taskInode.EnableLeakCheck()
inode := &taskOwnedInode{Inode: taskInode, owner: task}
@@ -144,7 +145,7 @@ var _ kernfs.Inode = (*taskOwnedInode)(nil)
func (fs *filesystem) newTaskOwnedInode(task *kernel.Task, ino uint64, perm linux.FileMode, inode dynamicInode) kernfs.Inode {
// Note: credentials are overridden by taskOwnedInode.
- inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm)
+ inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm)
return &taskOwnedInode{Inode: inode, owner: task}
}
@@ -152,7 +153,7 @@ func (fs *filesystem) newTaskOwnedInode(task *kernel.Task, ino uint64, perm linu
func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.FileMode, children map[string]kernfs.Inode) kernfs.Inode {
// Note: credentials are overridden by taskOwnedInode.
fdOpts := kernfs.GenericDirectoryFDOptions{SeekEnd: kernfs.SeekEndZero}
- dir := kernfs.NewStaticDir(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, children, fdOpts)
+ dir := kernfs.NewStaticDir(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, children, fdOpts)
return &taskOwnedInode{Inode: dir, owner: task}
}
diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go
index 2c80ac5c2..d268b44be 100644
--- a/pkg/sentry/fsimpl/proc/task_fds.go
+++ b/pkg/sentry/fsimpl/proc/task_fds.go
@@ -64,7 +64,7 @@ type fdDir struct {
}
// IterDirents implements kernfs.inodeDirectory.IterDirents.
-func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+func (i *fdDir) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
var fds []int32
i.task.WithMuLocked(func(t *kernel.Task) {
if fdTable := t.FDTable(); fdTable != nil {
@@ -127,15 +127,15 @@ func (fs *filesystem) newFDDirInode(task *kernel.Task) kernfs.Inode {
produceSymlink: true,
},
}
- inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+ inode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
inode.EnableLeakCheck()
inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
return inode
}
// IterDirents implements kernfs.inodeDirectory.IterDirents.
-func (i *fdDirInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
- return i.fdDir.IterDirents(ctx, cb, offset, relOffset)
+func (i *fdDirInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+ return i.fdDir.IterDirents(ctx, mnt, cb, offset, relOffset)
}
// Lookup implements kernfs.inodeDirectory.Lookup.
@@ -209,7 +209,7 @@ func (fs *filesystem) newFDSymlink(task *kernel.Task, fd int32, ino uint64) kern
task: task,
fd: fd,
}
- inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
+ inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
return inode
}
@@ -264,7 +264,7 @@ func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) kernfs.Inode {
task: task,
},
}
- inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+ inode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
inode.EnableLeakCheck()
inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
return inode
@@ -288,8 +288,8 @@ func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (kernfs.Inode,
}
// IterDirents implements Inode.IterDirents.
-func (i *fdInfoDirInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
- return i.fdDir.IterDirents(ctx, cb, offset, relOffset)
+func (i *fdInfoDirInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
+ return i.fdDir.IterDirents(ctx, mnt, cb, offset, relOffset)
}
// Open implements kernfs.Inode.Open.
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index 79f8b7e9f..ba71d0fde 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -31,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -249,7 +250,7 @@ type commInode struct {
func (fs *filesystem) newComm(task *kernel.Task, ino uint64, perm linux.FileMode) kernfs.Inode {
inode := &commInode{task: task}
- inode.DynamicBytesFile.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, &commData{task: task}, perm)
+ inode.DynamicBytesFile.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, &commData{task: task}, perm)
return inode
}
@@ -366,6 +367,162 @@ func (d *idMapData) Write(ctx context.Context, src usermem.IOSequence, offset in
return int64(srclen), nil
}
+var _ kernfs.Inode = (*memInode)(nil)
+
+// memInode implements kernfs.Inode for /proc/[pid]/mem.
+//
+// +stateify savable
+type memInode struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoStatFS
+ kernfs.InodeNoopRefCount
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+
+ task *kernel.Task
+ locks vfs.FileLocks
+}
+
+func (fs *filesystem) newMemInode(task *kernel.Task, ino uint64, perm linux.FileMode) kernfs.Inode {
+ // Note: credentials are overridden by taskOwnedInode.
+ inode := &memInode{task: task}
+ inode.init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm)
+ return &taskOwnedInode{Inode: inode, owner: task}
+}
+
+func (f *memInode) init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) {
+ if perm&^linux.PermissionsMask != 0 {
+ panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask))
+ }
+ f.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm)
+}
+
+// Open implements kernfs.Inode.Open.
+func (f *memInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ // TODO(gvisor.dev/issue/260): Add check for PTRACE_MODE_ATTACH_FSCREDS
+ // Permission to read this file is governed by PTRACE_MODE_ATTACH_FSCREDS
+ // Since we dont implement setfsuid/setfsgid we can just use PTRACE_MODE_ATTACH
+ if !kernel.ContextCanTrace(ctx, f.task, true) {
+ return nil, syserror.EACCES
+ }
+ if err := checkTaskState(f.task); err != nil {
+ return nil, err
+ }
+ fd := &memFD{}
+ if err := fd.Init(rp.Mount(), d, f, opts.Flags); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// SetStat implements kernfs.Inode.SetStat.
+func (*memInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+var _ vfs.FileDescriptionImpl = (*memFD)(nil)
+
+// memFD implements vfs.FileDescriptionImpl for /proc/[pid]/mem.
+//
+// +stateify savable
+type memFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+
+ inode *memInode
+
+ // mu guards the fields below.
+ mu sync.Mutex `state:"nosave"`
+ offset int64
+}
+
+// Init initializes memFD.
+func (fd *memFD) Init(m *vfs.Mount, d *kernfs.Dentry, inode *memInode, flags uint32) error {
+ fd.LockFD.Init(&inode.locks)
+ if err := fd.vfsfd.Init(fd, flags, m, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil {
+ return err
+ }
+ fd.inode = inode
+ return nil
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *memFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ switch whence {
+ case linux.SEEK_SET:
+ case linux.SEEK_CUR:
+ offset += fd.offset
+ default:
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ fd.offset = offset
+ return offset, nil
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *memFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ m, err := getMMIncRef(fd.inode.task)
+ if err != nil {
+ return 0, nil
+ }
+ defer m.DecUsers(ctx)
+ // Buffer the read data because of MM locks
+ buf := make([]byte, dst.NumBytes())
+ n, readErr := m.CopyIn(ctx, usermem.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true})
+ if n > 0 {
+ if _, err := dst.CopyOut(ctx, buf[:n]); err != nil {
+ return 0, syserror.EFAULT
+ }
+ return int64(n), nil
+ }
+ if readErr != nil {
+ return 0, syserror.EIO
+ }
+ return 0, nil
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *memFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ fd.mu.Lock()
+ n, err := fd.PRead(ctx, dst, fd.offset, opts)
+ fd.offset += n
+ fd.mu.Unlock()
+ return n, err
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *memFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ fs := fd.vfsfd.VirtualDentry().Mount().Filesystem()
+ return fd.inode.Stat(ctx, fs, opts)
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *memFD) SetStat(context.Context, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *memFD) Release(context.Context) {}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *memFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *memFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
+
// mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps.
//
// +stateify savable
@@ -657,7 +814,7 @@ var _ kernfs.Inode = (*exeSymlink)(nil)
func (fs *filesystem) newExeSymlink(task *kernel.Task, ino uint64) kernfs.Inode {
inode := &exeSymlink{task: task}
- inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
+ inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
return inode
}
@@ -733,7 +890,7 @@ var _ kernfs.Inode = (*cwdSymlink)(nil)
func (fs *filesystem) newCwdSymlink(task *kernel.Task, ino uint64) kernfs.Inode {
inode := &cwdSymlink{task: task}
- inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
+ inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
return inode
}
@@ -850,7 +1007,7 @@ func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns stri
inode := &namespaceSymlink{task: task}
// Note: credentials are overridden by taskOwnedInode.
- inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, target)
+ inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, target)
taskInode := &taskOwnedInode{Inode: inode, owner: task}
return taskInode
@@ -872,8 +1029,10 @@ func (s *namespaceSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.Vir
// Create a synthetic inode to represent the namespace.
fs := mnt.Filesystem().Impl().(*filesystem)
+ nsInode := &namespaceInode{}
+ nsInode.Init(ctx, auth.CredentialsFromContext(ctx), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0444)
dentry := &kernfs.Dentry{}
- dentry.Init(&fs.Filesystem, &namespaceInode{})
+ dentry.Init(&fs.Filesystem, nsInode)
vd := vfs.MakeVirtualDentry(mnt, dentry.VFSDentry())
// Only IncRef vd.Mount() because vd.Dentry() already holds a ref of 1.
mnt.IncRef()
@@ -897,11 +1056,11 @@ type namespaceInode struct {
var _ kernfs.Inode = (*namespaceInode)(nil)
// Init initializes a namespace inode.
-func (i *namespaceInode) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) {
+func (i *namespaceInode) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) {
if perm&^linux.PermissionsMask != 0 {
panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask))
}
- i.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm)
+ i.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm)
}
// Open implements kernfs.Inode.Open.
diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go
index 3425e8698..5a9ee111f 100644
--- a/pkg/sentry/fsimpl/proc/task_net.go
+++ b/pkg/sentry/fsimpl/proc/task_net.go
@@ -57,33 +57,33 @@ func (fs *filesystem) newTaskNetDir(task *kernel.Task) kernfs.Inode {
// TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task
// network namespace.
contents = map[string]kernfs.Inode{
- "dev": fs.newInode(root, 0444, &netDevData{stack: stack}),
- "snmp": fs.newInode(root, 0444, &netSnmpData{stack: stack}),
+ "dev": fs.newInode(task, root, 0444, &netDevData{stack: stack}),
+ "snmp": fs.newInode(task, root, 0444, &netSnmpData{stack: stack}),
// The following files are simple stubs until they are implemented in
// netstack, if the file contains a header the stub is just the header
// otherwise it is an empty file.
- "arp": fs.newInode(root, 0444, newStaticFile(arp)),
- "netlink": fs.newInode(root, 0444, newStaticFile(netlink)),
- "netstat": fs.newInode(root, 0444, &netStatData{}),
- "packet": fs.newInode(root, 0444, newStaticFile(packet)),
- "protocols": fs.newInode(root, 0444, newStaticFile(protocols)),
+ "arp": fs.newInode(task, root, 0444, newStaticFile(arp)),
+ "netlink": fs.newInode(task, root, 0444, newStaticFile(netlink)),
+ "netstat": fs.newInode(task, root, 0444, &netStatData{}),
+ "packet": fs.newInode(task, root, 0444, newStaticFile(packet)),
+ "protocols": fs.newInode(task, root, 0444, newStaticFile(protocols)),
// Linux sets psched values to: nsec per usec, psched tick in ns, 1000000,
// high res timer ticks per sec (ClockGetres returns 1ns resolution).
- "psched": fs.newInode(root, 0444, newStaticFile(psched)),
- "ptype": fs.newInode(root, 0444, newStaticFile(ptype)),
- "route": fs.newInode(root, 0444, &netRouteData{stack: stack}),
- "tcp": fs.newInode(root, 0444, &netTCPData{kernel: k}),
- "udp": fs.newInode(root, 0444, &netUDPData{kernel: k}),
- "unix": fs.newInode(root, 0444, &netUnixData{kernel: k}),
+ "psched": fs.newInode(task, root, 0444, newStaticFile(psched)),
+ "ptype": fs.newInode(task, root, 0444, newStaticFile(ptype)),
+ "route": fs.newInode(task, root, 0444, &netRouteData{stack: stack}),
+ "tcp": fs.newInode(task, root, 0444, &netTCPData{kernel: k}),
+ "udp": fs.newInode(task, root, 0444, &netUDPData{kernel: k}),
+ "unix": fs.newInode(task, root, 0444, &netUnixData{kernel: k}),
}
if stack.SupportsIPv6() {
- contents["if_inet6"] = fs.newInode(root, 0444, &ifinet6{stack: stack})
- contents["ipv6_route"] = fs.newInode(root, 0444, newStaticFile(""))
- contents["tcp6"] = fs.newInode(root, 0444, &netTCP6Data{kernel: k})
- contents["udp6"] = fs.newInode(root, 0444, newStaticFile(upd6))
+ contents["if_inet6"] = fs.newInode(task, root, 0444, &ifinet6{stack: stack})
+ contents["ipv6_route"] = fs.newInode(task, root, 0444, newStaticFile(""))
+ contents["tcp6"] = fs.newInode(task, root, 0444, &netTCP6Data{kernel: k})
+ contents["udp6"] = fs.newInode(task, root, 0444, newStaticFile(upd6))
}
}
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index 3259c3732..b81ea14bf 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -62,19 +62,19 @@ type tasksInode struct {
var _ kernfs.Inode = (*tasksInode)(nil)
-func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode {
+func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode {
root := auth.NewRootCredentials(pidns.UserNamespace())
contents := map[string]kernfs.Inode{
- "cpuinfo": fs.newInode(root, 0444, newStaticFileSetStat(cpuInfoData(k))),
- "filesystems": fs.newInode(root, 0444, &filesystemsData{}),
- "loadavg": fs.newInode(root, 0444, &loadavgData{}),
- "sys": fs.newSysDir(root, k),
- "meminfo": fs.newInode(root, 0444, &meminfoData{}),
- "mounts": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/mounts"),
- "net": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/net"),
- "stat": fs.newInode(root, 0444, &statData{}),
- "uptime": fs.newInode(root, 0444, &uptimeData{}),
- "version": fs.newInode(root, 0444, &versionData{}),
+ "cpuinfo": fs.newInode(ctx, root, 0444, newStaticFileSetStat(cpuInfoData(k))),
+ "filesystems": fs.newInode(ctx, root, 0444, &filesystemsData{}),
+ "loadavg": fs.newInode(ctx, root, 0444, &loadavgData{}),
+ "sys": fs.newSysDir(ctx, root, k),
+ "meminfo": fs.newInode(ctx, root, 0444, &meminfoData{}),
+ "mounts": kernfs.NewStaticSymlink(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/mounts"),
+ "net": kernfs.NewStaticSymlink(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/net"),
+ "stat": fs.newInode(ctx, root, 0444, &statData{}),
+ "uptime": fs.newInode(ctx, root, 0444, &uptimeData{}),
+ "version": fs.newInode(ctx, root, 0444, &versionData{}),
}
inode := &tasksInode{
@@ -82,7 +82,7 @@ func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace
fs: fs,
cgroupControllers: cgroupControllers,
}
- inode.InodeAttrs.Init(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+ inode.InodeAttrs.Init(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
inode.EnableLeakCheck()
inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
@@ -106,9 +106,9 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err
// If it failed to parse, check if it's one of the special handled files.
switch name {
case selfName:
- return i.newSelfSymlink(root), nil
+ return i.newSelfSymlink(ctx, root), nil
case threadSelfName:
- return i.newThreadSelfSymlink(root), nil
+ return i.newThreadSelfSymlink(ctx, root), nil
}
return nil, syserror.ENOENT
}
@@ -122,7 +122,7 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err
}
// IterDirents implements kernfs.inodeDirectory.IterDirents.
-func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) {
+func (i *tasksInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) {
// fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256
const FIRST_PROCESS_ENTRY = 256
diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go
index 07c27cdd9..01b7a6678 100644
--- a/pkg/sentry/fsimpl/proc/tasks_files.go
+++ b/pkg/sentry/fsimpl/proc/tasks_files.go
@@ -43,9 +43,9 @@ type selfSymlink struct {
var _ kernfs.Inode = (*selfSymlink)(nil)
-func (i *tasksInode) newSelfSymlink(creds *auth.Credentials) kernfs.Inode {
+func (i *tasksInode) newSelfSymlink(ctx context.Context, creds *auth.Credentials) kernfs.Inode {
inode := &selfSymlink{pidns: i.pidns}
- inode.Init(creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777)
+ inode.Init(ctx, creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777)
return inode
}
@@ -84,9 +84,9 @@ type threadSelfSymlink struct {
var _ kernfs.Inode = (*threadSelfSymlink)(nil)
-func (i *tasksInode) newThreadSelfSymlink(creds *auth.Credentials) kernfs.Inode {
+func (i *tasksInode) newThreadSelfSymlink(ctx context.Context, creds *auth.Credentials) kernfs.Inode {
inode := &threadSelfSymlink{pidns: i.pidns}
- inode.Init(creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777)
+ inode.Init(ctx, creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777)
return inode
}
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index 95420368d..7c7afdcfa 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -40,93 +40,93 @@ const (
)
// newSysDir returns the dentry corresponding to /proc/sys directory.
-func (fs *filesystem) newSysDir(root *auth.Credentials, k *kernel.Kernel) kernfs.Inode {
- return fs.newStaticDir(root, map[string]kernfs.Inode{
- "kernel": fs.newStaticDir(root, map[string]kernfs.Inode{
- "hostname": fs.newInode(root, 0444, &hostnameData{}),
- "shmall": fs.newInode(root, 0444, shmData(linux.SHMALL)),
- "shmmax": fs.newInode(root, 0444, shmData(linux.SHMMAX)),
- "shmmni": fs.newInode(root, 0444, shmData(linux.SHMMNI)),
+func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k *kernel.Kernel) kernfs.Inode {
+ return fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
+ "kernel": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
+ "hostname": fs.newInode(ctx, root, 0444, &hostnameData{}),
+ "shmall": fs.newInode(ctx, root, 0444, shmData(linux.SHMALL)),
+ "shmmax": fs.newInode(ctx, root, 0444, shmData(linux.SHMMAX)),
+ "shmmni": fs.newInode(ctx, root, 0444, shmData(linux.SHMMNI)),
}),
- "vm": fs.newStaticDir(root, map[string]kernfs.Inode{
- "mmap_min_addr": fs.newInode(root, 0444, &mmapMinAddrData{k: k}),
- "overcommit_memory": fs.newInode(root, 0444, newStaticFile("0\n")),
+ "vm": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
+ "mmap_min_addr": fs.newInode(ctx, root, 0444, &mmapMinAddrData{k: k}),
+ "overcommit_memory": fs.newInode(ctx, root, 0444, newStaticFile("0\n")),
}),
- "net": fs.newSysNetDir(root, k),
+ "net": fs.newSysNetDir(ctx, root, k),
})
}
// newSysNetDir returns the dentry corresponding to /proc/sys/net directory.
-func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) kernfs.Inode {
+func (fs *filesystem) newSysNetDir(ctx context.Context, root *auth.Credentials, k *kernel.Kernel) kernfs.Inode {
var contents map[string]kernfs.Inode
// TODO(gvisor.dev/issue/1833): Support for using the network stack in the
// network namespace of the calling process.
if stack := k.RootNetworkNamespace().Stack(); stack != nil {
contents = map[string]kernfs.Inode{
- "ipv4": fs.newStaticDir(root, map[string]kernfs.Inode{
- "tcp_recovery": fs.newInode(root, 0644, &tcpRecoveryData{stack: stack}),
- "tcp_rmem": fs.newInode(root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}),
- "tcp_sack": fs.newInode(root, 0644, &tcpSackData{stack: stack}),
- "tcp_wmem": fs.newInode(root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}),
- "ip_forward": fs.newInode(root, 0444, &ipForwarding{stack: stack}),
+ "ipv4": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
+ "tcp_recovery": fs.newInode(ctx, root, 0644, &tcpRecoveryData{stack: stack}),
+ "tcp_rmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}),
+ "tcp_sack": fs.newInode(ctx, root, 0644, &tcpSackData{stack: stack}),
+ "tcp_wmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}),
+ "ip_forward": fs.newInode(ctx, root, 0444, &ipForwarding{stack: stack}),
// The following files are simple stubs until they are implemented in
// netstack, most of these files are configuration related. We use the
// value closest to the actual netstack behavior or any empty file, all
// of these files will have mode 0444 (read-only for all users).
- "ip_local_port_range": fs.newInode(root, 0444, newStaticFile("16000 65535")),
- "ip_local_reserved_ports": fs.newInode(root, 0444, newStaticFile("")),
- "ipfrag_time": fs.newInode(root, 0444, newStaticFile("30")),
- "ip_nonlocal_bind": fs.newInode(root, 0444, newStaticFile("0")),
- "ip_no_pmtu_disc": fs.newInode(root, 0444, newStaticFile("1")),
+ "ip_local_port_range": fs.newInode(ctx, root, 0444, newStaticFile("16000 65535")),
+ "ip_local_reserved_ports": fs.newInode(ctx, root, 0444, newStaticFile("")),
+ "ipfrag_time": fs.newInode(ctx, root, 0444, newStaticFile("30")),
+ "ip_nonlocal_bind": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "ip_no_pmtu_disc": fs.newInode(ctx, root, 0444, newStaticFile("1")),
// tcp_allowed_congestion_control tell the user what they are able to
// do as an unprivledged process so we leave it empty.
- "tcp_allowed_congestion_control": fs.newInode(root, 0444, newStaticFile("")),
- "tcp_available_congestion_control": fs.newInode(root, 0444, newStaticFile("reno")),
- "tcp_congestion_control": fs.newInode(root, 0444, newStaticFile("reno")),
+ "tcp_allowed_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("")),
+ "tcp_available_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("reno")),
+ "tcp_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("reno")),
// Many of the following stub files are features netstack doesn't
// support. The unsupported features return "0" to indicate they are
// disabled.
- "tcp_base_mss": fs.newInode(root, 0444, newStaticFile("1280")),
- "tcp_dsack": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_early_retrans": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_fack": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_fastopen": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_fastopen_key": fs.newInode(root, 0444, newStaticFile("")),
- "tcp_invalid_ratelimit": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_keepalive_intvl": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_keepalive_probes": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_keepalive_time": fs.newInode(root, 0444, newStaticFile("7200")),
- "tcp_mtu_probing": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_no_metrics_save": fs.newInode(root, 0444, newStaticFile("1")),
- "tcp_probe_interval": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_probe_threshold": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_retries1": fs.newInode(root, 0444, newStaticFile("3")),
- "tcp_retries2": fs.newInode(root, 0444, newStaticFile("15")),
- "tcp_rfc1337": fs.newInode(root, 0444, newStaticFile("1")),
- "tcp_slow_start_after_idle": fs.newInode(root, 0444, newStaticFile("1")),
- "tcp_synack_retries": fs.newInode(root, 0444, newStaticFile("5")),
- "tcp_syn_retries": fs.newInode(root, 0444, newStaticFile("3")),
- "tcp_timestamps": fs.newInode(root, 0444, newStaticFile("1")),
+ "tcp_base_mss": fs.newInode(ctx, root, 0444, newStaticFile("1280")),
+ "tcp_dsack": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_early_retrans": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_fack": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_fastopen": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_fastopen_key": fs.newInode(ctx, root, 0444, newStaticFile("")),
+ "tcp_invalid_ratelimit": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_keepalive_intvl": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_keepalive_probes": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_keepalive_time": fs.newInode(ctx, root, 0444, newStaticFile("7200")),
+ "tcp_mtu_probing": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_no_metrics_save": fs.newInode(ctx, root, 0444, newStaticFile("1")),
+ "tcp_probe_interval": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_probe_threshold": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_retries1": fs.newInode(ctx, root, 0444, newStaticFile("3")),
+ "tcp_retries2": fs.newInode(ctx, root, 0444, newStaticFile("15")),
+ "tcp_rfc1337": fs.newInode(ctx, root, 0444, newStaticFile("1")),
+ "tcp_slow_start_after_idle": fs.newInode(ctx, root, 0444, newStaticFile("1")),
+ "tcp_synack_retries": fs.newInode(ctx, root, 0444, newStaticFile("5")),
+ "tcp_syn_retries": fs.newInode(ctx, root, 0444, newStaticFile("3")),
+ "tcp_timestamps": fs.newInode(ctx, root, 0444, newStaticFile("1")),
}),
- "core": fs.newStaticDir(root, map[string]kernfs.Inode{
- "default_qdisc": fs.newInode(root, 0444, newStaticFile("pfifo_fast")),
- "message_burst": fs.newInode(root, 0444, newStaticFile("10")),
- "message_cost": fs.newInode(root, 0444, newStaticFile("5")),
- "optmem_max": fs.newInode(root, 0444, newStaticFile("0")),
- "rmem_default": fs.newInode(root, 0444, newStaticFile("212992")),
- "rmem_max": fs.newInode(root, 0444, newStaticFile("212992")),
- "somaxconn": fs.newInode(root, 0444, newStaticFile("128")),
- "wmem_default": fs.newInode(root, 0444, newStaticFile("212992")),
- "wmem_max": fs.newInode(root, 0444, newStaticFile("212992")),
+ "core": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
+ "default_qdisc": fs.newInode(ctx, root, 0444, newStaticFile("pfifo_fast")),
+ "message_burst": fs.newInode(ctx, root, 0444, newStaticFile("10")),
+ "message_cost": fs.newInode(ctx, root, 0444, newStaticFile("5")),
+ "optmem_max": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "rmem_default": fs.newInode(ctx, root, 0444, newStaticFile("212992")),
+ "rmem_max": fs.newInode(ctx, root, 0444, newStaticFile("212992")),
+ "somaxconn": fs.newInode(ctx, root, 0444, newStaticFile("128")),
+ "wmem_default": fs.newInode(ctx, root, 0444, newStaticFile("212992")),
+ "wmem_max": fs.newInode(ctx, root, 0444, newStaticFile("212992")),
}),
}
}
- return fs.newStaticDir(root, contents)
+ return fs.newStaticDir(ctx, root, contents)
}
// mmapMinAddrData implements vfs.DynamicBytesSource for
diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go
index 2582ababd..7ee6227a9 100644
--- a/pkg/sentry/fsimpl/proc/tasks_test.go
+++ b/pkg/sentry/fsimpl/proc/tasks_test.go
@@ -77,6 +77,7 @@ var (
"gid_map": linux.DT_REG,
"io": linux.DT_REG,
"maps": linux.DT_REG,
+ "mem": linux.DT_REG,
"mountinfo": linux.DT_REG,
"mounts": linux.DT_REG,
"net": linux.DT_DIR,
diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go
index cf91ea36c..fda1fa942 100644
--- a/pkg/sentry/fsimpl/sockfs/sockfs.go
+++ b/pkg/sentry/fsimpl/sockfs/sockfs.go
@@ -108,13 +108,13 @@ func (i *inode) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, e
// NewDentry constructs and returns a sockfs dentry.
//
// Preconditions: mnt.Filesystem() must have been returned by NewFilesystem().
-func NewDentry(creds *auth.Credentials, mnt *vfs.Mount) *vfs.Dentry {
+func NewDentry(ctx context.Context, mnt *vfs.Mount) *vfs.Dentry {
fs := mnt.Filesystem().Impl().(*filesystem)
// File mode matches net/socket.c:sock_alloc.
filemode := linux.FileMode(linux.S_IFSOCK | 0600)
i := &inode{}
- i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.Filesystem.NextIno(), filemode)
+ i.InodeAttrs.Init(ctx, auth.CredentialsFromContext(ctx), linux.UNNAMED_MAJOR, fs.devMinor, fs.Filesystem.NextIno(), filemode)
d := &kernfs.Dentry{}
d.Init(&fs.Filesystem, i)
diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD
index 906cd52cb..09043b572 100644
--- a/pkg/sentry/fsimpl/sys/BUILD
+++ b/pkg/sentry/fsimpl/sys/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "dir_refs.go",
package = "sys",
prefix = "dir",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "dir",
},
@@ -28,6 +28,7 @@ go_library(
"//pkg/coverage",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/sentry/arch",
"//pkg/sentry/fsimpl/kernfs",
"//pkg/sentry/kernel",
diff --git a/pkg/sentry/fsimpl/sys/kcov.go b/pkg/sentry/fsimpl/sys/kcov.go
index 31a361029..b13f141a8 100644
--- a/pkg/sentry/fsimpl/sys/kcov.go
+++ b/pkg/sentry/fsimpl/sys/kcov.go
@@ -29,7 +29,7 @@ import (
func (fs *filesystem) newKcovFile(ctx context.Context, creds *auth.Credentials) kernfs.Inode {
k := &kcovInode{}
- k.InodeAttrs.Init(creds, 0, 0, fs.NextIno(), linux.S_IFREG|0600)
+ k.InodeAttrs.Init(ctx, creds, 0, 0, fs.NextIno(), linux.S_IFREG|0600)
return k
}
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
index 1ad679830..506a2a0f0 100644
--- a/pkg/sentry/fsimpl/sys/sys.go
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -18,6 +18,7 @@ package sys
import (
"bytes"
"fmt"
+ "strconv"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -29,9 +30,12 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
-// Name is the default filesystem name.
-const Name = "sysfs"
-const defaultSysDirMode = linux.FileMode(0755)
+const (
+ // Name is the default filesystem name.
+ Name = "sysfs"
+ defaultSysDirMode = linux.FileMode(0755)
+ defaultMaxCachedDentries = uint64(1000)
+)
// FilesystemType implements vfs.FilesystemType.
//
@@ -62,31 +66,43 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
return nil, nil, err
}
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ maxCachedDentries := defaultMaxCachedDentries
+ if str, ok := mopts["dentry_cache_limit"]; ok {
+ delete(mopts, "dentry_cache_limit")
+ maxCachedDentries, err = strconv.ParseUint(str, 10, 64)
+ if err != nil {
+ ctx.Warningf("sys.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str)
+ return nil, nil, syserror.EINVAL
+ }
+ }
+
fs := &filesystem{
devMinor: devMinor,
}
+ fs.MaxCachedDentries = maxCachedDentries
fs.VFSFilesystem().Init(vfsObj, &fsType, fs)
- root := fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{
- "block": fs.newDir(creds, defaultSysDirMode, nil),
- "bus": fs.newDir(creds, defaultSysDirMode, nil),
- "class": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{
- "power_supply": fs.newDir(creds, defaultSysDirMode, nil),
+ root := fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{
+ "block": fs.newDir(ctx, creds, defaultSysDirMode, nil),
+ "bus": fs.newDir(ctx, creds, defaultSysDirMode, nil),
+ "class": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{
+ "power_supply": fs.newDir(ctx, creds, defaultSysDirMode, nil),
}),
- "dev": fs.newDir(creds, defaultSysDirMode, nil),
- "devices": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{
- "system": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{
+ "dev": fs.newDir(ctx, creds, defaultSysDirMode, nil),
+ "devices": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{
+ "system": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{
"cpu": cpuDir(ctx, fs, creds),
}),
}),
- "firmware": fs.newDir(creds, defaultSysDirMode, nil),
- "fs": fs.newDir(creds, defaultSysDirMode, nil),
+ "firmware": fs.newDir(ctx, creds, defaultSysDirMode, nil),
+ "fs": fs.newDir(ctx, creds, defaultSysDirMode, nil),
"kernel": kernelDir(ctx, fs, creds),
- "module": fs.newDir(creds, defaultSysDirMode, nil),
- "power": fs.newDir(creds, defaultSysDirMode, nil),
+ "module": fs.newDir(ctx, creds, defaultSysDirMode, nil),
+ "power": fs.newDir(ctx, creds, defaultSysDirMode, nil),
})
var rootD kernfs.Dentry
- rootD.Init(&fs.Filesystem, root)
+ rootD.InitRoot(&fs.Filesystem, root)
return fs.VFSFilesystem(), rootD.VFSDentry(), nil
}
@@ -94,14 +110,14 @@ func cpuDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs
k := kernel.KernelFromContext(ctx)
maxCPUCores := k.ApplicationCores()
children := map[string]kernfs.Inode{
- "online": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)),
- "possible": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)),
- "present": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)),
+ "online": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)),
+ "possible": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)),
+ "present": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)),
}
for i := uint(0); i < maxCPUCores; i++ {
- children[fmt.Sprintf("cpu%d", i)] = fs.newDir(creds, linux.FileMode(0555), nil)
+ children[fmt.Sprintf("cpu%d", i)] = fs.newDir(ctx, creds, linux.FileMode(0555), nil)
}
- return fs.newDir(creds, defaultSysDirMode, children)
+ return fs.newDir(ctx, creds, defaultSysDirMode, children)
}
func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs.Inode {
@@ -111,12 +127,12 @@ func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) ker
var children map[string]kernfs.Inode
if coverage.KcovAvailable() {
children = map[string]kernfs.Inode{
- "debug": fs.newDir(creds, linux.FileMode(0700), map[string]kernfs.Inode{
+ "debug": fs.newDir(ctx, creds, linux.FileMode(0700), map[string]kernfs.Inode{
"kcov": fs.newKcovFile(ctx, creds),
}),
}
}
- return fs.newDir(creds, defaultSysDirMode, children)
+ return fs.newDir(ctx, creds, defaultSysDirMode, children)
}
// Release implements vfs.FilesystemImpl.Release.
@@ -140,9 +156,9 @@ type dir struct {
locks vfs.FileLocks
}
-func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
+func (fs *filesystem) newDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
d := &dir{}
- d.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755)
+ d.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755)
d.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
d.EnableLeakCheck()
d.IncLinks(d.OrderedChildren.Populate(contents))
@@ -191,9 +207,9 @@ func (c *cpuFile) Generate(ctx context.Context, buf *bytes.Buffer) error {
return nil
}
-func (fs *filesystem) newCPUFile(creds *auth.Credentials, maxCores uint, mode linux.FileMode) kernfs.Inode {
+func (fs *filesystem) newCPUFile(ctx context.Context, creds *auth.Credentials, maxCores uint, mode linux.FileMode) kernfs.Inode {
c := &cpuFile{maxCores: maxCores}
- c.DynamicBytesFile.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), c, mode)
+ c.DynamicBytesFile.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), c, mode)
return c
}
diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD
index 5cd428d64..fe520b6fd 100644
--- a/pkg/sentry/fsimpl/tmpfs/BUILD
+++ b/pkg/sentry/fsimpl/tmpfs/BUILD
@@ -31,7 +31,7 @@ go_template_instance(
out = "inode_refs.go",
package = "tmpfs",
prefix = "inode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "inode",
},
@@ -48,6 +48,7 @@ go_library(
"inode_refs.go",
"named_pipe.go",
"regular_file.go",
+ "save_restore.go",
"socket_file.go",
"symlink.go",
"tmpfs.go",
@@ -60,6 +61,7 @@ go_library(
"//pkg/fspath",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/fs",
diff --git a/pkg/sentry/fsimpl/tmpfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
index d772db9e9..57e7b57b0 100644
--- a/pkg/sentry/fsimpl/tmpfs/named_pipe.go
+++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
@@ -18,7 +18,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
- "gvisor.dev/gvisor/pkg/usermem"
)
// +stateify savable
@@ -32,7 +31,7 @@ type namedPipe struct {
// * fs.mu must be locked.
// * rp.Mount().CheckBeginWrite() has been called successfully.
func (fs *filesystem) newNamedPipe(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode {
- file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)}
+ file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize)}
file.inode.init(file, fs, kuid, kgid, linux.S_IFIFO|mode)
file.inode.nlink = 1 // Only the parent has a link.
return &file.inode
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index ce4e3eda7..98680fde9 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -42,7 +42,7 @@ type regularFile struct {
inode inode
// memFile is a platform.File used to allocate pages to this regularFile.
- memFile *pgalloc.MemoryFile
+ memFile *pgalloc.MemoryFile `state:"nosave"`
// memoryUsageKind is the memory accounting category under which pages backing
// this regularFile's contents are accounted.
@@ -92,7 +92,7 @@ type regularFile struct {
func (fs *filesystem) newRegularFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode {
file := &regularFile{
- memFile: fs.memFile,
+ memFile: fs.mfp.MemoryFile(),
memoryUsageKind: usage.Tmpfs,
seals: linux.F_SEAL_SEAL,
}
diff --git a/pkg/sentry/fsimpl/tmpfs/save_restore.go b/pkg/sentry/fsimpl/tmpfs/save_restore.go
new file mode 100644
index 000000000..b27f75cc2
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/save_restore.go
@@ -0,0 +1,20 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+// afterLoad is called by stateify.
+func (rf *regularFile) afterLoad() {
+ rf.memFile = rf.inode.fs.mfp.MemoryFile()
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index e2a0aac69..4ce859d57 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -61,8 +61,9 @@ type FilesystemType struct{}
type filesystem struct {
vfsfs vfs.Filesystem
- // memFile is used to allocate pages to for regular files.
- memFile *pgalloc.MemoryFile
+ // mfp is used to allocate memory that stores regular file contents. mfp is
+ // immutable.
+ mfp pgalloc.MemoryFileProvider
// clock is a realtime clock used to set timestamps in file operations.
clock time.Clock
@@ -106,8 +107,8 @@ type FilesystemOpts struct {
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, _ string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
- memFileProvider := pgalloc.MemoryFileProviderFromContext(ctx)
- if memFileProvider == nil {
+ mfp := pgalloc.MemoryFileProviderFromContext(ctx)
+ if mfp == nil {
panic("MemoryFileProviderFromContext returned nil")
}
@@ -181,7 +182,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
clock := time.RealtimeClockFromContext(ctx)
fs := filesystem{
- memFile: memFileProvider.MemoryFile(),
+ mfp: mfp,
clock: clock,
devMinor: devMinor,
}
diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD
index 0ca750281..e265be0ee 100644
--- a/pkg/sentry/fsimpl/verity/BUILD
+++ b/pkg/sentry/fsimpl/verity/BUILD
@@ -6,6 +6,7 @@ go_library(
name = "verity",
srcs = [
"filesystem.go",
+ "save_restore.go",
"verity.go",
],
visibility = ["//pkg/sentry:internal"],
@@ -15,6 +16,7 @@ go_library(
"//pkg/fspath",
"//pkg/marshal/primitive",
"//pkg/merkletree",
+ "//pkg/refsvfs2",
"//pkg/sentry/arch",
"//pkg/sentry/fs/lock",
"//pkg/sentry/kernel",
@@ -38,10 +40,12 @@ go_test(
"//pkg/context",
"//pkg/fspath",
"//pkg/sentry/arch",
+ "//pkg/sentry/fsimpl/testutil",
"//pkg/sentry/fsimpl/tmpfs",
+ "//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
- "//pkg/sentry/kernel/contexttest",
"//pkg/sentry/vfs",
+ "//pkg/syserror",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index 03da505e1..4e8d63d51 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -192,7 +192,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// contains the expected xattrs. If the file or the xattr does not
// exist, it indicates unexpected modifications to the file system.
if err == syserror.ENOENT || err == syserror.ENODATA {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err))
}
if err != nil {
return nil, err
@@ -201,7 +201,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// unexpected modifications to the file system.
offset, err := strconv.Atoi(off)
if err != nil {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err))
}
// Open parent Merkle tree file to read and verify child's hash.
@@ -215,7 +215,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// The parent Merkle tree file should have been created. If it's
// missing, it indicates an unexpected modification to the file system.
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err))
}
if err != nil {
return nil, err
@@ -233,7 +233,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// contains the expected xattrs. If the file or the xattr does not
// exist, it indicates unexpected modifications to the file system.
if err == syserror.ENOENT || err == syserror.ENODATA {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err))
}
if err != nil {
return nil, err
@@ -243,7 +243,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// unexpected modifications to the file system.
parentSize, err := strconv.Atoi(dataSize)
if err != nil {
- return nil, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
}
fdReader := vfs.FileReadWriteSeeker{
@@ -256,7 +256,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
Start: parent.lowerVD,
}, &vfs.StatOptions{})
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err))
}
if err != nil {
return nil, err
@@ -267,20 +267,22 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// Verify returns with success.
var buf bytes.Buffer
if _, err := merkletree.Verify(&merkletree.VerifyParams{
- Out: &buf,
- File: &fdReader,
- Tree: &fdReader,
- Size: int64(parentSize),
- Name: parent.name,
- Mode: uint32(parentStat.Mode),
- UID: parentStat.UID,
- GID: parentStat.GID,
+ Out: &buf,
+ File: &fdReader,
+ Tree: &fdReader,
+ Size: int64(parentSize),
+ Name: parent.name,
+ Mode: uint32(parentStat.Mode),
+ UID: parentStat.UID,
+ GID: parentStat.GID,
+ //TODO(b/156980949): Support passing other hash algorithms.
+ HashAlgorithms: fs.alg.toLinuxHashAlg(),
ReadOffset: int64(offset),
- ReadSize: int64(merkletree.DigestSize()),
+ ReadSize: int64(merkletree.DigestSize(fs.alg.toLinuxHashAlg())),
Expected: parent.hash,
DataAndTreeInSameFile: true,
}); err != nil && err != io.EOF {
- return nil, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification for %s failed: %v", childPath, err))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Verification for %s failed: %v", childPath, err))
}
// Cache child hash when it's verified the first time.
@@ -312,7 +314,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat
Flags: linux.O_RDONLY,
})
if err == syserror.ENOENT {
- return alertIntegrityViolation(err, fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err))
+ return alertIntegrityViolation(fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err))
}
if err != nil {
return err
@@ -324,7 +326,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat
})
if err == syserror.ENODATA {
- return alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err))
+ return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err))
}
if err != nil {
return err
@@ -332,7 +334,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat
size, err := strconv.Atoi(merkleSize)
if err != nil {
- return alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
+ return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
}
fdReader := vfs.FileReadWriteSeeker{
@@ -342,14 +344,16 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat
var buf bytes.Buffer
params := &merkletree.VerifyParams{
- Out: &buf,
- Tree: &fdReader,
- Size: int64(size),
- Name: d.name,
- Mode: uint32(stat.Mode),
- UID: stat.UID,
- GID: stat.GID,
- ReadOffset: 0,
+ Out: &buf,
+ Tree: &fdReader,
+ Size: int64(size),
+ Name: d.name,
+ Mode: uint32(stat.Mode),
+ UID: stat.UID,
+ GID: stat.GID,
+ //TODO(b/156980949): Support passing other hash algorithms.
+ HashAlgorithms: fs.alg.toLinuxHashAlg(),
+ ReadOffset: 0,
// Set read size to 0 so only the metadata is verified.
ReadSize: 0,
Expected: d.hash,
@@ -360,17 +364,57 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat
}
if _, err := merkletree.Verify(params); err != nil && err != io.EOF {
- return alertIntegrityViolation(err, fmt.Sprintf("Verification stat for %s failed: %v", childPath, err))
+ return alertIntegrityViolation(fmt.Sprintf("Verification stat for %s failed: %v", childPath, err))
}
d.mode = uint32(stat.Mode)
d.uid = stat.UID
d.gid = stat.GID
+ d.size = uint32(size)
return nil
}
// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
if child, ok := parent.children[name]; ok {
+ // If verity is enabled on child, we should check again whether
+ // the file and the corresponding Merkle tree are as expected,
+ // in order to catch deletion/renaming after the last time it's
+ // accessed.
+ if child.verityEnabled() {
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ // Get the path to the child dentry. This is only used
+ // to provide path information in failure case.
+ path, err := vfsObj.PathnameWithDeleted(ctx, child.fs.rootDentry.lowerVD, child.lowerVD)
+ if err != nil {
+ return nil, err
+ }
+
+ childVD, err := parent.getLowerAt(ctx, vfsObj, name)
+ if err == syserror.ENOENT {
+ // The file was previously accessed. If the
+ // file does not exist now, it indicates an
+ // unexpected modification to the file system.
+ return nil, alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", path))
+ }
+ if err != nil {
+ return nil, err
+ }
+ defer childVD.DecRef(ctx)
+
+ childMerkleVD, err := parent.getLowerAt(ctx, vfsObj, merklePrefix+name)
+ // The Merkle tree file was previous accessed. If it
+ // does not exist now, it indicates an unexpected
+ // modification to the file system.
+ if err == syserror.ENOENT {
+ return nil, alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", path))
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ defer childMerkleVD.DecRef(ctx)
+ }
+
// If enabling verification on files/directories is not allowed
// during runtime, all cached children are already verified. If
// runtime enable is allowed and the parent directory is
@@ -418,13 +462,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) {
vfsObj := fs.vfsfs.VirtualFilesystem()
- childFilename := fspath.Parse(name)
- childVD, childErr := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{
- Root: parent.lowerVD,
- Start: parent.lowerVD,
- Path: childFilename,
- }, &vfs.GetDentryOptions{})
-
+ childVD, childErr := parent.getLowerAt(ctx, vfsObj, name)
// We will handle ENOENT separately, as it may indicate unexpected
// modifications to the file system, and may cause a sentry panic.
if childErr != nil && childErr != syserror.ENOENT {
@@ -437,13 +475,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
defer childVD.DecRef(ctx)
}
- childMerkleFilename := merklePrefix + name
- childMerkleVD, childMerkleErr := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{
- Root: parent.lowerVD,
- Start: parent.lowerVD,
- Path: fspath.Parse(childMerkleFilename),
- }, &vfs.GetDentryOptions{})
-
+ childMerkleVD, childMerkleErr := parent.getLowerAt(ctx, vfsObj, merklePrefix+name)
// We will handle ENOENT separately, as it may indicate unexpected
// modifications to the file system, and may cause a sentry panic.
if childMerkleErr != nil && childMerkleErr != syserror.ENOENT {
@@ -472,7 +504,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
// corresponding Merkle tree is found. This indicates an
// unexpected modification to the file system that
// removed/renamed the child.
- return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name))
} else if childErr == nil && childMerkleErr == syserror.ENOENT {
// If in allowRuntimeEnable mode, and the Merkle tree file is
// not created yet, we create an empty Merkle tree file, so that
@@ -488,7 +520,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{
Root: parent.lowerVD,
Start: parent.lowerVD,
- Path: fspath.Parse(childMerkleFilename),
+ Path: fspath.Parse(merklePrefix + name),
}, &vfs.OpenOptions{
Flags: linux.O_RDWR | linux.O_CREAT,
Mode: 0644,
@@ -497,11 +529,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
return nil, err
}
childMerkleFD.DecRef(ctx)
- childMerkleVD, err = vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{
- Root: parent.lowerVD,
- Start: parent.lowerVD,
- Path: fspath.Parse(childMerkleFilename),
- }, &vfs.GetDentryOptions{})
+ childMerkleVD, err = parent.getLowerAt(ctx, vfsObj, merklePrefix+name)
if err != nil {
return nil, err
}
@@ -509,7 +537,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
// If runtime enable is not allowed. This indicates an
// unexpected modification to the file system that
// removed/renamed the Merkle tree file.
- return nil, alertIntegrityViolation(childMerkleErr, fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name))
}
} else if childErr == syserror.ENOENT && childMerkleErr == syserror.ENOENT {
// Both the child and the corresponding Merkle tree are missing.
@@ -518,7 +546,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
// TODO(b/167752508): Investigate possible ways to differentiate
// cases that both files are deleted from cases that they never
// exist in the file system.
- return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Failed to find file %s", parentPath+"/"+name))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Failed to find file %s", parentPath+"/"+name))
}
mask := uint32(linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID)
@@ -762,7 +790,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
// missing, it indicates an unexpected modification to the file system.
if err != nil {
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("File %s expected but not found", path))
+ return nil, alertIntegrityViolation(fmt.Sprintf("File %s expected but not found", path))
}
return nil, err
}
@@ -785,7 +813,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
// the file system.
if err != nil {
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path))
}
return nil, err
}
@@ -810,7 +838,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
})
if err != nil {
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path))
}
return nil, err
}
@@ -828,7 +856,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
if err != nil {
if err == syserror.ENOENT {
parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD)
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath))
}
return nil, err
}
diff --git a/pkg/sentry/fsimpl/verity/save_restore.go b/pkg/sentry/fsimpl/verity/save_restore.go
new file mode 100644
index 000000000..46b064342
--- /dev/null
+++ b/pkg/sentry/fsimpl/verity/save_restore.go
@@ -0,0 +1,27 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package verity
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/refsvfs2"
+)
+
+func (d *dentry) afterLoad() {
+ if atomic.LoadInt64(&d.refs) != -1 {
+ refsvfs2.Register(d)
+ }
+}
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index 8dc9e26bc..d24c839bb 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -23,6 +23,7 @@ package verity
import (
"fmt"
+ "math"
"strconv"
"sync/atomic"
@@ -31,6 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/merkletree"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/arch"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -41,32 +43,62 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-// Name is the default filesystem name.
-const Name = "verity"
+const (
+ // Name is the default filesystem name.
+ Name = "verity"
-// merklePrefix is the prefix of the Merkle tree files. For example, the Merkle
-// tree file for "/foo" is "/.merkle.verity.foo".
-const merklePrefix = ".merkle.verity."
+ // merklePrefix is the prefix of the Merkle tree files. For example, the Merkle
+ // tree file for "/foo" is "/.merkle.verity.foo".
+ merklePrefix = ".merkle.verity."
-// merkleoffsetInParentXattr is the extended attribute name specifying the
-// offset of child hash in its parent's Merkle tree.
-const merkleOffsetInParentXattr = "user.merkle.offset"
+ // merkleOffsetInParentXattr is the extended attribute name specifying the
+ // offset of the child hash in its parent's Merkle tree.
+ merkleOffsetInParentXattr = "user.merkle.offset"
-// merkleSizeXattr is the extended attribute name specifying the size of data
-// hashed by the corresponding Merkle tree. For a file, it's the size of the
-// whole file. For a directory, it's the size of all its children's hashes.
-const merkleSizeXattr = "user.merkle.size"
+ // merkleSizeXattr is the extended attribute name specifying the size of data
+ // hashed by the corresponding Merkle tree. For a regular file, this is the
+ // file size. For a directory, this is the size of all its children's hashes.
+ merkleSizeXattr = "user.merkle.size"
-// sizeOfStringInt32 is the size for a 32 bit integer stored as string in
-// extended attributes. The maximum value of a 32 bit integer is 10 digits.
-const sizeOfStringInt32 = 10
+ // sizeOfStringInt32 is the size for a 32 bit integer stored as string in
+ // extended attributes. The maximum value of a 32 bit integer has 10 digits.
+ sizeOfStringInt32 = 10
+)
-// noCrashOnVerificationFailure indicates whether the sandbox should panic
-// whenever verification fails. If true, an error is returned instead of
-// panicking. This should only be set for tests.
-// TOOD(b/165661693): Decide whether to panic or return error based on this
-// flag.
-var noCrashOnVerificationFailure bool
+var (
+ // noCrashOnVerificationFailure indicates whether the sandbox should panic
+ // whenever verification fails. If true, an error is returned instead of
+ // panicking. This should only be set for tests.
+ //
+ // TODO(b/165661693): Decide whether to panic or return error based on this
+ // flag.
+ noCrashOnVerificationFailure bool
+
+ // verityMu synchronizes concurrent operations that enable verity and perform
+ // verification checks.
+ verityMu sync.RWMutex
+)
+
+// HashAlgorithm is a type specifying the algorithm used to hash the file
+// content.
+type HashAlgorithm int
+
+// Currently supported hashing algorithms include SHA256 and SHA512.
+const (
+ SHA256 HashAlgorithm = iota
+ SHA512
+)
+
+func (alg HashAlgorithm) toLinuxHashAlg() int {
+ switch alg {
+ case SHA256:
+ return linux.FS_VERITY_HASH_ALG_SHA256
+ case SHA512:
+ return linux.FS_VERITY_HASH_ALG_SHA512
+ default:
+ return 0
+ }
+}
// FilesystemType implements vfs.FilesystemType.
//
@@ -97,6 +129,10 @@ type filesystem struct {
// stores the root hash of the whole file system in bytes.
rootDentry *dentry
+ // alg is the algorithms used to hash the files in the verity file
+ // system.
+ alg HashAlgorithm
+
// renameMu synchronizes renaming with non-renaming operations in order
// to ensure consistent lock ordering between dentry.dirMu in different
// dentries.
@@ -125,6 +161,10 @@ type InternalFilesystemOptions struct {
// LowerName is the name of the filesystem wrapped by verity fs.
LowerName string
+ // Alg is the algorithms used to hash the files in the verity file
+ // system.
+ Alg HashAlgorithm
+
// RootHash is the root hash of the overall verity file system.
RootHash []byte
@@ -153,10 +193,10 @@ func (FilesystemType) Release(ctx context.Context) {}
// alertIntegrityViolation alerts a violation of integrity, which usually means
// unexpected modification to the file system is detected. In
-// noCrashOnVerificationFailure mode, it returns an error, otherwise it panic.
-func alertIntegrityViolation(err error, msg string) error {
+// noCrashOnVerificationFailure mode, it returns EIO, otherwise it panic.
+func alertIntegrityViolation(msg string) error {
if noCrashOnVerificationFailure {
- return err
+ return syserror.EIO
}
panic(msg)
}
@@ -183,6 +223,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
fs := &filesystem{
creds: creds.Fork(),
+ alg: iopts.Alg,
lowerMount: mnt,
allowRuntimeEnable: iopts.AllowRuntimeEnable,
}
@@ -236,7 +277,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// the root Merkle file, or it's never generated.
fs.vfsfs.DecRef(ctx)
d.DecRef(ctx)
- return nil, nil, alertIntegrityViolation(err, "Failed to find root Merkle file")
+ return nil, nil, alertIntegrityViolation("Failed to find root Merkle file")
}
d.lowerMerkleVD = lowerMerkleVD
@@ -289,11 +330,12 @@ type dentry struct {
// fs is the owning filesystem. fs is immutable.
fs *filesystem
- // mode, uid and gid are the file mode, owner, and group of the file in
- // the underlying file system.
+ // mode, uid, gid and size are the file mode, owner, group, and size of
+ // the file in the underlying file system.
mode uint32
uid uint32
gid uint32
+ size uint32
// parent is the dentry corresponding to this dentry's parent directory.
// name is this dentry's name in parent. If this dentry is a filesystem
@@ -331,22 +373,25 @@ func (fs *filesystem) newDentry() *dentry {
fs: fs,
}
d.vfsd.Init(d)
+ refsvfs2.Register(d)
return d
}
// IncRef implements vfs.DentryImpl.IncRef.
func (d *dentry) IncRef() {
- atomic.AddInt64(&d.refs, 1)
+ r := atomic.AddInt64(&d.refs, 1)
+ refsvfs2.LogIncRef(d, r)
}
// TryIncRef implements vfs.DentryImpl.TryIncRef.
func (d *dentry) TryIncRef() bool {
for {
- refs := atomic.LoadInt64(&d.refs)
- if refs <= 0 {
+ r := atomic.LoadInt64(&d.refs)
+ if r <= 0 {
return false
}
- if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) {
+ if atomic.CompareAndSwapInt64(&d.refs, r, r+1) {
+ refsvfs2.LogTryIncRef(d, r+1)
return true
}
}
@@ -354,15 +399,27 @@ func (d *dentry) TryIncRef() bool {
// DecRef implements vfs.DentryImpl.DecRef.
func (d *dentry) DecRef(ctx context.Context) {
- if refs := atomic.AddInt64(&d.refs, -1); refs == 0 {
+ r := atomic.AddInt64(&d.refs, -1)
+ refsvfs2.LogDecRef(d, r)
+ if r == 0 {
d.fs.renameMu.Lock()
d.checkDropLocked(ctx)
d.fs.renameMu.Unlock()
- } else if refs < 0 {
+ } else if r < 0 {
panic("verity.dentry.DecRef() called without holding a reference")
}
}
+func (d *dentry) decRefLocked(ctx context.Context) {
+ r := atomic.AddInt64(&d.refs, -1)
+ refsvfs2.LogDecRef(d, r)
+ if r == 0 {
+ d.checkDropLocked(ctx)
+ } else if r < 0 {
+ panic("verity.dentry.decRefLocked() called without holding a reference")
+ }
+}
+
// checkDropLocked should be called after d's reference count becomes 0 or it
// becomes deleted.
func (d *dentry) checkDropLocked(ctx context.Context) {
@@ -393,23 +450,36 @@ func (d *dentry) destroyLocked(ctx context.Context) {
if d.lowerVD.Ok() {
d.lowerVD.DecRef(ctx)
}
-
if d.lowerMerkleVD.Ok() {
d.lowerMerkleVD.DecRef(ctx)
}
-
if d.parent != nil {
d.parent.dirMu.Lock()
if !d.vfsd.IsDead() {
delete(d.parent.children, d.name)
}
d.parent.dirMu.Unlock()
- if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 {
- d.parent.checkDropLocked(ctx)
- } else if refs < 0 {
- panic("verity.dentry.DecRef() called without holding a reference")
- }
+ d.parent.decRefLocked(ctx)
}
+ refsvfs2.Unregister(d)
+}
+
+// RefType implements refsvfs2.CheckedObject.Type.
+func (d *dentry) RefType() string {
+ return "verity.dentry"
+}
+
+// LeakMessage implements refsvfs2.CheckedObject.LeakMessage.
+func (d *dentry) LeakMessage() string {
+ return fmt.Sprintf("[verity.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs))
+}
+
+// LogRefs implements refsvfs2.CheckedObject.LogRefs.
+//
+// This should only be set to true for debugging purposes, as it can generate an
+// extremely large amount of output and drastically degrade performance.
+func (d *dentry) LogRefs() bool {
+ return false
}
// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
@@ -448,6 +518,16 @@ func (d *dentry) verityEnabled() bool {
return !d.fs.allowRuntimeEnable || len(d.hash) != 0
}
+// getLowerAt returns the dentry in the underlying file system, which is
+// represented by filename relative to d.
+func (d *dentry) getLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, filename string) (vfs.VirtualDentry, error) {
+ return vfsObj.GetDentryAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.lowerVD,
+ Start: d.lowerVD,
+ Path: fspath.Parse(filename),
+ }, &vfs.GetDentryOptions{})
+}
+
func (d *dentry) readlink(ctx context.Context) (string, error) {
return d.fs.vfsfs.VirtualFilesystem().ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{
Root: d.lowerVD,
@@ -489,6 +569,10 @@ type fileDescription struct {
// directory that contains the current file/directory. This is only used
// if allowRuntimeEnable is set to true.
parentMerkleWriter *vfs.FileDescription
+
+ // off is the file offset. off is protected by mu.
+ mu sync.Mutex `state:"nosave"`
+ off int64
}
// Release implements vfs.FileDescriptionImpl.Release.
@@ -524,6 +608,32 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions)
return syserror.EPERM
}
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ n := int64(0)
+ switch whence {
+ case linux.SEEK_SET:
+ // use offset as specified
+ case linux.SEEK_CUR:
+ n = fd.off
+ case linux.SEEK_END:
+ n = int64(fd.d.size)
+ default:
+ return 0, syserror.EINVAL
+ }
+ if offset > math.MaxInt64-n {
+ return 0, syserror.EINVAL
+ }
+ offset += n
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ fd.off = offset
+ return offset, nil
+}
+
// generateMerkle generates a Merkle tree file for fd. If fd points to a file
// /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The hash
// of the generated Merkle tree and the data size is returned. If fd points to
@@ -546,6 +656,8 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64,
params := &merkletree.GenerateParams{
TreeReader: &merkleReader,
TreeWriter: &merkleWriter,
+ //TODO(b/156980949): Support passing other hash algorithms.
+ HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(),
}
switch atomic.LoadUint32(&fd.d.mode) & linux.S_IFMT {
@@ -611,7 +723,7 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (ui
// or directory other than the root, the parent Merkle tree file should
// have also been initialized.
if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || (fd.parentMerkleWriter == nil && fd.d != fd.d.fs.rootDentry) {
- return 0, alertIntegrityViolation(syserror.EIO, "Unexpected verity fd: missing expected underlying fds")
+ return 0, alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds")
}
hash, dataSize, err := fd.generateMerkle(ctx)
@@ -657,6 +769,9 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (ui
// measureVerity returns the hash of fd, saved in verityDigest.
func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, verityDigest usermem.Addr) (uintptr, error) {
t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ return 0, syserror.EINVAL
+ }
var metadata linux.DigestMetadata
// If allowRuntimeEnable is true, an empty fd.d.hash indicates that
@@ -667,7 +782,7 @@ func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, ve
if fd.d.fs.allowRuntimeEnable {
return 0, syserror.ENODATA
}
- return 0, alertIntegrityViolation(syserror.ENODATA, "Ioctl measureVerity: no hash found")
+ return 0, alertIntegrityViolation("Ioctl measureVerity: no hash found")
}
// The first part of VerityDigest is the metadata.
@@ -702,6 +817,9 @@ func (fd *fileDescription) verityFlags(ctx context.Context, uio usermem.IO, flag
}
t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ return 0, syserror.EINVAL
+ }
_, err := primitive.CopyInt32Out(t, flags, f)
return 0, err
}
@@ -722,6 +840,16 @@ func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.
}
}
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // Implement Read with PRead by setting offset.
+ fd.mu.Lock()
+ n, err := fd.PRead(ctx, dst, fd.off, opts)
+ fd.off += n
+ fd.mu.Unlock()
+ return n, err
+}
+
// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
// No need to verify if the file is not enabled yet in
@@ -742,7 +870,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
// contains the expected xattrs. If the xattr does not exist, it
// indicates unexpected modifications to the file system.
if err == syserror.ENODATA {
- return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err))
+ return 0, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err))
}
if err != nil {
return 0, err
@@ -752,7 +880,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
// unexpected modifications to the file system.
size, err := strconv.Atoi(dataSize)
if err != nil {
- return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err))
+ return 0, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err))
}
dataReader := vfs.FileReadWriteSeeker{
@@ -766,25 +894,37 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
}
n, err := merkletree.Verify(&merkletree.VerifyParams{
- Out: dst.Writer(ctx),
- File: &dataReader,
- Tree: &merkleReader,
- Size: int64(size),
- Name: fd.d.name,
- Mode: fd.d.mode,
- UID: fd.d.uid,
- GID: fd.d.gid,
+ Out: dst.Writer(ctx),
+ File: &dataReader,
+ Tree: &merkleReader,
+ Size: int64(size),
+ Name: fd.d.name,
+ Mode: fd.d.mode,
+ UID: fd.d.uid,
+ GID: fd.d.gid,
+ //TODO(b/156980949): Support passing other hash algorithms.
+ HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(),
ReadOffset: offset,
ReadSize: dst.NumBytes(),
Expected: fd.d.hash,
DataAndTreeInSameFile: false,
})
if err != nil {
- return 0, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification failed: %v", err))
+ return 0, alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err))
}
return n, err
}
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.EROFS
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.EROFS
+}
+
// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
return fd.lowerFD.LockPOSIX(ctx, uid, t, start, length, whence, block)
diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go
index e301d35f5..b2da9dd96 100644
--- a/pkg/sentry/fsimpl/verity/verity_test.go
+++ b/pkg/sentry/fsimpl/verity/verity_test.go
@@ -25,10 +25,12 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -41,11 +43,18 @@ const maxDataSize = 100000
// newVerityRoot creates a new verity mount, and returns the root. The
// underlying file system is tmpfs. If the error is not nil, then cleanup
// should be called when the root is no longer needed.
-func newVerityRoot(ctx context.Context, t *testing.T) (*vfs.VirtualFilesystem, vfs.VirtualDentry, error) {
+func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *kernel.Task, error) {
+ k, err := testutil.Boot()
+ if err != nil {
+ t.Fatalf("testutil.Boot: %v", err)
+ }
+
+ ctx := k.SupervisorContext()
+
rand.Seed(time.Now().UnixNano())
vfsObj := &vfs.VirtualFilesystem{}
if err := vfsObj.Init(ctx); err != nil {
- return nil, vfs.VirtualDentry{}, fmt.Errorf("VFS init: %v", err)
+ return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err)
}
vfsObj.MustRegisterFilesystemType("verity", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
@@ -61,22 +70,33 @@ func newVerityRoot(ctx context.Context, t *testing.T) (*vfs.VirtualFilesystem, v
InternalData: InternalFilesystemOptions{
RootMerkleFileName: rootMerkleFilename,
LowerName: "tmpfs",
+ Alg: hashAlg,
AllowRuntimeEnable: true,
NoCrashOnVerificationFailure: true,
},
},
})
if err != nil {
- return nil, vfs.VirtualDentry{}, fmt.Errorf("NewMountNamespace: %v", err)
+ return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("NewMountNamespace: %v", err)
}
root := mntns.Root()
root.IncRef()
+
+ // Use lowerRoot in the task as we modify the lower file system
+ // directly in many tests.
+ lowerRoot := root.Dentry().Impl().(*dentry).lowerVD
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ task, err := testutil.CreateTask(ctx, "name", tc, mntns, lowerRoot, lowerRoot)
+ if err != nil {
+ t.Fatalf("testutil.CreateTask: %v", err)
+ }
+
t.Helper()
t.Cleanup(func() {
root.DecRef(ctx)
mntns.DecRef(ctx)
})
- return vfsObj, root, nil
+ return vfsObj, root, task, nil
}
// newFileFD creates a new file in the verity mount, and returns the FD. The FD
@@ -142,207 +162,296 @@ func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) er
return nil
}
+var hashAlgs = []HashAlgorithm{SHA256, SHA512}
+
// TestOpen ensures that when a file is created, the corresponding Merkle tree
// file and the root Merkle tree file exist.
func TestOpen(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Ensure that the corresponding Merkle tree file is created.
- lowerRoot := root.Dentry().Impl().(*dentry).lowerVD
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: lowerRoot,
- Start: lowerRoot,
- Path: fspath.Parse(merklePrefix + filename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- }); err != nil {
- t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err)
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Ensure that the corresponding Merkle tree file is created.
+ lowerRoot := root.Dentry().Impl().(*dentry).lowerVD
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerRoot,
+ Start: lowerRoot,
+ Path: fspath.Parse(merklePrefix + filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ }); err != nil {
+ t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err)
+ }
+
+ // Ensure the root merkle tree file is created.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerRoot,
+ Start: lowerRoot,
+ Path: fspath.Parse(merklePrefix + rootMerkleFilename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ }); err != nil {
+ t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err)
+ }
}
+}
- // Ensure the root merkle tree file is created.
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: lowerRoot,
- Start: lowerRoot,
- Path: fspath.Parse(merklePrefix + rootMerkleFilename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- }); err != nil {
- t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err)
+// TestPReadUnmodifiedFileSucceeds ensures that pread from an untouched verity
+// file succeeds after enabling verity for it.
+func TestPReadUnmodifiedFileSucceeds(t *testing.T) {
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirm a normal read succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ buf := make([]byte, size)
+ n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{})
+ if err != nil && err != io.EOF {
+ t.Fatalf("fd.PRead: %v", err)
+ }
+
+ if n != int64(size) {
+ t.Errorf("fd.PRead got read length %d, want %d", n, size)
+ }
}
}
-// TestUnmodifiedFileSucceeds ensures that read from an untouched verity file
-// succeeds after enabling verity for it.
+// TestReadUnmodifiedFileSucceeds ensures that read from an untouched verity
+// file succeeds after enabling verity for it.
func TestReadUnmodifiedFileSucceeds(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Enable verity on the file and confirm a normal read succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
-
- buf := make([]byte, size)
- n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{})
- if err != nil && err != io.EOF {
- t.Fatalf("fd.PRead: %v", err)
- }
-
- if n != int64(size) {
- t.Errorf("fd.PRead got read length %d, want %d", n, size)
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirm a normal read succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ buf := make([]byte, size)
+ n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
+ if err != nil && err != io.EOF {
+ t.Fatalf("fd.Read: %v", err)
+ }
+
+ if n != int64(size) {
+ t.Errorf("fd.PRead got read length %d, want %d", n, size)
+ }
}
}
// TestReopenUnmodifiedFileSucceeds ensures that reopen an untouched verity file
// succeeds after enabling verity for it.
func TestReopenUnmodifiedFileSucceeds(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Enable verity on the file and confirms a normal read succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
-
- // Ensure reopening the verity enabled file succeeds.
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- Path: fspath.Parse(filename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- Mode: linux.ModeRegular,
- }); err != nil {
- t.Errorf("reopen enabled file failed: %v", err)
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirms a normal read succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Ensure reopening the verity enabled file succeeds.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ Mode: linux.ModeRegular,
+ }); err != nil {
+ t.Errorf("reopen enabled file failed: %v", err)
+ }
}
}
-// TestModifiedFileFails ensures that read from a modified verity file fails.
-func TestModifiedFileFails(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
-
- // Open a new lowerFD that's read/writable.
- lowerVD := fd.Impl().(*fileDescription).d.lowerVD
-
- lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: lowerVD,
- Start: lowerVD,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDWR,
- })
- if err != nil {
- t.Fatalf("OpenAt: %v", err)
- }
-
- if err := corruptRandomBit(ctx, lowerFD, size); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
+// TestPReadModifiedFileFails ensures that read from a modified verity file
+// fails.
+func TestPReadModifiedFileFails(t *testing.T) {
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Open a new lowerFD that's read/writable.
+ lowerVD := fd.Impl().(*fileDescription).d.lowerVD
+
+ lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerVD,
+ Start: lowerVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ if err := corruptRandomBit(ctx, lowerFD, size); err != nil {
+ t.Fatalf("corruptRandomBit: %v", err)
+ }
+
+ // Confirm that read from the modified file fails.
+ buf := make([]byte, size)
+ if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
+ t.Fatalf("fd.PRead succeeded, expected failure")
+ }
}
+}
- // Confirm that read from the modified file fails.
- buf := make([]byte, size)
- if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
- t.Fatalf("fd.PRead succeeded with modified file")
+// TestReadModifiedFileFails ensures that read from a modified verity file
+// fails.
+func TestReadModifiedFileFails(t *testing.T) {
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Open a new lowerFD that's read/writable.
+ lowerVD := fd.Impl().(*fileDescription).d.lowerVD
+
+ lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerVD,
+ Start: lowerVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ if err := corruptRandomBit(ctx, lowerFD, size); err != nil {
+ t.Fatalf("corruptRandomBit: %v", err)
+ }
+
+ // Confirm that read from the modified file fails.
+ buf := make([]byte, size)
+ if _, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}); err == nil {
+ t.Fatalf("fd.Read succeeded, expected failure")
+ }
}
}
// TestModifiedMerkleFails ensures that read from a verity file fails if the
// corresponding Merkle tree file is modified.
func TestModifiedMerkleFails(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
-
- // Open a new lowerMerkleFD that's read/writable.
- lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD
-
- lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: lowerMerkleVD,
- Start: lowerMerkleVD,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDWR,
- })
- if err != nil {
- t.Fatalf("OpenAt: %v", err)
- }
-
- // Flip a random bit in the Merkle tree file.
- stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{})
- if err != nil {
- t.Fatalf("stat: %v", err)
- }
- merkleSize := int(stat.Size)
- if err := corruptRandomBit(ctx, lowerMerkleFD, merkleSize); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
- }
-
- // Confirm that read from a file with modified Merkle tree fails.
- buf := make([]byte, size)
- if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
- fmt.Println(buf)
- t.Fatalf("fd.PRead succeeded with modified Merkle file")
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Open a new lowerMerkleFD that's read/writable.
+ lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD
+
+ lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerMerkleVD,
+ Start: lowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ // Flip a random bit in the Merkle tree file.
+ stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ t.Fatalf("stat: %v", err)
+ }
+ merkleSize := int(stat.Size)
+ if err := corruptRandomBit(ctx, lowerMerkleFD, merkleSize); err != nil {
+ t.Fatalf("corruptRandomBit: %v", err)
+ }
+
+ // Confirm that read from a file with modified Merkle tree fails.
+ buf := make([]byte, size)
+ if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
+ fmt.Println(buf)
+ t.Fatalf("fd.PRead succeeded with modified Merkle file")
+ }
}
}
@@ -350,142 +459,267 @@ func TestModifiedMerkleFails(t *testing.T) {
// verity enabled directory fails if the hashes related to the target file in
// the parent Merkle tree file is modified.
func TestModifiedParentMerkleFails(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
-
- // Enable verity on the parent directory.
- parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- })
- if err != nil {
- t.Fatalf("OpenAt: %v", err)
- }
-
- if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
-
- // Open a new lowerMerkleFD that's read/writable.
- parentLowerMerkleVD := fd.Impl().(*fileDescription).d.parent.lowerMerkleVD
-
- parentLowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: parentLowerMerkleVD,
- Start: parentLowerMerkleVD,
- }, &vfs.OpenOptions{
- Flags: linux.O_RDWR,
- })
- if err != nil {
- t.Fatalf("OpenAt: %v", err)
- }
-
- // Flip a random bit in the parent Merkle tree file.
- // This parent directory contains only one child, so any random
- // modification in the parent Merkle tree should cause verification
- // failure when opening the child file.
- stat, err := parentLowerMerkleFD.Stat(ctx, vfs.StatOptions{})
- if err != nil {
- t.Fatalf("stat: %v", err)
- }
- parentMerkleSize := int(stat.Size)
- if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
- }
-
- parentLowerMerkleFD.DecRef(ctx)
-
- // Ensure reopening the verity enabled file fails.
- if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
- Root: root,
- Start: root,
- Path: fspath.Parse(filename),
- }, &vfs.OpenOptions{
- Flags: linux.O_RDONLY,
- Mode: linux.ModeRegular,
- }); err == nil {
- t.Errorf("OpenAt file with modified parent Merkle succeeded")
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Enable verity on the parent directory.
+ parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Open a new lowerMerkleFD that's read/writable.
+ parentLowerMerkleVD := fd.Impl().(*fileDescription).d.parent.lowerMerkleVD
+
+ parentLowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: parentLowerMerkleVD,
+ Start: parentLowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ // Flip a random bit in the parent Merkle tree file.
+ // This parent directory contains only one child, so any random
+ // modification in the parent Merkle tree should cause verification
+ // failure when opening the child file.
+ stat, err := parentLowerMerkleFD.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ t.Fatalf("stat: %v", err)
+ }
+ parentMerkleSize := int(stat.Size)
+ if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil {
+ t.Fatalf("corruptRandomBit: %v", err)
+ }
+
+ parentLowerMerkleFD.DecRef(ctx)
+
+ // Ensure reopening the verity enabled file fails.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ Mode: linux.ModeRegular,
+ }); err == nil {
+ t.Errorf("OpenAt file with modified parent Merkle succeeded")
+ }
}
}
// TestUnmodifiedStatSucceeds ensures that stat of an untouched verity file
// succeeds after enabling verity for it.
func TestUnmodifiedStatSucceeds(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
-
- // Enable verity on the file and confirms stat succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("fd.Ioctl: %v", err)
- }
-
- if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil {
- t.Errorf("fd.Stat: %v", err)
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirms stat succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("fd.Ioctl: %v", err)
+ }
+
+ if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil {
+ t.Errorf("fd.Stat: %v", err)
+ }
}
}
// TestModifiedStatFails checks that getting stat for a file with modified stat
// should fail.
func TestModifiedStatFails(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
- if err != nil {
- t.Fatalf("newVerityRoot: %v", err)
- }
-
- filename := "verity-test-file"
- fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
- if err != nil {
- t.Fatalf("newFileFD: %v", err)
- }
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("fd.Ioctl: %v", err)
+ }
+
+ lowerFD := fd.Impl().(*fileDescription).lowerFD
+ // Change the stat of the underlying file, and check that stat fails.
+ if err := lowerFD.SetStat(ctx, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: uint32(linux.STATX_MODE),
+ Mode: 0777,
+ },
+ }); err != nil {
+ t.Fatalf("lowerFD.SetStat: %v", err)
+ }
- // Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("fd.Ioctl: %v", err)
+ if _, err := fd.Stat(ctx, vfs.StatOptions{}); err == nil {
+ t.Errorf("fd.Stat succeeded when it should fail")
+ }
}
+}
- lowerFD := fd.Impl().(*fileDescription).lowerFD
- // Change the stat of the underlying file, and check that stat fails.
- if err := lowerFD.SetStat(ctx, vfs.SetStatOptions{
- Stat: linux.Statx{
- Mask: uint32(linux.STATX_MODE),
- Mode: 0777,
+// TestOpenDeletedOrRenamedFileFails ensures that opening a deleted/renamed
+// verity enabled file or the corresponding Merkle tree file fails with the
+// verify error.
+func TestOpenDeletedFileFails(t *testing.T) {
+ testCases := []struct {
+ // Tests removing files is remove is true. Otherwise tests
+ // renaming files.
+ remove bool
+ // The original file is removed/renamed if changeFile is true.
+ changeFile bool
+ // The Merkle tree file is removed/renamed if changeMerkleFile
+ // is true.
+ changeMerkleFile bool
+ }{
+ {
+ remove: true,
+ changeFile: true,
+ changeMerkleFile: false,
+ },
+ {
+ remove: true,
+ changeFile: false,
+ changeMerkleFile: true,
+ },
+ {
+ remove: false,
+ changeFile: true,
+ changeMerkleFile: false,
+ },
+ {
+ remove: false,
+ changeFile: true,
+ changeMerkleFile: false,
},
- }); err != nil {
- t.Fatalf("lowerFD.SetStat: %v", err)
}
-
- if _, err := fd.Stat(ctx, vfs.StatOptions{}); err == nil {
- t.Errorf("fd.Stat succeeded when it should fail")
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("remove:%t", tc.remove), func(t *testing.T) {
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ rootLowerVD := root.Dentry().Impl().(*dentry).lowerVD
+ if tc.remove {
+ if tc.changeFile {
+ if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(filename),
+ }); err != nil {
+ t.Fatalf("UnlinkAt: %v", err)
+ }
+ }
+ if tc.changeMerkleFile {
+ if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(merklePrefix + filename),
+ }); err != nil {
+ t.Fatalf("UnlinkAt: %v", err)
+ }
+ }
+ } else {
+ newFilename := "renamed-test-file"
+ if tc.changeFile {
+ if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(filename),
+ }, &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(newFilename),
+ }, &vfs.RenameOptions{}); err != nil {
+ t.Fatalf("RenameAt: %v", err)
+ }
+ }
+ if tc.changeMerkleFile {
+ if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(merklePrefix + filename),
+ }, &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(merklePrefix + newFilename),
+ }, &vfs.RenameOptions{}); err != nil {
+ t.Fatalf("UnlinkAt: %v", err)
+ }
+ }
+ }
+
+ // Ensure reopening the verity enabled file fails.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ Mode: linux.ModeRegular,
+ }); err != syserror.EIO {
+ t.Errorf("got OpenAt error: %v, expected EIO", err)
+ }
+ }
+ })
}
}
diff --git a/pkg/sentry/hostfd/BUILD b/pkg/sentry/hostfd/BUILD
index 364a78306..db3b0d0a0 100644
--- a/pkg/sentry/hostfd/BUILD
+++ b/pkg/sentry/hostfd/BUILD
@@ -6,10 +6,12 @@ go_library(
name = "hostfd",
srcs = [
"hostfd.go",
+ "hostfd_linux.go",
"hostfd_unsafe.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
+ "//pkg/log",
"//pkg/safemem",
"//pkg/sync",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/sentry/hostfd/hostfd_linux.go b/pkg/sentry/hostfd/hostfd_linux.go
new file mode 100644
index 000000000..1cabc848f
--- /dev/null
+++ b/pkg/sentry/hostfd/hostfd_linux.go
@@ -0,0 +1,18 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostfd
+
+// maxIov is the maximum permitted size of a struct iovec array.
+const maxIov = 1024 // UIO_MAXIOV
diff --git a/pkg/sentry/hostfd/hostfd_unsafe.go b/pkg/sentry/hostfd/hostfd_unsafe.go
index cd4dc67fb..694371b1c 100644
--- a/pkg/sentry/hostfd/hostfd_unsafe.go
+++ b/pkg/sentry/hostfd/hostfd_unsafe.go
@@ -20,6 +20,7 @@ import (
"unsafe"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/safemem"
)
@@ -44,6 +45,10 @@ func Preadv2(fd int32, dsts safemem.BlockSeq, offset int64, flags uint32) (uint6
}
} else {
iovs := safemem.IovecsFromBlockSeq(dsts)
+ if len(iovs) > maxIov {
+ log.Debugf("hostfd.Preadv2: truncating from %d iovecs to %d", len(iovs), maxIov)
+ iovs = iovs[:maxIov]
+ }
n, _, e = syscall.Syscall6(unix.SYS_PREADV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags))
}
if e != 0 {
@@ -76,6 +81,10 @@ func Pwritev2(fd int32, srcs safemem.BlockSeq, offset int64, flags uint32) (uint
}
} else {
iovs := safemem.IovecsFromBlockSeq(srcs)
+ if len(iovs) > maxIov {
+ log.Debugf("hostfd.Preadv2: truncating from %d iovecs to %d", len(iovs), maxIov)
+ iovs = iovs[:maxIov]
+ }
n, _, e = syscall.Syscall6(unix.SYS_PWRITEV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags))
}
if e != 0 {
diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go
index fbe6d6aa6..f31277d30 100644
--- a/pkg/sentry/inet/inet.go
+++ b/pkg/sentry/inet/inet.go
@@ -32,9 +32,13 @@ type Stack interface {
InterfaceAddrs() map[int32][]InterfaceAddr
// AddInterfaceAddr adds an address to the network interface identified by
- // index.
+ // idx.
AddInterfaceAddr(idx int32, addr InterfaceAddr) error
+ // RemoveInterfaceAddr removes an address from the network interface
+ // identified by idx.
+ RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error
+
// SupportsIPv6 returns true if the stack supports IPv6 connectivity.
SupportsIPv6() bool
diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go
index 1779cc6f3..9ebeba8a3 100644
--- a/pkg/sentry/inet/test_stack.go
+++ b/pkg/sentry/inet/test_stack.go
@@ -15,6 +15,9 @@
package inet
import (
+ "bytes"
+ "fmt"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -58,6 +61,24 @@ func (s *TestStack) AddInterfaceAddr(idx int32, addr InterfaceAddr) error {
return nil
}
+// RemoveInterfaceAddr implements Stack.RemoveInterfaceAddr.
+func (s *TestStack) RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error {
+ interfaceAddrs, ok := s.InterfaceAddrsMap[idx]
+ if !ok {
+ return fmt.Errorf("unknown idx: %d", idx)
+ }
+
+ var filteredAddrs []InterfaceAddr
+ for _, interfaceAddr := range interfaceAddrs {
+ if !bytes.Equal(interfaceAddr.Addr, addr.Addr) {
+ filteredAddrs = append(filteredAddrs, addr)
+ }
+ }
+ s.InterfaceAddrsMap[idx] = filteredAddrs
+
+ return nil
+}
+
// SupportsIPv6 implements Stack.SupportsIPv6.
func (s *TestStack) SupportsIPv6() bool {
return s.SupportsIPv6Flag
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index c0de72eef..90dd4a047 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -79,7 +79,7 @@ go_template_instance(
out = "fd_table_refs.go",
package = "kernel",
prefix = "FDTable",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "FDTable",
},
@@ -90,7 +90,7 @@ go_template_instance(
out = "fs_context_refs.go",
package = "kernel",
prefix = "FSContext",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "FSContext",
},
@@ -101,7 +101,7 @@ go_template_instance(
out = "ipc_namespace_refs.go",
package = "kernel",
prefix = "IPCNamespace",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "IPCNamespace",
},
@@ -112,7 +112,7 @@ go_template_instance(
out = "process_group_refs.go",
package = "kernel",
prefix = "ProcessGroup",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "ProcessGroup",
},
@@ -123,7 +123,7 @@ go_template_instance(
out = "session_refs.go",
package = "kernel",
prefix = "Session",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "Session",
},
@@ -229,7 +229,7 @@ go_library(
"//pkg/marshal/primitive",
"//pkg/metric",
"//pkg/refs",
- "//pkg/refs_vfs2",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/secio",
"//pkg/sentry/arch",
diff --git a/pkg/sentry/kernel/abstract_socket_namespace.go b/pkg/sentry/kernel/abstract_socket_namespace.go
index 1b9721534..0ddbe5ff6 100644
--- a/pkg/sentry/kernel/abstract_socket_namespace.go
+++ b/pkg/sentry/kernel/abstract_socket_namespace.go
@@ -19,7 +19,7 @@ import (
"syscall"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/refs_vfs2"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -27,7 +27,7 @@ import (
// +stateify savable
type abstractEndpoint struct {
ep transport.BoundEndpoint
- socket refs_vfs2.RefCounter
+ socket refsvfs2.RefCounter
name string
ns *AbstractSocketNamespace
}
@@ -57,7 +57,7 @@ func NewAbstractSocketNamespace() *AbstractSocketNamespace {
// its backing socket.
type boundEndpoint struct {
transport.BoundEndpoint
- socket refs_vfs2.RefCounter
+ socket refsvfs2.RefCounter
}
// Release implements transport.BoundEndpoint.Release.
@@ -89,7 +89,7 @@ func (a *AbstractSocketNamespace) BoundEndpoint(name string) transport.BoundEndp
//
// When the last reference managed by socket is dropped, ep may be removed from the
// namespace.
-func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep transport.BoundEndpoint, socket refs_vfs2.RefCounter) error {
+func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep transport.BoundEndpoint, socket refsvfs2.RefCounter) error {
a.mu.Lock()
defer a.mu.Unlock()
@@ -109,7 +109,7 @@ func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep tran
// Remove removes the specified socket at name from the abstract socket
// namespace, if it has not yet been replaced.
-func (a *AbstractSocketNamespace) Remove(name string, socket refs_vfs2.RefCounter) {
+func (a *AbstractSocketNamespace) Remove(name string, socket refsvfs2.RefCounter) {
a.mu.Lock()
defer a.mu.Unlock()
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index 0ec7344cd..7aba31587 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -110,7 +110,7 @@ func (f *FDTable) saveDescriptorTable() map[int32]descriptor {
func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) {
ctx := context.Background()
- f.init() // Initialize table.
+ f.initNoLeakCheck() // Initialize table.
f.used = 0
for fd, d := range m {
if file, fileVFS2 := f.setAll(ctx, fd, d.file, d.fileVFS2, d.flags); file != nil || fileVFS2 != nil {
@@ -240,6 +240,10 @@ func (f *FDTable) String() string {
case fileVFS2 != nil:
vfsObj := fileVFS2.Mount().Filesystem().VirtualFilesystem()
+ vd := fileVFS2.VirtualDentry()
+ if vd.Dentry() == nil {
+ panic(fmt.Sprintf("fd %d (type %T) has nil dentry: %#v", fd, fileVFS2.Impl(), fileVFS2))
+ }
name, err := vfsObj.PathnameWithDeleted(ctx, vfs.VirtualDentry{}, fileVFS2.VirtualDentry())
if err != nil {
fmt.Fprintf(&buf, "<err: %v>\n", err)
diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go
index da79e6627..3476551f3 100644
--- a/pkg/sentry/kernel/fd_table_unsafe.go
+++ b/pkg/sentry/kernel/fd_table_unsafe.go
@@ -31,14 +31,21 @@ type descriptorTable struct {
slice unsafe.Pointer `state:".(map[int32]*descriptor)"`
}
-// init initializes the table.
+// initNoLeakCheck initializes the table without enabling leak checking.
//
-// TODO(gvisor.dev/1486): Enable leak check for FDTable.
-func (f *FDTable) init() {
+// This is used when loading an FDTable after S/R, during which the ref count
+// object itself will enable leak checking if necessary.
+func (f *FDTable) initNoLeakCheck() {
var slice []unsafe.Pointer // Empty slice.
atomic.StorePointer(&f.slice, unsafe.Pointer(&slice))
}
+// init initializes the table with leak checking.
+func (f *FDTable) init() {
+ f.initNoLeakCheck()
+ f.EnableLeakCheck()
+}
+
// get gets a file entry.
//
// The boolean indicates whether this was in range.
diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go
index d46d1e1c1..41fb2a784 100644
--- a/pkg/sentry/kernel/fs_context.go
+++ b/pkg/sentry/kernel/fs_context.go
@@ -130,13 +130,15 @@ func (f *FSContext) Fork() *FSContext {
f.root.IncRef()
}
- return &FSContext{
+ ctx := &FSContext{
cwd: f.cwd,
root: f.root,
cwdVFS2: f.cwdVFS2,
rootVFS2: f.rootVFS2,
umask: f.umask,
}
+ ctx.EnableLeakCheck()
+ return ctx
}
// WorkingDirectory returns the current working directory.
@@ -147,19 +149,23 @@ func (f *FSContext) WorkingDirectory() *fs.Dirent {
f.mu.Lock()
defer f.mu.Unlock()
- f.cwd.IncRef()
+ if f.cwd != nil {
+ f.cwd.IncRef()
+ }
return f.cwd
}
// WorkingDirectoryVFS2 returns the current working directory.
//
-// This will return nil if called after f is destroyed, otherwise it will return
-// a Dirent with a reference taken.
+// This will return an empty vfs.VirtualDentry if called after f is
+// destroyed, otherwise it will return a Dirent with a reference taken.
func (f *FSContext) WorkingDirectoryVFS2() vfs.VirtualDentry {
f.mu.Lock()
defer f.mu.Unlock()
- f.cwdVFS2.IncRef()
+ if f.cwdVFS2.Ok() {
+ f.cwdVFS2.IncRef()
+ }
return f.cwdVFS2
}
@@ -218,13 +224,15 @@ func (f *FSContext) RootDirectory() *fs.Dirent {
// RootDirectoryVFS2 returns the current filesystem root.
//
-// This will return nil if called after f is destroyed, otherwise it will return
-// a Dirent with a reference taken.
+// This will return an empty vfs.VirtualDentry if called after f is
+// destroyed, otherwise it will return a Dirent with a reference taken.
func (f *FSContext) RootDirectoryVFS2() vfs.VirtualDentry {
f.mu.Lock()
defer f.mu.Unlock()
- f.rootVFS2.IncRef()
+ if f.rootVFS2.Ok() {
+ f.rootVFS2.IncRef()
+ }
return f.rootVFS2
}
diff --git a/pkg/sentry/kernel/ipc_namespace.go b/pkg/sentry/kernel/ipc_namespace.go
index 3f34ee0db..b87e40dd1 100644
--- a/pkg/sentry/kernel/ipc_namespace.go
+++ b/pkg/sentry/kernel/ipc_namespace.go
@@ -55,7 +55,7 @@ func (i *IPCNamespace) ShmRegistry() *shm.Registry {
return i.shms
}
-// DecRef implements refs_vfs2.RefCounter.DecRef.
+// DecRef implements refsvfs2.RefCounter.DecRef.
func (i *IPCNamespace) DecRef(ctx context.Context) {
i.IPCNamespaceRefs.DecRef(func() {
i.shms.Release(ctx)
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 0eb2bf7bd..9b2be44d4 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -430,9 +430,8 @@ func (k *Kernel) Init(args InitKernelArgs) error {
// SaveTo saves the state of k to w.
//
// Preconditions: The kernel must be paused throughout the call to SaveTo.
-func (k *Kernel) SaveTo(w wire.Writer) error {
+func (k *Kernel) SaveTo(ctx context.Context, w wire.Writer) error {
saveStart := time.Now()
- ctx := k.SupervisorContext()
// Do not allow other Kernel methods to affect it while it's being saved.
k.extMu.Lock()
@@ -446,38 +445,55 @@ func (k *Kernel) SaveTo(w wire.Writer) error {
k.mf.StartEvictions()
k.mf.WaitForEvictions()
- // Flush write operations on open files so data reaches backing storage.
- // This must come after MemoryFile eviction since eviction may cause file
- // writes.
- if err := k.tasks.flushWritesToFiles(ctx); err != nil {
- return err
- }
+ if VFS2Enabled {
+ // Discard unsavable mappings, such as those for host file descriptors.
+ if err := k.invalidateUnsavableMappings(ctx); err != nil {
+ return fmt.Errorf("failed to invalidate unsavable mappings: %v", err)
+ }
- // Remove all epoll waiter objects from underlying wait queues.
- // NOTE: for programs to resume execution in future snapshot scenarios,
- // we will need to re-establish these waiter objects after saving.
- k.tasks.unregisterEpollWaiters(ctx)
+ // Prepare filesystems for saving. This must be done after
+ // invalidateUnsavableMappings(), since dropping memory mappings may
+ // affect filesystem state (e.g. page cache reference counts).
+ if err := k.vfs.PrepareSave(ctx); err != nil {
+ return err
+ }
+ } else {
+ // Flush cached file writes to backing storage. This must come after
+ // MemoryFile eviction since eviction may cause file writes.
+ if err := k.flushWritesToFiles(ctx); err != nil {
+ return err
+ }
- // Clear the dirent cache before saving because Dirents must be Loaded in a
- // particular order (parents before children), and Loading dirents from a cache
- // breaks that order.
- if err := k.flushMountSourceRefs(ctx); err != nil {
- return err
- }
+ // Remove all epoll waiter objects from underlying wait queues.
+ // NOTE: for programs to resume execution in future snapshot scenarios,
+ // we will need to re-establish these waiter objects after saving.
+ k.tasks.unregisterEpollWaiters(ctx)
- // Ensure that all inode and mount release operations have completed.
- fs.AsyncBarrier()
+ // Clear the dirent cache before saving because Dirents must be Loaded in a
+ // particular order (parents before children), and Loading dirents from a cache
+ // breaks that order.
+ if err := k.flushMountSourceRefs(ctx); err != nil {
+ return err
+ }
- // Once all fs work has completed (flushed references have all been released),
- // reset mount mappings. This allows individual mounts to save how inodes map
- // to filesystem resources. Without this, fs.Inodes cannot be restored.
- fs.SaveInodeMappings()
+ // Ensure that all inode and mount release operations have completed.
+ fs.AsyncBarrier()
- // Discard unsavable mappings, such as those for host file descriptors.
- // This must be done after waiting for "asynchronous fs work", which
- // includes async I/O that may touch application memory.
- if err := k.invalidateUnsavableMappings(ctx); err != nil {
- return fmt.Errorf("failed to invalidate unsavable mappings: %v", err)
+ // Once all fs work has completed (flushed references have all been released),
+ // reset mount mappings. This allows individual mounts to save how inodes map
+ // to filesystem resources. Without this, fs.Inodes cannot be restored.
+ fs.SaveInodeMappings()
+
+ // Discard unsavable mappings, such as those for host file descriptors.
+ // This must be done after waiting for "asynchronous fs work", which
+ // includes async I/O that may touch application memory.
+ //
+ // TODO(gvisor.dev/issue/1624): This rationale is believed to be
+ // obsolete since AIO callbacks are now waited-for by Kernel.Pause(),
+ // but this order is conservatively retained for VFS1.
+ if err := k.invalidateUnsavableMappings(ctx); err != nil {
+ return fmt.Errorf("failed to invalidate unsavable mappings: %v", err)
+ }
}
// Save the CPUID FeatureSet before the rest of the kernel so we can
@@ -486,14 +502,14 @@ func (k *Kernel) SaveTo(w wire.Writer) error {
//
// N.B. This will also be saved along with the full kernel save below.
cpuidStart := time.Now()
- if _, err := state.Save(k.SupervisorContext(), w, k.FeatureSet()); err != nil {
+ if _, err := state.Save(ctx, w, k.FeatureSet()); err != nil {
return err
}
log.Infof("CPUID save took [%s].", time.Since(cpuidStart))
// Save the kernel state.
kernelStart := time.Now()
- stats, err := state.Save(k.SupervisorContext(), w, k)
+ stats, err := state.Save(ctx, w, k)
if err != nil {
return err
}
@@ -502,7 +518,7 @@ func (k *Kernel) SaveTo(w wire.Writer) error {
// Save the memory file's state.
memoryStart := time.Now()
- if err := k.mf.SaveTo(k.SupervisorContext(), w); err != nil {
+ if err := k.mf.SaveTo(ctx, w); err != nil {
return err
}
log.Infof("Memory save took [%s].", time.Since(memoryStart))
@@ -514,11 +530,9 @@ func (k *Kernel) SaveTo(w wire.Writer) error {
// flushMountSourceRefs flushes the MountSources for all mounted filesystems
// and open FDs.
+//
+// Preconditions: !VFS2Enabled.
func (k *Kernel) flushMountSourceRefs(ctx context.Context) error {
- if VFS2Enabled {
- return nil // Not relevant.
- }
-
// Flush all mount sources for currently mounted filesystems in each task.
flushed := make(map[*fs.MountNamespace]struct{})
k.tasks.mu.RLock()
@@ -561,13 +575,9 @@ func (ts *TaskSet) forEachFDPaused(ctx context.Context, f func(*fs.File, *vfs.Fi
return err
}
-func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error {
- // TODO(gvisor.dev/issue/1663): Add save support for VFS2.
- if VFS2Enabled {
- return nil
- }
-
- return ts.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error {
+// Preconditions: !VFS2Enabled.
+func (k *Kernel) flushWritesToFiles(ctx context.Context) error {
+ return k.tasks.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error {
if flags := file.Flags(); !flags.Write {
return nil
}
@@ -589,37 +599,8 @@ func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error {
})
}
-// Preconditions: The kernel must be paused.
-func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error {
- invalidated := make(map[*mm.MemoryManager]struct{})
- k.tasks.mu.RLock()
- defer k.tasks.mu.RUnlock()
- for t := range k.tasks.Root.tids {
- // We can skip locking Task.mu here since the kernel is paused.
- if mm := t.tc.MemoryManager; mm != nil {
- if _, ok := invalidated[mm]; !ok {
- if err := mm.InvalidateUnsavable(ctx); err != nil {
- return err
- }
- invalidated[mm] = struct{}{}
- }
- }
- // I really wish we just had a sync.Map of all MMs...
- if r, ok := t.runState.(*runSyscallAfterExecStop); ok {
- if err := r.tc.MemoryManager.InvalidateUnsavable(ctx); err != nil {
- return err
- }
- }
- }
- return nil
-}
-
+// Preconditions: !VFS2Enabled.
func (ts *TaskSet) unregisterEpollWaiters(ctx context.Context) {
- // TODO(gvisor.dev/issue/1663): Add save support for VFS2.
- if VFS2Enabled {
- return
- }
-
ts.mu.RLock()
defer ts.mu.RUnlock()
@@ -644,8 +625,33 @@ func (ts *TaskSet) unregisterEpollWaiters(ctx context.Context) {
}
}
+// Preconditions: The kernel must be paused.
+func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error {
+ invalidated := make(map[*mm.MemoryManager]struct{})
+ k.tasks.mu.RLock()
+ defer k.tasks.mu.RUnlock()
+ for t := range k.tasks.Root.tids {
+ // We can skip locking Task.mu here since the kernel is paused.
+ if mm := t.tc.MemoryManager; mm != nil {
+ if _, ok := invalidated[mm]; !ok {
+ if err := mm.InvalidateUnsavable(ctx); err != nil {
+ return err
+ }
+ invalidated[mm] = struct{}{}
+ }
+ }
+ // I really wish we just had a sync.Map of all MMs...
+ if r, ok := t.runState.(*runSyscallAfterExecStop); ok {
+ if err := r.tc.MemoryManager.InvalidateUnsavable(ctx); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
// LoadFrom returns a new Kernel loaded from args.
-func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clocks) error {
+func (k *Kernel) LoadFrom(ctx context.Context, r wire.Reader, net inet.Stack, clocks sentrytime.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error {
loadStart := time.Now()
initAppCores := k.applicationCores
@@ -656,7 +662,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock
// don't need to explicitly install it in the Kernel.
cpuidStart := time.Now()
var features cpuid.FeatureSet
- if _, err := state.Load(k.SupervisorContext(), r, &features); err != nil {
+ if _, err := state.Load(ctx, r, &features); err != nil {
return err
}
log.Infof("CPUID load took [%s].", time.Since(cpuidStart))
@@ -671,7 +677,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock
// Load the kernel state.
kernelStart := time.Now()
- stats, err := state.Load(k.SupervisorContext(), r, k)
+ stats, err := state.Load(ctx, r, k)
if err != nil {
return err
}
@@ -684,7 +690,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock
// Load the memory file's state.
memoryStart := time.Now()
- if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil {
+ if err := k.mf.LoadFrom(ctx, r); err != nil {
return err
}
log.Infof("Memory load took [%s].", time.Since(memoryStart))
@@ -696,11 +702,17 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock
net.Resume()
}
- // Ensure that all pending asynchronous work is complete:
- // - namedpipe opening
- // - inode file opening
- if err := fs.AsyncErrorBarrier(); err != nil {
- return err
+ if VFS2Enabled {
+ if err := k.vfs.CompleteRestore(ctx, vfsOpts); err != nil {
+ return err
+ }
+ } else {
+ // Ensure that all pending asynchronous work is complete:
+ // - namedpipe opening
+ // - inode file opening
+ if err := fs.AsyncErrorBarrier(); err != nil {
+ return err
+ }
}
tcpip.AsyncLoading.Wait()
diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go
index ce0db5583..d6fb0fdb8 100644
--- a/pkg/sentry/kernel/pipe/node_test.go
+++ b/pkg/sentry/kernel/pipe/node_test.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
)
type sleeper struct {
@@ -66,7 +65,8 @@ func testOpenOrDie(ctx context.Context, t *testing.T, n fs.InodeOperations, flag
d := fs.NewDirent(ctx, inode, "pipe")
file, err := n.GetFile(ctx, d, flags)
if err != nil {
- t.Fatalf("open with flags %+v failed: %v", flags, err)
+ t.Errorf("open with flags %+v failed: %v", flags, err)
+ return nil, err
}
if doneChan != nil {
doneChan <- struct{}{}
@@ -85,11 +85,11 @@ func testOpen(ctx context.Context, t *testing.T, n fs.InodeOperations, flags fs.
}
func newNamedPipe(t *testing.T) *Pipe {
- return NewPipe(true, DefaultPipeSize, usermem.PageSize)
+ return NewPipe(true, DefaultPipeSize)
}
func newAnonPipe(t *testing.T) *Pipe {
- return NewPipe(false, DefaultPipeSize, usermem.PageSize)
+ return NewPipe(false, DefaultPipeSize)
}
// assertRecvBlocks ensures that a recv attempt on c blocks for at least
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 67beb0ad6..b989e14c7 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -26,18 +26,27 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
const (
// MinimumPipeSize is a hard limit of the minimum size of a pipe.
- MinimumPipeSize = 64 << 10
+ // It corresponds to fs/pipe.c:pipe_min_size.
+ MinimumPipeSize = usermem.PageSize
+
+ // MaximumPipeSize is a hard limit on the maximum size of a pipe.
+ // It corresponds to fs/pipe.c:pipe_max_size.
+ MaximumPipeSize = 1048576
// DefaultPipeSize is the system-wide default size of a pipe in bytes.
- DefaultPipeSize = MinimumPipeSize
+ // It corresponds to pipe_fs_i.h:PIPE_DEF_BUFFERS.
+ DefaultPipeSize = 16 * usermem.PageSize
- // MaximumPipeSize is a hard limit on the maximum size of a pipe.
- MaximumPipeSize = 8 << 20
+ // atomicIOBytes is the maximum number of bytes that the pipe will
+ // guarantee atomic reads or writes atomically.
+ // It corresponds to limits.h:PIPE_BUF.
+ atomicIOBytes = 4096
)
// Pipe is an encapsulation of a platform-independent pipe.
@@ -53,12 +62,6 @@ type Pipe struct {
// This value is immutable.
isNamed bool
- // atomicIOBytes is the maximum number of bytes that the pipe will
- // guarantee atomic reads or writes atomically.
- //
- // This value is immutable.
- atomicIOBytes int64
-
// The number of active readers for this pipe.
//
// Access atomically.
@@ -94,47 +97,34 @@ type Pipe struct {
// NewPipe initializes and returns a pipe.
//
-// N.B. The size and atomicIOBytes will be bounded.
-func NewPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *Pipe {
+// N.B. The size will be bounded.
+func NewPipe(isNamed bool, sizeBytes int64) *Pipe {
if sizeBytes < MinimumPipeSize {
sizeBytes = MinimumPipeSize
}
if sizeBytes > MaximumPipeSize {
sizeBytes = MaximumPipeSize
}
- if atomicIOBytes <= 0 {
- atomicIOBytes = 1
- }
- if atomicIOBytes > sizeBytes {
- atomicIOBytes = sizeBytes
- }
var p Pipe
- initPipe(&p, isNamed, sizeBytes, atomicIOBytes)
+ initPipe(&p, isNamed, sizeBytes)
return &p
}
-func initPipe(pipe *Pipe, isNamed bool, sizeBytes, atomicIOBytes int64) {
+func initPipe(pipe *Pipe, isNamed bool, sizeBytes int64) {
if sizeBytes < MinimumPipeSize {
sizeBytes = MinimumPipeSize
}
if sizeBytes > MaximumPipeSize {
sizeBytes = MaximumPipeSize
}
- if atomicIOBytes <= 0 {
- atomicIOBytes = 1
- }
- if atomicIOBytes > sizeBytes {
- atomicIOBytes = sizeBytes
- }
pipe.isNamed = isNamed
pipe.max = sizeBytes
- pipe.atomicIOBytes = atomicIOBytes
}
// NewConnectedPipe initializes a pipe and returns a pair of objects
// representing the read and write ends of the pipe.
-func NewConnectedPipe(ctx context.Context, sizeBytes, atomicIOBytes int64) (*fs.File, *fs.File) {
- p := NewPipe(false /* isNamed */, sizeBytes, atomicIOBytes)
+func NewConnectedPipe(ctx context.Context, sizeBytes int64) (*fs.File, *fs.File) {
+ p := NewPipe(false /* isNamed */, sizeBytes)
// Build an fs.Dirent for the pipe which will be shared by both
// returned files.
@@ -264,7 +254,7 @@ func (p *Pipe) writeLocked(ctx context.Context, ops writeOps) (int64, error) {
wanted := ops.left()
avail := p.max - p.view.Size()
if wanted > avail {
- if wanted <= p.atomicIOBytes {
+ if wanted <= atomicIOBytes {
return 0, syserror.ErrWouldBlock
}
ops.limit(avail)
diff --git a/pkg/sentry/kernel/pipe/pipe_test.go b/pkg/sentry/kernel/pipe/pipe_test.go
index fe97e9800..3dd739080 100644
--- a/pkg/sentry/kernel/pipe/pipe_test.go
+++ b/pkg/sentry/kernel/pipe/pipe_test.go
@@ -26,7 +26,7 @@ import (
func TestPipeRW(t *testing.T) {
ctx := contexttest.Context(t)
- r, w := NewConnectedPipe(ctx, 65536, 4096)
+ r, w := NewConnectedPipe(ctx, 65536)
defer r.DecRef(ctx)
defer w.DecRef(ctx)
@@ -46,7 +46,7 @@ func TestPipeRW(t *testing.T) {
func TestPipeReadBlock(t *testing.T) {
ctx := contexttest.Context(t)
- r, w := NewConnectedPipe(ctx, 65536, 4096)
+ r, w := NewConnectedPipe(ctx, 65536)
defer r.DecRef(ctx)
defer w.DecRef(ctx)
@@ -61,7 +61,7 @@ func TestPipeWriteBlock(t *testing.T) {
const capacity = MinimumPipeSize
ctx := contexttest.Context(t)
- r, w := NewConnectedPipe(ctx, capacity, atomicIOBytes)
+ r, w := NewConnectedPipe(ctx, capacity)
defer r.DecRef(ctx)
defer w.DecRef(ctx)
@@ -76,7 +76,7 @@ func TestPipeWriteUntilEnd(t *testing.T) {
const atomicIOBytes = 2
ctx := contexttest.Context(t)
- r, w := NewConnectedPipe(ctx, atomicIOBytes, atomicIOBytes)
+ r, w := NewConnectedPipe(ctx, atomicIOBytes)
defer r.DecRef(ctx)
defer w.DecRef(ctx)
@@ -116,7 +116,8 @@ func TestPipeWriteUntilEnd(t *testing.T) {
}
}
if err != nil {
- t.Fatalf("Readv: got unexpected error %v", err)
+ t.Errorf("Readv: got unexpected error %v", err)
+ return
}
}
}()
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
index 1a152142b..7b23cbe86 100644
--- a/pkg/sentry/kernel/pipe/vfs.go
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -33,6 +33,8 @@ import (
// VFSPipe represents the actual pipe, analagous to an inode. VFSPipes should
// not be copied.
+//
+// +stateify savable
type VFSPipe struct {
// mu protects the fields below.
mu sync.Mutex `state:"nosave"`
@@ -52,9 +54,9 @@ type VFSPipe struct {
}
// NewVFSPipe returns an initialized VFSPipe.
-func NewVFSPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *VFSPipe {
+func NewVFSPipe(isNamed bool, sizeBytes int64) *VFSPipe {
var vp VFSPipe
- initPipe(&vp.pipe, isNamed, sizeBytes, atomicIOBytes)
+ initPipe(&vp.pipe, isNamed, sizeBytes)
return &vp
}
@@ -164,6 +166,8 @@ func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, l
// VFSPipeFD implements vfs.FileDescriptionImpl for pipes. It also implements
// non-atomic usermem.IO methods, allowing it to be passed as usermem.IO to
// other FileDescriptions for splice(2) and tee(2).
+//
+// +stateify savable
type VFSPipeFD struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go
index 1145faf13..1abfe2201 100644
--- a/pkg/sentry/kernel/ptrace.go
+++ b/pkg/sentry/kernel/ptrace.go
@@ -1000,7 +1000,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
// at the address specified by the data parameter, and the return value
// is the error flag." - ptrace(2)
word := t.Arch().Native(0)
- if _, err := word.CopyIn(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr); err != nil {
+ if _, err := word.CopyIn(target.CopyContext(t, usermem.IOOpts{IgnorePermissions: true}), addr); err != nil {
return err
}
_, err := word.CopyOut(t, data)
@@ -1008,7 +1008,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
case linux.PTRACE_POKETEXT, linux.PTRACE_POKEDATA:
word := t.Arch().Native(uintptr(data))
- _, err := word.CopyOut(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr)
+ _, err := word.CopyOut(target.CopyContext(t, usermem.IOOpts{IgnorePermissions: true}), addr)
return err
case linux.PTRACE_GETREGSET:
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go
index c00fa1138..b99c0bffa 100644
--- a/pkg/sentry/kernel/semaphore/semaphore.go
+++ b/pkg/sentry/kernel/semaphore/semaphore.go
@@ -103,6 +103,7 @@ type waiter struct {
waiterEntry
// value represents how much resource the waiter needs to wake up.
+ // The value is either 0 or negative.
value int16
ch chan struct{}
}
@@ -283,6 +284,33 @@ func (s *Set) Change(ctx context.Context, creds *auth.Credentials, owner fs.File
return nil
}
+// GetStat extracts semid_ds information from the set.
+func (s *Set) GetStat(creds *auth.Credentials) (*linux.SemidDS, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // "The calling process must have read permission on the semaphore set."
+ if !s.checkPerms(creds, fs.PermMask{Read: true}) {
+ return nil, syserror.EACCES
+ }
+
+ ds := &linux.SemidDS{
+ SemPerm: linux.IPCPerm{
+ Key: uint32(s.key),
+ UID: uint32(creds.UserNamespace.MapFromKUID(s.owner.UID)),
+ GID: uint32(creds.UserNamespace.MapFromKGID(s.owner.GID)),
+ CUID: uint32(creds.UserNamespace.MapFromKUID(s.creator.UID)),
+ CGID: uint32(creds.UserNamespace.MapFromKGID(s.creator.GID)),
+ Mode: uint16(s.perms.LinuxMode()),
+ Seq: 0, // IPC sequence not supported.
+ },
+ SemOTime: s.opTime.TimeT(),
+ SemCTime: s.changeTime.TimeT(),
+ SemNSems: uint64(s.Size()),
+ }
+ return ds, nil
+}
+
// SetVal overrides a semaphore value, waking up waiters as needed.
func (s *Set) SetVal(ctx context.Context, num int32, val int16, creds *auth.Credentials, pid int32) error {
if val < 0 || val > valueMax {
@@ -320,7 +348,7 @@ func (s *Set) SetValAll(ctx context.Context, vals []uint16, creds *auth.Credenti
}
for _, val := range vals {
- if val < 0 || val > valueMax {
+ if val > valueMax {
return syserror.ERANGE
}
}
@@ -396,6 +424,42 @@ func (s *Set) GetPID(num int32, creds *auth.Credentials) (int32, error) {
return sem.pid, nil
}
+func (s *Set) countWaiters(num int32, creds *auth.Credentials, pred func(w *waiter) bool) (uint16, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // The calling process must have read permission on the semaphore set.
+ if !s.checkPerms(creds, fs.PermMask{Read: true}) {
+ return 0, syserror.EACCES
+ }
+
+ sem := s.findSem(num)
+ if sem == nil {
+ return 0, syserror.ERANGE
+ }
+ var cnt uint16
+ for w := sem.waiters.Front(); w != nil; w = w.Next() {
+ if pred(w) {
+ cnt++
+ }
+ }
+ return cnt, nil
+}
+
+// CountZeroWaiters returns number of waiters waiting for the sem's value to increase.
+func (s *Set) CountZeroWaiters(num int32, creds *auth.Credentials) (uint16, error) {
+ return s.countWaiters(num, creds, func(w *waiter) bool {
+ return w.value == 0
+ })
+}
+
+// CountNegativeWaiters returns number of waiters waiting for the sem to go to zero.
+func (s *Set) CountNegativeWaiters(num int32, creds *auth.Credentials) (uint16, error) {
+ return s.countWaiters(num, creds, func(w *waiter) bool {
+ return w.value < 0
+ })
+}
+
// ExecuteOps attempts to execute a list of operations to the set. It only
// succeeds when all operations can be applied. No changes are made if it fails.
//
@@ -548,11 +612,18 @@ func (s *Set) destroy() {
}
}
+func abs(val int16) int16 {
+ if val < 0 {
+ return -val
+ }
+ return val
+}
+
// wakeWaiters goes over all waiters and checks which of them can be notified.
func (s *sem) wakeWaiters() {
// Note that this will release all waiters waiting for 0 too.
for w := s.waiters.Front(); w != nil; {
- if s.value < w.value {
+ if s.value < abs(w.value) {
// Still blocked, skip it.
w = w.Next()
continue
diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go
index df5c8421b..5bddb0a36 100644
--- a/pkg/sentry/kernel/sessions.go
+++ b/pkg/sentry/kernel/sessions.go
@@ -477,20 +477,20 @@ func (tg *ThreadGroup) Session() *Session {
//
// If this group isn't visible in this namespace, zero will be returned. It is
// the callers responsibility to check that before using this function.
-func (pidns *PIDNamespace) IDOfSession(s *Session) SessionID {
- pidns.owner.mu.RLock()
- defer pidns.owner.mu.RUnlock()
- return pidns.sids[s]
+func (ns *PIDNamespace) IDOfSession(s *Session) SessionID {
+ ns.owner.mu.RLock()
+ defer ns.owner.mu.RUnlock()
+ return ns.sids[s]
}
// SessionWithID returns the Session with the given ID in the PID namespace ns,
// or nil if that given ID is not defined in this namespace.
//
// A reference is not taken on the session.
-func (pidns *PIDNamespace) SessionWithID(id SessionID) *Session {
- pidns.owner.mu.RLock()
- defer pidns.owner.mu.RUnlock()
- return pidns.sessions[id]
+func (ns *PIDNamespace) SessionWithID(id SessionID) *Session {
+ ns.owner.mu.RLock()
+ defer ns.owner.mu.RUnlock()
+ return ns.sessions[id]
}
// ProcessGroup returns the ThreadGroup's ProcessGroup.
@@ -505,18 +505,18 @@ func (tg *ThreadGroup) ProcessGroup() *ProcessGroup {
// IDOfProcessGroup returns the process group assigned to pg in PID namespace ns.
//
// The same constraints apply as IDOfSession.
-func (pidns *PIDNamespace) IDOfProcessGroup(pg *ProcessGroup) ProcessGroupID {
- pidns.owner.mu.RLock()
- defer pidns.owner.mu.RUnlock()
- return pidns.pgids[pg]
+func (ns *PIDNamespace) IDOfProcessGroup(pg *ProcessGroup) ProcessGroupID {
+ ns.owner.mu.RLock()
+ defer ns.owner.mu.RUnlock()
+ return ns.pgids[pg]
}
// ProcessGroupWithID returns the ProcessGroup with the given ID in the PID
// namespace ns, or nil if that given ID is not defined in this namespace.
//
// A reference is not taken on the process group.
-func (pidns *PIDNamespace) ProcessGroupWithID(id ProcessGroupID) *ProcessGroup {
- pidns.owner.mu.RLock()
- defer pidns.owner.mu.RUnlock()
- return pidns.processGroups[id]
+func (ns *PIDNamespace) ProcessGroupWithID(id ProcessGroupID) *ProcessGroup {
+ ns.owner.mu.RLock()
+ defer ns.owner.mu.RUnlock()
+ return ns.processGroups[id]
}
diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD
index f8a382fd8..80a592c8f 100644
--- a/pkg/sentry/kernel/shm/BUILD
+++ b/pkg/sentry/kernel/shm/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "shm_refs.go",
package = "shm",
prefix = "Shm",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "Shm",
},
@@ -27,7 +27,7 @@ go_library(
"//pkg/context",
"//pkg/log",
"//pkg/refs",
- "//pkg/refs_vfs2",
+ "//pkg/refsvfs2",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index 682080c14..527344162 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -355,7 +355,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
}
if opts.ChildSetTID {
ctid := nt.ThreadID()
- ctid.CopyOut(nt.AsCopyContext(usermem.IOOpts{AddressSpaceActive: false}), opts.ChildTID)
+ ctid.CopyOut(nt.CopyContext(t, usermem.IOOpts{AddressSpaceActive: false}), opts.ChildTID)
}
ntid := t.tg.pidns.IDOfTask(nt)
if opts.ParentSetTID {
diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go
index ce134bf54..94dabbcd8 100644
--- a/pkg/sentry/kernel/task_usermem.go
+++ b/pkg/sentry/kernel/task_usermem.go
@@ -18,7 +18,8 @@ import (
"math"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -281,29 +282,89 @@ func (t *Task) IovecsIOSequence(addr usermem.Addr, iovcnt int, opts usermem.IOOp
}, nil
}
-// copyContext implements marshal.CopyContext. It wraps a task to allow copying
-// memory to and from the task memory with custom usermem.IOOpts.
-type copyContext struct {
- *Task
+type taskCopyContext struct {
+ ctx context.Context
+ t *Task
opts usermem.IOOpts
}
-// AsCopyContext wraps the task and returns it as CopyContext.
-func (t *Task) AsCopyContext(opts usermem.IOOpts) marshal.CopyContext {
- return &copyContext{t, opts}
+// CopyContext returns a marshal.CopyContext that copies to/from t's address
+// space using opts.
+func (t *Task) CopyContext(ctx context.Context, opts usermem.IOOpts) *taskCopyContext {
+ return &taskCopyContext{
+ ctx: ctx,
+ t: t,
+ opts: opts,
+ }
+}
+
+// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer.
+func (cc *taskCopyContext) CopyScratchBuffer(size int) []byte {
+ if ctxTask, ok := cc.ctx.(*Task); ok {
+ return ctxTask.CopyScratchBuffer(size)
+ }
+ return make([]byte, size)
+}
+
+func (cc *taskCopyContext) getMemoryManager() (*mm.MemoryManager, error) {
+ cc.t.mu.Lock()
+ tmm := cc.t.MemoryManager()
+ cc.t.mu.Unlock()
+ if !tmm.IncUsers() {
+ return nil, syserror.EFAULT
+ }
+ return tmm, nil
+}
+
+// CopyInBytes implements marshal.CopyContext.CopyInBytes.
+func (cc *taskCopyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) {
+ tmm, err := cc.getMemoryManager()
+ if err != nil {
+ return 0, err
+ }
+ defer tmm.DecUsers(cc.ctx)
+ return tmm.CopyIn(cc.ctx, addr, dst, cc.opts)
+}
+
+// CopyOutBytes implements marshal.CopyContext.CopyOutBytes.
+func (cc *taskCopyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) {
+ tmm, err := cc.getMemoryManager()
+ if err != nil {
+ return 0, err
+ }
+ defer tmm.DecUsers(cc.ctx)
+ return tmm.CopyOut(cc.ctx, addr, src, cc.opts)
+}
+
+type ownTaskCopyContext struct {
+ t *Task
+ opts usermem.IOOpts
+}
+
+// OwnCopyContext returns a marshal.CopyContext that copies to/from t's address
+// space using opts. The returned CopyContext may only be used by t's task
+// goroutine.
+//
+// Since t already implements marshal.CopyContext, this is only needed to
+// override the usermem.IOOpts used for the copy.
+func (t *Task) OwnCopyContext(opts usermem.IOOpts) *ownTaskCopyContext {
+ return &ownTaskCopyContext{
+ t: t,
+ opts: opts,
+ }
}
-// CopyInString copies a string in from the task's memory.
-func (t *copyContext) CopyInString(addr usermem.Addr, maxLen int) (string, error) {
- return usermem.CopyStringIn(t, t.MemoryManager(), addr, maxLen, t.opts)
+// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer.
+func (cc *ownTaskCopyContext) CopyScratchBuffer(size int) []byte {
+ return cc.t.CopyScratchBuffer(size)
}
-// CopyInBytes copies task memory into dst from an IO context.
-func (t *copyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) {
- return t.MemoryManager().CopyIn(t, addr, dst, t.opts)
+// CopyInBytes implements marshal.CopyContext.CopyInBytes.
+func (cc *ownTaskCopyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) {
+ return cc.t.MemoryManager().CopyIn(cc.t, addr, dst, cc.opts)
}
-// CopyOutBytes copies src into task memoryfrom an IO context.
-func (t *copyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) {
- return t.MemoryManager().CopyOut(t, addr, src, t.opts)
+// CopyOutBytes implements marshal.CopyContext.CopyOutBytes.
+func (cc *ownTaskCopyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) {
+ return cc.t.MemoryManager().CopyOut(cc.t, addr, src, cc.opts)
}
diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go
index 9bc452e67..9e5c2d26f 100644
--- a/pkg/sentry/kernel/vdso.go
+++ b/pkg/sentry/kernel/vdso.go
@@ -115,7 +115,7 @@ func (v *VDSOParamPage) incrementSeq(paramPage safemem.Block) error {
}
if old != v.seq {
- return fmt.Errorf("unexpected VDSOParamPage seq value: got %d expected %d. Application may hang or get incorrect time from the VDSO.", old, v.seq)
+ return fmt.Errorf("unexpected VDSOParamPage seq value: got %d expected %d; application may hang or get incorrect time from the VDSO", old, v.seq)
}
v.seq = next
diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD
index b4a47ccca..6dbeccfe2 100644
--- a/pkg/sentry/mm/BUILD
+++ b/pkg/sentry/mm/BUILD
@@ -78,7 +78,7 @@ go_template_instance(
out = "aio_mappable_refs.go",
package = "mm",
prefix = "aioMappable",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "aioMappable",
},
@@ -89,7 +89,7 @@ go_template_instance(
out = "special_mappable_refs.go",
package = "mm",
prefix = "SpecialMappable",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "SpecialMappable",
},
@@ -127,6 +127,7 @@ go_library(
"//pkg/context",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safecopy",
"//pkg/safemem",
"//pkg/sentry/arch",
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
index 0a54dd30d..acad4c793 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
@@ -79,6 +79,18 @@ func bluepillStopGuest(c *vCPU) {
c.runData.requestInterruptWindow = 0
}
+// bluepillSigBus is reponsible for injecting NMI to trigger sigbus.
+//
+//go:nosplit
+func bluepillSigBus(c *vCPU) {
+ if _, _, errno := syscall.RawSyscall( // escapes: no.
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_NMI, 0); errno != 0 {
+ throw("NMI injection failed")
+ }
+}
+
// bluepillReadyStopGuest checks whether the current vCPU is ready for interrupt injection.
//
//go:nosplit
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go
index 58f3d6fdd..965ad66b5 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.go
@@ -27,15 +27,20 @@ var (
// The action for bluepillSignal is changed by sigaction().
bluepillSignal = syscall.SIGILL
- // vcpuSErr is the event of system error.
- vcpuSErr = kvmVcpuEvents{
+ // vcpuSErrBounce is the event of system error for bouncing KVM.
+ vcpuSErrBounce = kvmVcpuEvents{
exception: exception{
sErrPending: 1,
- sErrHasEsr: 0,
- pad: [6]uint8{0, 0, 0, 0, 0, 0},
- sErrEsr: 1,
},
- rsvd: [12]uint32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
+ }
+
+ // vcpuSErrNMI is the event of system error to trigger sigbus.
+ vcpuSErrNMI = kvmVcpuEvents{
+ exception: exception{
+ sErrPending: 1,
+ sErrHasEsr: 1,
+ sErrEsr: _ESR_ELx_SERR_NMI,
+ },
}
)
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
index b35c930e2..9433d4da5 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
@@ -80,11 +80,24 @@ func getHypercallID(addr uintptr) int {
//
//go:nosplit
func bluepillStopGuest(c *vCPU) {
- if _, _, errno := syscall.RawSyscall(
+ if _, _, errno := syscall.RawSyscall( // escapes: no.
syscall.SYS_IOCTL,
uintptr(c.fd),
_KVM_SET_VCPU_EVENTS,
- uintptr(unsafe.Pointer(&vcpuSErr))); errno != 0 {
+ uintptr(unsafe.Pointer(&vcpuSErrBounce))); errno != 0 {
+ throw("sErr injection failed")
+ }
+}
+
+// bluepillSigBus is reponsible for injecting sError to trigger sigbus.
+//
+//go:nosplit
+func bluepillSigBus(c *vCPU) {
+ if _, _, errno := syscall.RawSyscall( // escapes: no.
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_VCPU_EVENTS,
+ uintptr(unsafe.Pointer(&vcpuSErrNMI))); errno != 0 {
throw("sErr injection failed")
}
}
diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go
index eb05950cd..75085ac6a 100644
--- a/pkg/sentry/platform/kvm/bluepill_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go
@@ -146,12 +146,7 @@ func bluepillHandler(context unsafe.Pointer) {
// MMIO exit we receive EFAULT from the run ioctl. We
// always inject an NMI here since we may be in kernel
// mode and have interrupts disabled.
- if _, _, errno := syscall.RawSyscall( // escapes: no.
- syscall.SYS_IOCTL,
- uintptr(c.fd),
- _KVM_NMI, 0); errno != 0 {
- throw("NMI injection failed")
- }
+ bluepillSigBus(c)
continue // Rerun vCPU.
default:
throw("run failed")
diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go
index dd45ad10b..5979aef97 100644
--- a/pkg/sentry/platform/kvm/kvm.go
+++ b/pkg/sentry/platform/kvm/kvm.go
@@ -158,8 +158,7 @@ func (*KVM) MaxUserAddress() usermem.Addr {
// NewAddressSpace returns a new pagetable root.
func (k *KVM) NewAddressSpace(_ interface{}) (platform.AddressSpace, <-chan struct{}, error) {
// Allocate page tables and install system mappings.
- pageTables := pagetables.New(newAllocator())
- k.machine.mapUpperHalf(pageTables)
+ pageTables := pagetables.NewWithUpper(newAllocator(), k.machine.upperSharedPageTables, ring0.KernelStartAddress)
// Return the new address space.
return &addressSpace{
diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go
index 84df0f878..b060d9544 100644
--- a/pkg/sentry/platform/kvm/kvm_const_arm64.go
+++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go
@@ -38,6 +38,8 @@ const (
_KVM_ARM64_REGS_SCTLR_EL1 = 0x603000000013c080
_KVM_ARM64_REGS_CPACR_EL1 = 0x603000000013c082
_KVM_ARM64_REGS_VBAR_EL1 = 0x603000000013c600
+ _KVM_ARM64_REGS_TIMER_CNT = 0x603000000013df1a
+ _KVM_ARM64_REGS_CNTFRQ_EL0 = 0x603000000013df00
)
// Arm64: Architectural Feature Access Control Register EL1.
@@ -149,6 +151,9 @@ const (
_ESR_SEGV_PEMERR_L1 = 0xd
_ESR_SEGV_PEMERR_L2 = 0xe
_ESR_SEGV_PEMERR_L3 = 0xf
+
+ // Custom ISS field definitions for system error.
+ _ESR_ELx_SERR_NMI = 0x1
)
// Arm64: MMIO base address used to dispatch hypercalls.
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index 61ed24d01..e2fffc99b 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/procid"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ ktime "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -40,6 +41,9 @@ type machine struct {
// slots are currently being updated, and the caller should retry.
nextSlot uint32
+ // upperSharedPageTables tracks the read-only shared upper of all the pagetables.
+ upperSharedPageTables *pagetables.PageTables
+
// kernel is the set of global structures.
kernel ring0.Kernel
@@ -198,9 +202,7 @@ func newMachine(vm int) (*machine, error) {
log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs)
m.vCPUsByTID = make(map[uint64]*vCPU)
m.vCPUsByID = make([]*vCPU, m.maxVCPUs)
- m.kernel.Init(ring0.KernelOpts{
- PageTables: pagetables.New(newAllocator()),
- }, m.maxVCPUs)
+ m.kernel.Init(m.maxVCPUs)
// Pull the maximum slots.
maxSlots, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_MEMSLOTS)
@@ -212,6 +214,13 @@ func newMachine(vm int) (*machine, error) {
log.Debugf("The maximum number of slots is %d.", m.maxSlots)
m.usedSlots = make([]uintptr, m.maxSlots)
+ // Create the upper shared pagetables and kernel(sentry) pagetables.
+ m.upperSharedPageTables = pagetables.New(newAllocator())
+ m.mapUpperHalf(m.upperSharedPageTables)
+ m.upperSharedPageTables.Allocator.(*allocator).base.Drain()
+ m.upperSharedPageTables.MarkReadOnlyShared()
+ m.kernel.PageTables = pagetables.NewWithUpper(newAllocator(), m.upperSharedPageTables, ring0.KernelStartAddress)
+
// Apply the physical mappings. Note that these mappings may point to
// guest physical addresses that are not actually available. These
// physical pages are mapped on demand, see kernel_unsafe.go.
@@ -225,7 +234,6 @@ func newMachine(vm int) (*machine, error) {
return true // Keep iterating.
})
- m.mapUpperHalf(m.kernel.PageTables)
var physicalRegionsReadOnly []physicalRegion
var physicalRegionsAvailable []physicalRegion
@@ -625,3 +633,35 @@ func (c *vCPU) BounceToKernel() {
func (c *vCPU) BounceToHost() {
c.bounce(true)
}
+
+// setSystemTimeLegacy calibrates and sets an approximate system time.
+func (c *vCPU) setSystemTimeLegacy() error {
+ const minIterations = 10
+ minimum := uint64(0)
+ for iter := 0; ; iter++ {
+ // Try to set the TSC to an estimate of where it will be
+ // on the host during a "fast" system call iteration.
+ start := uint64(ktime.Rdtsc())
+ if err := c.setTSC(start + (minimum / 2)); err != nil {
+ return err
+ }
+ // See if this is our new minimum call time. Note that this
+ // serves two functions: one, we make sure that we are
+ // accurately predicting the offset we need to set. Second, we
+ // don't want to do the final set on a slow call, which could
+ // produce a really bad result.
+ end := uint64(ktime.Rdtsc())
+ if end < start {
+ continue // Totally bogus: unstable TSC?
+ }
+ current := end - start
+ if current < minimum || iter == 0 {
+ minimum = current // Set our new minimum.
+ }
+ // Is this past minIterations and within ~10% of minimum?
+ upperThreshold := (((minimum << 3) + minimum) >> 3)
+ if iter >= minIterations && current <= upperThreshold {
+ return nil
+ }
+ }
+}
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index c67127d95..8e03c310d 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -252,38 +252,6 @@ func (c *vCPU) setSystemTime() error {
}
}
-// setSystemTimeLegacy calibrates and sets an approximate system time.
-func (c *vCPU) setSystemTimeLegacy() error {
- const minIterations = 10
- minimum := uint64(0)
- for iter := 0; ; iter++ {
- // Try to set the TSC to an estimate of where it will be
- // on the host during a "fast" system call iteration.
- start := uint64(ktime.Rdtsc())
- if err := c.setTSC(start + (minimum / 2)); err != nil {
- return err
- }
- // See if this is our new minimum call time. Note that this
- // serves two functions: one, we make sure that we are
- // accurately predicting the offset we need to set. Second, we
- // don't want to do the final set on a slow call, which could
- // produce a really bad result.
- end := uint64(ktime.Rdtsc())
- if end < start {
- continue // Totally bogus: unstable TSC?
- }
- current := end - start
- if current < minimum || iter == 0 {
- minimum = current // Set our new minimum.
- }
- // Is this past minIterations and within ~10% of minimum?
- upperThreshold := (((minimum << 3) + minimum) >> 3)
- if iter >= minIterations && current <= upperThreshold {
- return nil
- }
- }
-}
-
// nonCanonical generates a canonical address return.
//
//go:nosplit
@@ -464,30 +432,27 @@ func availableRegionsForSetMem() (phyRegions []physicalRegion) {
return physicalRegions
}
-var execRegions = func() (regions []region) {
+func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
+ // Map all the executible regions so that all the entry functions
+ // are mapped in the upper half.
applyVirtualRegions(func(vr virtualRegion) {
if excludeVirtualRegion(vr) || vr.filename == "[vsyscall]" {
return
}
+
if vr.accessType.Execute {
- regions = append(regions, vr.region)
+ r := vr.region
+ physical, length, ok := translateToPhysical(r.virtual)
+ if !ok || length < r.length {
+ panic("impossible translation")
+ }
+ pageTable.Map(
+ usermem.Addr(ring0.KernelStartAddress|r.virtual),
+ r.length,
+ pagetables.MapOpts{AccessType: usermem.Execute},
+ physical)
}
})
- return
-}()
-
-func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
- for _, r := range execRegions {
- physical, length, ok := translateToPhysical(r.virtual)
- if !ok || length < r.length {
- panic("impossilbe translation")
- }
- pageTable.Map(
- usermem.Addr(ring0.KernelStartAddress|r.virtual),
- r.length,
- pagetables.MapOpts{AccessType: usermem.Execute},
- physical)
- }
for start, end := range m.kernel.EntryRegions() {
regionLen := end - start
physical, length, ok := translateToPhysical(start)
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index a163f956d..fd92c3873 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -159,9 +159,33 @@ func (c *vCPU) initArchState() error {
}
c.floatingPointState = arch.NewFloatingPointData()
+
+ return c.setSystemTime()
+}
+
+// setTSC sets the counter Virtual Offset.
+func (c *vCPU) setTSC(value uint64) error {
+ var (
+ reg kvmOneReg
+ data uint64
+ )
+
+ reg.addr = uint64(reflect.ValueOf(&data).Pointer())
+ reg.id = _KVM_ARM64_REGS_TIMER_CNT
+ data = uint64(value)
+
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
return nil
}
+// setSystemTime sets the vCPU to the system time.
+func (c *vCPU) setSystemTime() error {
+ return c.setSystemTimeLegacy()
+}
+
//go:nosplit
func (c *vCPU) loadSegments(tid uint64) {
// TODO(gvisor.dev/issue/1238): TLS is not supported.
@@ -197,7 +221,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Pc) {
return nonCanonical(regs.Pc, int32(syscall.SIGSEGV), info)
} else if !ring0.IsCanonical(regs.Sp) {
- return nonCanonical(regs.Sp, int32(syscall.SIGBUS), info)
+ return nonCanonical(regs.Sp, int32(syscall.SIGSEGV), info)
}
// Assign PCIDs.
@@ -233,10 +257,13 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
case ring0.PageFault:
return c.fault(int32(syscall.SIGSEGV), info)
+ case ring0.El0ErrNMI:
+ return c.fault(int32(syscall.SIGBUS), info)
case ring0.Vector(bounce): // ring0.VirtualizationException
return usermem.NoAccess, platform.ErrContextInterrupt
- case ring0.El0Sync_undef,
- ring0.El1Sync_undef:
+ case ring0.El0SyncUndef:
+ return c.fault(int32(syscall.SIGILL), info)
+ case ring0.El1SyncUndef:
*info = arch.SignalInfo{
Signo: int32(syscall.SIGILL),
Code: 1, // ILL_ILLOPC (illegal opcode).
diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go
index 87a573cc4..327d48465 100644
--- a/pkg/sentry/platform/ring0/aarch64.go
+++ b/pkg/sentry/platform/ring0/aarch64.go
@@ -58,46 +58,55 @@ type Vector uintptr
// Exception vectors.
const (
- El1SyncInvalid = iota
- El1IrqInvalid
- El1FiqInvalid
- El1ErrorInvalid
+ El1InvSync = iota
+ El1InvIrq
+ El1InvFiq
+ El1InvError
+
El1Sync
El1Irq
El1Fiq
- El1Error
+ El1Err
+
El0Sync
El0Irq
El0Fiq
- El0Error
- El0Sync_invalid
- El0Irq_invalid
- El0Fiq_invalid
- El0Error_invalid
- El1Sync_da
- El1Sync_ia
- El1Sync_sp_pc
- El1Sync_undef
- El1Sync_dbg
- El1Sync_inv
- El0Sync_svc
- El0Sync_da
- El0Sync_ia
- El0Sync_fpsimd_acc
- El0Sync_sve_acc
- El0Sync_sys
- El0Sync_sp_pc
- El0Sync_undef
- El0Sync_dbg
- El0Sync_inv
+ El0Err
+
+ El0InvSync
+ El0InvIrq
+ El0InvFiq
+ El0InvErr
+
+ El1SyncDa
+ El1SyncIa
+ El1SyncSpPc
+ El1SyncUndef
+ El1SyncDbg
+ El1SyncInv
+
+ El0SyncSVC
+ El0SyncDa
+ El0SyncIa
+ El0SyncFpsimdAcc
+ El0SyncSveAcc
+ El0SyncSys
+ El0SyncSpPc
+ El0SyncUndef
+ El0SyncDbg
+ El0SyncInv
+
+ El0ErrNMI
+ El0ErrBounce
+
_NR_INTERRUPTS
)
// System call vectors.
const (
- Syscall Vector = El0Sync_svc
- PageFault Vector = El0Sync_da
- VirtualizationException Vector = El0Error
+ Syscall Vector = El0SyncSVC
+ PageFault Vector = El0SyncDa
+ VirtualizationException Vector = El0ErrBounce
)
// VirtualAddressBits returns the number bits available for virtual addresses.
diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go
index e6daf24df..f9765771e 100644
--- a/pkg/sentry/platform/ring0/defs.go
+++ b/pkg/sentry/platform/ring0/defs.go
@@ -23,6 +23,9 @@ import (
//
// This contains global state, shared by multiple CPUs.
type Kernel struct {
+ // PageTables are the kernel pagetables; this must be provided.
+ PageTables *pagetables.PageTables
+
KernelArchState
}
diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go
index 00899273e..7a2275558 100644
--- a/pkg/sentry/platform/ring0/defs_amd64.go
+++ b/pkg/sentry/platform/ring0/defs_amd64.go
@@ -66,17 +66,9 @@ var (
KernelDataSegment SegmentDescriptor
)
-// KernelOpts has initialization options for the kernel.
-type KernelOpts struct {
- // PageTables are the kernel pagetables; this must be provided.
- PageTables *pagetables.PageTables
-}
-
// KernelArchState contains architecture-specific state.
type KernelArchState struct {
- KernelOpts
-
- // cpuEntries is array of kernelEntry for all cpus
+ // cpuEntries is array of kernelEntry for all cpus.
cpuEntries []kernelEntry
// globalIDT is our set of interrupt gates.
diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go
index 508236e46..a014dcbc0 100644
--- a/pkg/sentry/platform/ring0/defs_arm64.go
+++ b/pkg/sentry/platform/ring0/defs_arm64.go
@@ -32,15 +32,8 @@ var (
KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1)
)
-// KernelOpts has initialization options for the kernel.
-type KernelOpts struct {
- // PageTables are the kernel pagetables; this must be provided.
- PageTables *pagetables.PageTables
-}
-
// KernelArchState contains architecture-specific state.
type KernelArchState struct {
- KernelOpts
}
// CPUArchState contains CPU-specific arch state.
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index 2370a9276..f489ad352 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -288,6 +288,10 @@
#define ESR_ELx_WFx_ISS_WFE (UL(1) << 0)
#define ESR_ELx_xVC_IMM_MASK ((1UL << 16) - 1)
+/* ISS field definitions for system error */
+#define ESR_ELx_SERR_MASK (0x1)
+#define ESR_ELx_SERR_NMI (0x1)
+
// LOAD_KERNEL_ADDRESS loads a kernel address.
#define LOAD_KERNEL_ADDRESS(from, to) \
MOVD from, to; \
@@ -366,6 +370,19 @@
MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \
LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack.
+// EXCEPTION_WITH_ERROR is a common exception handler function.
+#define EXCEPTION_WITH_ERROR(user, vector) \
+ WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
+ WORD $0xd538601a; \ //MRS FAR_EL1, R26
+ MOVD R26, CPU_FAULT_ADDR(RSV_REG); \
+ MOVD $user, R3; \
+ MOVD R3, CPU_ERROR_TYPE(RSV_REG); \ // Set error type to user.
+ MOVD $vector, R3; \
+ MOVD R3, CPU_VECTOR_CODE(RSV_REG); \
+ MRS ESR_EL1, R3; \
+ MOVD R3, CPU_ERROR_CODE(RSV_REG); \
+ B ·kernelExitToEl1(SB);
+
// storeAppASID writes the application's asid value.
TEXT ·storeAppASID(SB),NOSPLIT,$0-8
MOVD asid+0(FP), R1
@@ -503,6 +520,10 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0
MOVD CPU_REGISTERS+PTRACE_PC(RSV_REG), R1
MSR R1, ELR_EL1
+ // restore sentry's tls.
+ MOVD CPU_REGISTERS+PTRACE_TLS(RSV_REG), R1
+ MSR R1, TPIDR_EL0
+
MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R1
MOVD R1, RSP
@@ -659,21 +680,7 @@ el0_svc:
el0_da:
el0_ia:
- WORD $0xd538d092 //MRS TPIDR_EL1, R18
- WORD $0xd538601a //MRS FAR_EL1, R26
-
- MOVD R26, CPU_FAULT_ADDR(RSV_REG)
-
- MOVD $1, R3
- MOVD R3, CPU_ERROR_TYPE(RSV_REG) // Set error type to user.
-
- MOVD $PageFault, R3
- MOVD R3, CPU_VECTOR_CODE(RSV_REG)
-
- MRS ESR_EL1, R3
- MOVD R3, CPU_ERROR_CODE(RSV_REG)
-
- B ·kernelExitToEl1(SB)
+ EXCEPTION_WITH_ERROR(1, PageFault)
el0_fpsimd_acc:
B ·Shutdown(SB)
@@ -688,10 +695,7 @@ el0_sp_pc:
B ·Shutdown(SB)
el0_undef:
- MOVD $El0Sync_undef, R3
- MOVD R3, CPU_VECTOR_CODE(RSV_REG)
-
- B ·kernelExitToEl1(SB)
+ EXCEPTION_WITH_ERROR(1, El0SyncUndef)
el0_dbg:
B ·Shutdown(SB)
@@ -707,6 +711,29 @@ TEXT ·El0_fiq(SB),NOSPLIT,$0
TEXT ·El0_error(SB),NOSPLIT,$0
KERNEL_ENTRY_FROM_EL0
+ WORD $0xd5385219 // MRS ESR_EL1, R25
+ AND $ESR_ELx_SERR_MASK, R25, R24
+ CMP $ESR_ELx_SERR_NMI, R24
+ BEQ el0_nmi
+ B el0_bounce
+el0_nmi:
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ WORD $0xd538601a //MRS FAR_EL1, R26
+
+ MOVD R26, CPU_FAULT_ADDR(RSV_REG)
+
+ MOVD $1, R3
+ MOVD R3, CPU_ERROR_TYPE(RSV_REG) // Set error type to user.
+
+ MOVD $El0ErrNMI, R3
+ MOVD R3, CPU_VECTOR_CODE(RSV_REG)
+
+ MRS ESR_EL1, R3
+ MOVD R3, CPU_ERROR_CODE(RSV_REG)
+
+ B ·kernelExitToEl1(SB)
+
+el0_bounce:
WORD $0xd538d092 //MRS TPIDR_EL1, R18
WORD $0xd538601a //MRS FAR_EL1, R26
@@ -718,7 +745,7 @@ TEXT ·El0_error(SB),NOSPLIT,$0
MOVD $VirtualizationException, R3
MOVD R3, CPU_VECTOR_CODE(RSV_REG)
- B ·HaltAndResume(SB)
+ B ·kernelExitToEl1(SB)
TEXT ·El0_sync_invalid(SB),NOSPLIT,$0
B ·Shutdown(SB)
diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go
index 264be23d3..292f9d0cc 100644
--- a/pkg/sentry/platform/ring0/kernel.go
+++ b/pkg/sentry/platform/ring0/kernel.go
@@ -16,11 +16,9 @@ package ring0
// Init initializes a new kernel.
//
-// N.B. that constraints on KernelOpts must be satisfied.
-//
//go:nosplit
-func (k *Kernel) Init(opts KernelOpts, maxCPUs int) {
- k.init(opts, maxCPUs)
+func (k *Kernel) Init(maxCPUs int) {
+ k.init(maxCPUs)
}
// Halt halts execution.
diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go
index 3a9dff4cc..b55dc29b3 100644
--- a/pkg/sentry/platform/ring0/kernel_amd64.go
+++ b/pkg/sentry/platform/ring0/kernel_amd64.go
@@ -24,10 +24,7 @@ import (
)
// init initializes architecture-specific state.
-func (k *Kernel) init(opts KernelOpts, maxCPUs int) {
- // Save the root page tables.
- k.PageTables = opts.PageTables
-
+func (k *Kernel) init(maxCPUs int) {
entrySize := reflect.TypeOf(kernelEntry{}).Size()
var (
entries []kernelEntry
diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go
index b294ccc7c..6cbbf001f 100644
--- a/pkg/sentry/platform/ring0/kernel_arm64.go
+++ b/pkg/sentry/platform/ring0/kernel_arm64.go
@@ -25,9 +25,7 @@ func HaltAndResume()
func HaltEl1SvcAndResume()
// init initializes architecture-specific state.
-func (k *Kernel) init(opts KernelOpts, maxCPUs int) {
- // Save the root page tables.
- k.PageTables = opts.PageTables
+func (k *Kernel) init(maxCPUs int) {
}
// init initializes architecture-specific state.
diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go
index 45eba960d..53bc3353c 100644
--- a/pkg/sentry/platform/ring0/offsets_arm64.go
+++ b/pkg/sentry/platform/ring0/offsets_arm64.go
@@ -47,43 +47,36 @@ func Emit(w io.Writer) {
fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet)
fmt.Fprintf(w, "\n// Vectors.\n")
- fmt.Fprintf(w, "#define El1SyncInvalid 0x%02x\n", El1SyncInvalid)
- fmt.Fprintf(w, "#define El1IrqInvalid 0x%02x\n", El1IrqInvalid)
- fmt.Fprintf(w, "#define El1FiqInvalid 0x%02x\n", El1FiqInvalid)
- fmt.Fprintf(w, "#define El1ErrorInvalid 0x%02x\n", El1ErrorInvalid)
fmt.Fprintf(w, "#define El1Sync 0x%02x\n", El1Sync)
fmt.Fprintf(w, "#define El1Irq 0x%02x\n", El1Irq)
fmt.Fprintf(w, "#define El1Fiq 0x%02x\n", El1Fiq)
- fmt.Fprintf(w, "#define El1Error 0x%02x\n", El1Error)
+ fmt.Fprintf(w, "#define El1Err 0x%02x\n", El1Err)
fmt.Fprintf(w, "#define El0Sync 0x%02x\n", El0Sync)
fmt.Fprintf(w, "#define El0Irq 0x%02x\n", El0Irq)
fmt.Fprintf(w, "#define El0Fiq 0x%02x\n", El0Fiq)
- fmt.Fprintf(w, "#define El0Error 0x%02x\n", El0Error)
+ fmt.Fprintf(w, "#define El0Err 0x%02x\n", El0Err)
- fmt.Fprintf(w, "#define El0Sync_invalid 0x%02x\n", El0Sync_invalid)
- fmt.Fprintf(w, "#define El0Irq_invalid 0x%02x\n", El0Irq_invalid)
- fmt.Fprintf(w, "#define El0Fiq_invalid 0x%02x\n", El0Fiq_invalid)
- fmt.Fprintf(w, "#define El0Error_invalid 0x%02x\n", El0Error_invalid)
+ fmt.Fprintf(w, "#define El1SyncDa 0x%02x\n", El1SyncDa)
+ fmt.Fprintf(w, "#define El1SyncIa 0x%02x\n", El1SyncIa)
+ fmt.Fprintf(w, "#define El1SyncSpPc 0x%02x\n", El1SyncSpPc)
+ fmt.Fprintf(w, "#define El1SyncUndef 0x%02x\n", El1SyncUndef)
+ fmt.Fprintf(w, "#define El1SyncDbg 0x%02x\n", El1SyncDbg)
+ fmt.Fprintf(w, "#define El1SyncInv 0x%02x\n", El1SyncInv)
- fmt.Fprintf(w, "#define El1Sync_da 0x%02x\n", El1Sync_da)
- fmt.Fprintf(w, "#define El1Sync_ia 0x%02x\n", El1Sync_ia)
- fmt.Fprintf(w, "#define El1Sync_sp_pc 0x%02x\n", El1Sync_sp_pc)
- fmt.Fprintf(w, "#define El1Sync_undef 0x%02x\n", El1Sync_undef)
- fmt.Fprintf(w, "#define El1Sync_dbg 0x%02x\n", El1Sync_dbg)
- fmt.Fprintf(w, "#define El1Sync_inv 0x%02x\n", El1Sync_inv)
+ fmt.Fprintf(w, "#define El0SyncSVC 0x%02x\n", El0SyncSVC)
+ fmt.Fprintf(w, "#define El0SyncDa 0x%02x\n", El0SyncDa)
+ fmt.Fprintf(w, "#define El0SyncIa 0x%02x\n", El0SyncIa)
+ fmt.Fprintf(w, "#define El0SyncFpsimdAcc 0x%02x\n", El0SyncFpsimdAcc)
+ fmt.Fprintf(w, "#define El0SyncSveAcc 0x%02x\n", El0SyncSveAcc)
+ fmt.Fprintf(w, "#define El0SyncSys 0x%02x\n", El0SyncSys)
+ fmt.Fprintf(w, "#define El0SyncSpPc 0x%02x\n", El0SyncSpPc)
+ fmt.Fprintf(w, "#define El0SyncUndef 0x%02x\n", El0SyncUndef)
+ fmt.Fprintf(w, "#define El0SyncDbg 0x%02x\n", El0SyncDbg)
+ fmt.Fprintf(w, "#define El0SyncInv 0x%02x\n", El0SyncInv)
- fmt.Fprintf(w, "#define El0Sync_svc 0x%02x\n", El0Sync_svc)
- fmt.Fprintf(w, "#define El0Sync_da 0x%02x\n", El0Sync_da)
- fmt.Fprintf(w, "#define El0Sync_ia 0x%02x\n", El0Sync_ia)
- fmt.Fprintf(w, "#define El0Sync_fpsimd_acc 0x%02x\n", El0Sync_fpsimd_acc)
- fmt.Fprintf(w, "#define El0Sync_sve_acc 0x%02x\n", El0Sync_sve_acc)
- fmt.Fprintf(w, "#define El0Sync_sys 0x%02x\n", El0Sync_sys)
- fmt.Fprintf(w, "#define El0Sync_sp_pc 0x%02x\n", El0Sync_sp_pc)
- fmt.Fprintf(w, "#define El0Sync_undef 0x%02x\n", El0Sync_undef)
- fmt.Fprintf(w, "#define El0Sync_dbg 0x%02x\n", El0Sync_dbg)
- fmt.Fprintf(w, "#define El0Sync_inv 0x%02x\n", El0Sync_inv)
+ fmt.Fprintf(w, "#define El0ErrNMI 0x%02x\n", El0ErrNMI)
fmt.Fprintf(w, "#define PageFault 0x%02x\n", PageFault)
fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall)
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go
index 7f18ac296..bc16a1622 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go
@@ -30,6 +30,10 @@ type PageTables struct {
Allocator Allocator
// root is the pagetable root.
+ //
+ // For same archs such as amd64, the upper of the PTEs is cloned
+ // from and owned by upperSharedPageTables which are shared among
+ // many PageTables if upperSharedPageTables is not nil.
root *PTEs
// rootPhysical is the cached physical address of the root.
@@ -39,15 +43,52 @@ type PageTables struct {
// archPageTables includes architecture-specific features.
archPageTables
+
+ // upperSharedPageTables represents a read-only shared upper
+ // of the Pagetable. When it is not nil, the upper is not
+ // allowed to be modified.
+ upperSharedPageTables *PageTables
+
+ // upperStart is the start address of the upper portion that
+ // are shared from upperSharedPageTables
+ upperStart uintptr
+
+ // readOnlyShared indicates the Pagetables are read-only and
+ // own the ranges that are shared with other Pagetables.
+ readOnlyShared bool
}
-// New returns new PageTables.
-func New(a Allocator) *PageTables {
+// NewWithUpper returns new PageTables.
+//
+// upperSharedPageTables are used for mapping the upper of addresses,
+// starting at upperStart. These pageTables should not be touched (as
+// invalidations may be incorrect) after they are passed as an
+// upperSharedPageTables. Only when all dependent PageTables are gone
+// may they be used. The intenteded use case is for kernel page tables,
+// which are static and fixed.
+//
+// Precondition: upperStart must be between canonical ranges.
+// Precondition: upperStart must be pgdSize aligned.
+// precondition: upperSharedPageTables must be marked read-only shared.
+func NewWithUpper(a Allocator, upperSharedPageTables *PageTables, upperStart uintptr) *PageTables {
p := new(PageTables)
p.Init(a)
+ if upperSharedPageTables != nil {
+ if !upperSharedPageTables.readOnlyShared {
+ panic("Only read-only shared pagetables can be used as upper")
+ }
+ p.upperSharedPageTables = upperSharedPageTables
+ p.upperStart = upperStart
+ p.cloneUpperShared()
+ }
return p
}
+// New returns new PageTables.
+func New(a Allocator) *PageTables {
+ return NewWithUpper(a, nil, 0)
+}
+
// mapVisitor is used for map.
type mapVisitor struct {
target uintptr // Input.
@@ -90,6 +131,21 @@ func (*mapVisitor) requiresSplit() bool { return true }
//
//go:nosplit
func (p *PageTables) Map(addr usermem.Addr, length uintptr, opts MapOpts, physical uintptr) bool {
+ if p.readOnlyShared {
+ panic("Should not modify read-only shared pagetables.")
+ }
+ if uintptr(addr)+length < uintptr(addr) {
+ panic("addr & length overflow")
+ }
+ if p.upperSharedPageTables != nil {
+ // ignore change to the read-only upper shared portion.
+ if uintptr(addr) >= p.upperStart {
+ return false
+ }
+ if uintptr(addr)+length > p.upperStart {
+ length = p.upperStart - uintptr(addr)
+ }
+ }
if !opts.AccessType.Any() {
return p.Unmap(addr, length)
}
@@ -128,12 +184,27 @@ func (v *unmapVisitor) visit(start uintptr, pte *PTE, align uintptr) {
//
// True is returned iff there was a previous mapping in the range.
//
-// Precondition: addr & length must be page-aligned.
+// Precondition: addr & length must be page-aligned, their sum must not overflow.
//
// +checkescape:hard,stack
//
//go:nosplit
func (p *PageTables) Unmap(addr usermem.Addr, length uintptr) bool {
+ if p.readOnlyShared {
+ panic("Should not modify read-only shared pagetables.")
+ }
+ if uintptr(addr)+length < uintptr(addr) {
+ panic("addr & length overflow")
+ }
+ if p.upperSharedPageTables != nil {
+ // ignore change to the read-only upper shared portion.
+ if uintptr(addr) >= p.upperStart {
+ return false
+ }
+ if uintptr(addr)+length > p.upperStart {
+ length = p.upperStart - uintptr(addr)
+ }
+ }
w := unmapWalker{
pageTables: p,
visitor: unmapVisitor{
@@ -218,3 +289,10 @@ func (p *PageTables) Lookup(addr usermem.Addr) (physical uintptr, opts MapOpts)
w.iterateRange(uintptr(addr), uintptr(addr)+1)
return w.visitor.physical + offset, w.visitor.opts
}
+
+// MarkReadOnlyShared marks the pagetables read-only and can be shared.
+//
+// It is usually used on the pagetables that are used as the upper
+func (p *PageTables) MarkReadOnlyShared() {
+ p.readOnlyShared = true
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
index 520161755..a4e416af7 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
@@ -24,14 +24,6 @@ import (
// archPageTables is architecture-specific data.
type archPageTables struct {
- // root is the pagetable root for kernel space.
- root *PTEs
-
- // rootPhysical is the cached physical address of the root.
- //
- // This is saved only to prevent constant translation.
- rootPhysical uintptr
-
asid uint16
}
@@ -46,7 +38,7 @@ func (p *PageTables) TTBR0_EL1(noFlush bool, asid uint16) uint64 {
//
//go:nosplit
func (p *PageTables) TTBR1_EL1(noFlush bool, asid uint16) uint64 {
- return uint64(p.archPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset
+ return uint64(p.upperSharedPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset
}
// Bits in page table entries.
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
index 0c153cf8c..e7ab887e5 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
@@ -50,5 +50,26 @@ func (p *PageTables) Init(allocator Allocator) {
p.rootPhysical = p.Allocator.PhysicalFor(p.root)
}
+func pgdIndex(upperStart uintptr) uintptr {
+ if upperStart&(pgdSize-1) != 0 {
+ panic("upperStart should be pgd size aligned")
+ }
+ if upperStart >= upperBottom {
+ return entriesPerPage/2 + (upperStart-upperBottom)/pgdSize
+ }
+ if upperStart < lowerTop {
+ return upperStart / pgdSize
+ }
+ panic("upperStart should be in canonical range")
+}
+
+// cloneUpperShared clone the upper from the upper shared page tables.
+//
+//go:nosplit
+func (p *PageTables) cloneUpperShared() {
+ start := pgdIndex(p.upperStart)
+ copy(p.root[start:entriesPerPage], p.upperSharedPageTables.root[start:entriesPerPage])
+}
+
// PTEs is a collection of entries.
type PTEs [entriesPerPage]PTE
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
index 1a49f12a2..5392bf27a 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
@@ -36,7 +36,7 @@ const (
pudSize = 1 << pudShift
pgdSize = 1 << pgdShift
- ttbrASIDOffset = 55
+ ttbrASIDOffset = 48
ttbrASIDMask = 0xff
entriesPerPage = 512
@@ -49,8 +49,17 @@ func (p *PageTables) Init(allocator Allocator) {
p.Allocator = allocator
p.root = p.Allocator.NewPTEs()
p.rootPhysical = p.Allocator.PhysicalFor(p.root)
- p.archPageTables.root = p.Allocator.NewPTEs()
- p.archPageTables.rootPhysical = p.Allocator.PhysicalFor(p.archPageTables.root)
+}
+
+// cloneUpperShared clone the upper from the upper shared page tables.
+//
+//go:nosplit
+func (p *PageTables) cloneUpperShared() {
+ if p.upperStart != upperBottom {
+ panic("upperStart should be the same as upperBottom")
+ }
+
+ // nothing to do for arm.
}
// PTEs is a collection of entries.
diff --git a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go
index c261d393a..157c9a7cc 100644
--- a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go
+++ b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go
@@ -116,7 +116,7 @@ func next(start uintptr, size uintptr) uintptr {
func (w *Walker) iterateRangeCanonical(start, end uintptr) {
pgdEntryIndex := w.pageTables.root
if start >= upperBottom {
- pgdEntryIndex = w.pageTables.archPageTables.root
+ pgdEntryIndex = w.pageTables.upperSharedPageTables.root
}
for pgdIndex := (uint16((start & pgdMask) >> pgdShift)); start < end && pgdIndex < entriesPerPage; pgdIndex++ {
diff --git a/pkg/sentry/socket/control/control_vfs2.go b/pkg/sentry/socket/control/control_vfs2.go
index d9621968c..37d02948f 100644
--- a/pkg/sentry/socket/control/control_vfs2.go
+++ b/pkg/sentry/socket/control/control_vfs2.go
@@ -24,6 +24,8 @@ import (
)
// SCMRightsVFS2 represents a SCM_RIGHTS socket control message.
+//
+// +stateify savable
type SCMRightsVFS2 interface {
transport.RightsControlMessage
@@ -34,9 +36,11 @@ type SCMRightsVFS2 interface {
Files(ctx context.Context, max int) (rf RightsFilesVFS2, truncated bool)
}
-// RightsFiles represents a SCM_RIGHTS socket control message. A reference is
-// maintained for each vfs.FileDescription and is release either when an FD is created or
-// when the Release method is called.
+// RightsFilesVFS2 represents a SCM_RIGHTS socket control message. A reference
+// is maintained for each vfs.FileDescription and is release either when an FD
+// is created or when the Release method is called.
+//
+// +stateify savable
type RightsFilesVFS2 []*vfs.FileDescription
// NewSCMRightsVFS2 creates a new SCM_RIGHTS socket control message
diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go
index 163af329b..9a2cac40b 100644
--- a/pkg/sentry/socket/hostinet/socket_vfs2.go
+++ b/pkg/sentry/socket/hostinet/socket_vfs2.go
@@ -33,6 +33,7 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// +stateify savable
type socketVFS2 struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
@@ -51,7 +52,7 @@ var _ = socket.SocketVFS2(&socketVFS2{})
func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol int, fd int, flags uint32) (*vfs.FileDescription, *syserr.Error) {
mnt := t.Kernel().SocketMount()
- d := sockfs.NewDentry(t.Credentials(), mnt)
+ d := sockfs.NewDentry(t, mnt)
defer d.DecRef(t)
s := &socketVFS2{
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index faa61160e..7e7857ac3 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -324,7 +324,12 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
}
// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr.
-func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+func (s *Stack) AddInterfaceAddr(int32, inet.InterfaceAddr) error {
+ return syserror.EACCES
+}
+
+// RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr.
+func (s *Stack) RemoveInterfaceAddr(int32, inet.InterfaceAddr) error {
return syserror.EACCES
}
@@ -359,7 +364,7 @@ func (s *Stack) TCPSACKEnabled() (bool, error) {
}
// SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled.
-func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
+func (s *Stack) SetTCPSACKEnabled(bool) error {
return syserror.EACCES
}
@@ -369,7 +374,7 @@ func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) {
}
// SetTCPRecovery implements inet.Stack.SetTCPRecovery.
-func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error {
+func (s *Stack) SetTCPRecovery(inet.TCPLossRecovery) error {
return syserror.EACCES
}
@@ -430,18 +435,18 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
}
if rawLine == "" {
- return fmt.Errorf("Failed to get raw line")
+ return fmt.Errorf("failed to get raw line")
}
parts := strings.SplitN(rawLine, ":", 2)
if len(parts) != 2 {
- return fmt.Errorf("Failed to get prefix from: %q", rawLine)
+ return fmt.Errorf("failed to get prefix from: %q", rawLine)
}
sliceStat = toSlice(stat)
fields := strings.Fields(strings.TrimSpace(parts[1]))
if len(fields) != len(sliceStat) {
- return fmt.Errorf("Failed to parse fields: %q", rawLine)
+ return fmt.Errorf("failed to parse fields: %q", rawLine)
}
if _, ok := stat.(*inet.StatSNMPTCP); ok {
snmpTCP = true
@@ -457,7 +462,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
sliceStat[i], err = strconv.ParseUint(fields[i], 10, 64)
}
if err != nil {
- return fmt.Errorf("Failed to parse field %d from: %q, %v", i, rawLine, err)
+ return fmt.Errorf("failed to parse field %d from: %q, %v", i, rawLine, err)
}
}
@@ -495,6 +500,6 @@ func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
}
// SetForwarding implements inet.Stack.SetForwarding.
-func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error {
+func (s *Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error {
return syserror.EACCES
}
diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go
index 549787955..e0976fed0 100644
--- a/pkg/sentry/socket/netfilter/extensions.go
+++ b/pkg/sentry/socket/netfilter/extensions.go
@@ -100,24 +100,43 @@ func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf
// marshalTarget and unmarshalTarget can be used.
type targetMaker interface {
// id uniquely identifies the target.
- id() stack.TargetID
+ id() targetID
- // marshal converts from a stack.Target to an ABI struct.
- marshal(target stack.Target) []byte
+ // marshal converts from a target to an ABI struct.
+ marshal(target target) []byte
- // unmarshal converts from the ABI matcher struct to a stack.Target.
- unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error)
+ // unmarshal converts from the ABI matcher struct to a target.
+ unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error)
}
-// targetMakers maps the TargetID of supported targets to the targetMaker that
+// A targetID uniquely identifies a target.
+type targetID struct {
+ // name is the target name as stored in the xt_entry_target struct.
+ name string
+
+ // networkProtocol is the protocol to which the target applies.
+ networkProtocol tcpip.NetworkProtocolNumber
+
+ // revision is the version of the target.
+ revision uint8
+}
+
+// target extends a stack.Target, allowing it to be used with the extension
+// system. The sentry only uses targets, never stack.Targets directly.
+type target interface {
+ stack.Target
+ id() targetID
+}
+
+// targetMakers maps the targetID of supported targets to the targetMaker that
// marshals and unmarshals it. It is immutable after package initialization.
-var targetMakers = map[stack.TargetID]targetMaker{}
+var targetMakers = map[targetID]targetMaker{}
func targetRevision(name string, netProto tcpip.NetworkProtocolNumber, rev uint8) (uint8, bool) {
- tid := stack.TargetID{
- Name: name,
- NetworkProtocol: netProto,
- Revision: rev,
+ tid := targetID{
+ name: name,
+ networkProtocol: netProto,
+ revision: rev,
}
if _, ok := targetMakers[tid]; !ok {
return 0, false
@@ -126,8 +145,8 @@ func targetRevision(name string, netProto tcpip.NetworkProtocolNumber, rev uint8
// Return the highest supported revision unless rev is higher.
for _, other := range targetMakers {
otherID := other.id()
- if name == otherID.Name && netProto == otherID.NetworkProtocol && otherID.Revision > rev {
- rev = uint8(otherID.Revision)
+ if name == otherID.name && netProto == otherID.networkProtocol && otherID.revision > rev {
+ rev = uint8(otherID.revision)
}
}
return rev, true
@@ -142,19 +161,21 @@ func registerTargetMaker(tm targetMaker) {
targetMakers[tm.id()] = tm
}
-func marshalTarget(target stack.Target) []byte {
- targetMaker, ok := targetMakers[target.ID()]
+func marshalTarget(tgt stack.Target) []byte {
+ // The sentry only uses targets, never stack.Targets directly.
+ target := tgt.(target)
+ targetMaker, ok := targetMakers[target.id()]
if !ok {
- panic(fmt.Sprintf("unknown target of type %T with id %+v.", target, target.ID()))
+ panic(fmt.Sprintf("unknown target of type %T with id %+v.", target, target.id()))
}
return targetMaker.marshal(target)
}
-func unmarshalTarget(target linux.XTEntryTarget, filter stack.IPHeaderFilter, buf []byte) (stack.Target, *syserr.Error) {
- tid := stack.TargetID{
- Name: target.Name.String(),
- NetworkProtocol: filter.NetworkProtocol(),
- Revision: target.Revision,
+func unmarshalTarget(target linux.XTEntryTarget, filter stack.IPHeaderFilter, buf []byte) (target, *syserr.Error) {
+ tid := targetID{
+ name: target.Name.String(),
+ networkProtocol: filter.NetworkProtocol(),
+ revision: target.Revision,
}
targetMaker, ok := targetMakers[tid]
if !ok {
diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go
index b560fae0d..70c561cce 100644
--- a/pkg/sentry/socket/netfilter/ipv4.go
+++ b/pkg/sentry/socket/netfilter/ipv4.go
@@ -46,13 +46,13 @@ func convertNetstackToBinary4(stk *stack.Stack, tablename linux.TableName) (linu
return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename)
}
- table, ok := stk.IPTables().GetTable(tablename.String(), false)
+ id, ok := nameToID[tablename.String()]
if !ok {
return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename)
}
// Setup the info struct.
- entries, info := getEntries4(table, tablename)
+ entries, info := getEntries4(stk.IPTables().GetTable(id, false), tablename)
return entries, info, nil
}
diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go
index 4253f7bf4..5dbb604f0 100644
--- a/pkg/sentry/socket/netfilter/ipv6.go
+++ b/pkg/sentry/socket/netfilter/ipv6.go
@@ -46,13 +46,13 @@ func convertNetstackToBinary6(stk *stack.Stack, tablename linux.TableName) (linu
return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename)
}
- table, ok := stk.IPTables().GetTable(tablename.String(), true)
+ id, ok := nameToID[tablename.String()]
if !ok {
return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename)
}
// Setup the info struct, which is the same in IPv4 and IPv6.
- entries, info := getEntries6(table, tablename)
+ entries, info := getEntries6(stk.IPTables().GetTable(id, true), tablename)
return entries, info, nil
}
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 904a12e38..b283d7229 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -42,6 +42,45 @@ func nflog(format string, args ...interface{}) {
}
}
+// Table names.
+const (
+ natTable = "nat"
+ mangleTable = "mangle"
+ filterTable = "filter"
+)
+
+// nameToID is immutable.
+var nameToID = map[string]stack.TableID{
+ natTable: stack.NATID,
+ mangleTable: stack.MangleID,
+ filterTable: stack.FilterID,
+}
+
+// DefaultLinuxTables returns the rules of stack.DefaultTables() wrapped for
+// compatibility with netfilter extensions.
+func DefaultLinuxTables() *stack.IPTables {
+ tables := stack.DefaultTables()
+ tables.VisitTargets(func(oldTarget stack.Target) stack.Target {
+ switch val := oldTarget.(type) {
+ case *stack.AcceptTarget:
+ return &acceptTarget{AcceptTarget: *val}
+ case *stack.DropTarget:
+ return &dropTarget{DropTarget: *val}
+ case *stack.ErrorTarget:
+ return &errorTarget{ErrorTarget: *val}
+ case *stack.UserChainTarget:
+ return &userChainTarget{UserChainTarget: *val}
+ case *stack.ReturnTarget:
+ return &returnTarget{ReturnTarget: *val}
+ case *stack.RedirectTarget:
+ return &redirectTarget{RedirectTarget: *val}
+ default:
+ panic(fmt.Sprintf("Unknown rule in default iptables of type %T", val))
+ }
+ })
+ return tables
+}
+
// GetInfo returns information about iptables.
func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, ipv6 bool) (linux.IPTGetinfo, *syserr.Error) {
// Read in the struct and table name.
@@ -144,9 +183,9 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
// TODO(gvisor.dev/issue/170): Support other tables.
var table stack.Table
switch replace.Name.String() {
- case stack.FilterTable:
+ case filterTable:
table = stack.EmptyFilterTable()
- case stack.NATTable:
+ case natTable:
table = stack.EmptyNATTable()
default:
nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
@@ -177,7 +216,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
}
if offset == replace.Underflow[hook] {
if !validUnderflow(table.Rules[ruleIdx], ipv6) {
- nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP", ruleIdx)
+ nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP: %+v", ruleIdx)
return syserr.ErrInvalidArgument
}
table.Underflows[hk] = ruleIdx
@@ -253,8 +292,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
// - There are no chains without an unconditional final rule.
// - There are no chains without an unconditional underflow rule.
- return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table, ipv6))
-
+ return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(nameToID[replace.Name.String()], table, ipv6))
}
// parseMatchers parses 0 or more matchers from optVal. optVal should contain
@@ -308,7 +346,7 @@ func validUnderflow(rule stack.Rule, ipv6 bool) bool {
return false
}
switch rule.Target.(type) {
- case *stack.AcceptTarget, *stack.DropTarget:
+ case *acceptTarget, *dropTarget:
return true
default:
return false
@@ -319,7 +357,7 @@ func isUnconditionalAccept(rule stack.Rule, ipv6 bool) bool {
if !validUnderflow(rule, ipv6) {
return false
}
- _, ok := rule.Target.(*stack.AcceptTarget)
+ _, ok := rule.Target.(*acceptTarget)
return ok
}
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index 0e14447fe..f2653d523 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -26,6 +26,15 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// ErrorTargetName is used to mark targets as error targets. Error targets
+// shouldn't be reached - an error has occurred if we fall through to one.
+const ErrorTargetName = "ERROR"
+
+// RedirectTargetName is used to mark targets as redirect targets. Redirect
+// targets should be reached for only NAT and Mangle tables. These targets will
+// change the destination port and/or IP for packets.
+const RedirectTargetName = "REDIRECT"
+
func init() {
// Standard targets include ACCEPT, DROP, RETURN, and JUMP.
registerTargetMaker(&standardTargetMaker{
@@ -52,25 +61,96 @@ func init() {
})
}
+// The stack package provides some basic, useful targets for us. The following
+// types wrap them for compatibility with the extension system.
+
+type acceptTarget struct {
+ stack.AcceptTarget
+}
+
+func (at *acceptTarget) id() targetID {
+ return targetID{
+ networkProtocol: at.NetworkProtocol,
+ }
+}
+
+type dropTarget struct {
+ stack.DropTarget
+}
+
+func (dt *dropTarget) id() targetID {
+ return targetID{
+ networkProtocol: dt.NetworkProtocol,
+ }
+}
+
+type errorTarget struct {
+ stack.ErrorTarget
+}
+
+func (et *errorTarget) id() targetID {
+ return targetID{
+ name: ErrorTargetName,
+ networkProtocol: et.NetworkProtocol,
+ }
+}
+
+type userChainTarget struct {
+ stack.UserChainTarget
+}
+
+func (uc *userChainTarget) id() targetID {
+ return targetID{
+ name: ErrorTargetName,
+ networkProtocol: uc.NetworkProtocol,
+ }
+}
+
+type returnTarget struct {
+ stack.ReturnTarget
+}
+
+func (rt *returnTarget) id() targetID {
+ return targetID{
+ networkProtocol: rt.NetworkProtocol,
+ }
+}
+
+type redirectTarget struct {
+ stack.RedirectTarget
+
+ // addr must be (un)marshalled when reading and writing the target to
+ // userspace, but does not affect behavior.
+ addr tcpip.Address
+}
+
+func (rt *redirectTarget) id() targetID {
+ return targetID{
+ name: RedirectTargetName,
+ networkProtocol: rt.NetworkProtocol,
+ }
+}
+
type standardTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-func (sm *standardTargetMaker) id() stack.TargetID {
+func (sm *standardTargetMaker) id() targetID {
// Standard targets have the empty string as a name and no revisions.
- return stack.TargetID{
- NetworkProtocol: sm.NetworkProtocol,
+ return targetID{
+ networkProtocol: sm.NetworkProtocol,
}
}
-func (*standardTargetMaker) marshal(target stack.Target) []byte {
+
+func (*standardTargetMaker) marshal(target target) []byte {
// Translate verdicts the same way as the iptables tool.
var verdict int32
switch tg := target.(type) {
- case *stack.AcceptTarget:
+ case *acceptTarget:
verdict = -linux.NF_ACCEPT - 1
- case *stack.DropTarget:
+ case *dropTarget:
verdict = -linux.NF_DROP - 1
- case *stack.ReturnTarget:
+ case *returnTarget:
verdict = linux.NF_RETURN
case *JumpTarget:
verdict = int32(tg.Offset)
@@ -90,7 +170,7 @@ func (*standardTargetMaker) marshal(target stack.Target) []byte {
return binary.Marshal(ret, usermem.ByteOrder, xt)
}
-func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
if len(buf) != linux.SizeOfXTStandardTarget {
nflog("buf has wrong size for standard target %d", len(buf))
return nil, syserr.ErrInvalidArgument
@@ -114,20 +194,20 @@ type errorTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-func (em *errorTargetMaker) id() stack.TargetID {
+func (em *errorTargetMaker) id() targetID {
// Error targets have no revision.
- return stack.TargetID{
- Name: stack.ErrorTargetName,
- NetworkProtocol: em.NetworkProtocol,
+ return targetID{
+ name: ErrorTargetName,
+ networkProtocol: em.NetworkProtocol,
}
}
-func (*errorTargetMaker) marshal(target stack.Target) []byte {
+func (*errorTargetMaker) marshal(target target) []byte {
var errorName string
switch tg := target.(type) {
- case *stack.ErrorTarget:
- errorName = stack.ErrorTargetName
- case *stack.UserChainTarget:
+ case *errorTarget:
+ errorName = ErrorTargetName
+ case *userChainTarget:
errorName = tg.Name
default:
panic(fmt.Sprintf("errorMakerTarget cannot marshal unknown type %T", target))
@@ -140,37 +220,38 @@ func (*errorTargetMaker) marshal(target stack.Target) []byte {
},
}
copy(xt.Name[:], errorName)
- copy(xt.Target.Name[:], stack.ErrorTargetName)
+ copy(xt.Target.Name[:], ErrorTargetName)
ret := make([]byte, 0, linux.SizeOfXTErrorTarget)
return binary.Marshal(ret, usermem.ByteOrder, xt)
}
-func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
if len(buf) != linux.SizeOfXTErrorTarget {
nflog("buf has insufficient size for error target %d", len(buf))
return nil, syserr.ErrInvalidArgument
}
- var errorTarget linux.XTErrorTarget
+ var errTgt linux.XTErrorTarget
buf = buf[:linux.SizeOfXTErrorTarget]
- binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget)
+ binary.Unmarshal(buf, usermem.ByteOrder, &errTgt)
// Error targets are used in 2 cases:
- // * An actual error case. These rules have an error
- // named stack.ErrorTargetName. The last entry of the table
- // is usually an error case to catch any packets that
- // somehow fall through every rule.
+ // * An actual error case. These rules have an error named
+ // ErrorTargetName. The last entry of the table is usually an error
+ // case to catch any packets that somehow fall through every rule.
// * To mark the start of a user defined chain. These
// rules have an error with the name of the chain.
- switch name := errorTarget.Name.String(); name {
- case stack.ErrorTargetName:
- return &stack.ErrorTarget{NetworkProtocol: filter.NetworkProtocol()}, nil
+ switch name := errTgt.Name.String(); name {
+ case ErrorTargetName:
+ return &errorTarget{stack.ErrorTarget{
+ NetworkProtocol: filter.NetworkProtocol(),
+ }}, nil
default:
// User defined chain.
- return &stack.UserChainTarget{
+ return &userChainTarget{stack.UserChainTarget{
Name: name,
NetworkProtocol: filter.NetworkProtocol(),
- }, nil
+ }}, nil
}
}
@@ -178,22 +259,22 @@ type redirectTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-func (rm *redirectTargetMaker) id() stack.TargetID {
- return stack.TargetID{
- Name: stack.RedirectTargetName,
- NetworkProtocol: rm.NetworkProtocol,
+func (rm *redirectTargetMaker) id() targetID {
+ return targetID{
+ name: RedirectTargetName,
+ networkProtocol: rm.NetworkProtocol,
}
}
-func (*redirectTargetMaker) marshal(target stack.Target) []byte {
- rt := target.(*stack.RedirectTarget)
+func (*redirectTargetMaker) marshal(target target) []byte {
+ rt := target.(*redirectTarget)
// This is a redirect target named redirect
xt := linux.XTRedirectTarget{
Target: linux.XTEntryTarget{
TargetSize: linux.SizeOfXTRedirectTarget,
},
}
- copy(xt.Target.Name[:], stack.RedirectTargetName)
+ copy(xt.Target.Name[:], RedirectTargetName)
ret := make([]byte, 0, linux.SizeOfXTRedirectTarget)
xt.NfRange.RangeSize = 1
@@ -203,7 +284,7 @@ func (*redirectTargetMaker) marshal(target stack.Target) []byte {
return binary.Marshal(ret, usermem.ByteOrder, xt)
}
-func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
if len(buf) < linux.SizeOfXTRedirectTarget {
nflog("redirectTargetMaker: buf has insufficient size for redirect target %d", len(buf))
return nil, syserr.ErrInvalidArgument
@@ -214,15 +295,17 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
return nil, syserr.ErrInvalidArgument
}
- var redirectTarget linux.XTRedirectTarget
+ var rt linux.XTRedirectTarget
buf = buf[:linux.SizeOfXTRedirectTarget]
- binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget)
+ binary.Unmarshal(buf, usermem.ByteOrder, &rt)
// Copy linux.XTRedirectTarget to stack.RedirectTarget.
- target := stack.RedirectTarget{NetworkProtocol: filter.NetworkProtocol()}
+ target := redirectTarget{RedirectTarget: stack.RedirectTarget{
+ NetworkProtocol: filter.NetworkProtocol(),
+ }}
// RangeSize should be 1.
- nfRange := redirectTarget.NfRange
+ nfRange := rt.NfRange
if nfRange.RangeSize != 1 {
nflog("redirectTargetMaker: bad rangesize %d", nfRange.RangeSize)
return nil, syserr.ErrInvalidArgument
@@ -247,7 +330,7 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
return nil, syserr.ErrInvalidArgument
}
- target.Addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
+ target.addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
target.Port = ntohs(nfRange.RangeIPV4.MinPort)
return &target, nil
@@ -264,15 +347,15 @@ type nfNATTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-func (rm *nfNATTargetMaker) id() stack.TargetID {
- return stack.TargetID{
- Name: stack.RedirectTargetName,
- NetworkProtocol: rm.NetworkProtocol,
+func (rm *nfNATTargetMaker) id() targetID {
+ return targetID{
+ name: RedirectTargetName,
+ networkProtocol: rm.NetworkProtocol,
}
}
-func (*nfNATTargetMaker) marshal(target stack.Target) []byte {
- rt := target.(*stack.RedirectTarget)
+func (*nfNATTargetMaker) marshal(target target) []byte {
+ rt := target.(*redirectTarget)
nt := nfNATTarget{
Target: linux.XTEntryTarget{
TargetSize: nfNATMarhsalledSize,
@@ -281,9 +364,9 @@ func (*nfNATTargetMaker) marshal(target stack.Target) []byte {
Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED,
},
}
- copy(nt.Target.Name[:], stack.RedirectTargetName)
- copy(nt.Range.MinAddr[:], rt.Addr)
- copy(nt.Range.MaxAddr[:], rt.Addr)
+ copy(nt.Target.Name[:], RedirectTargetName)
+ copy(nt.Range.MinAddr[:], rt.addr)
+ copy(nt.Range.MaxAddr[:], rt.addr)
nt.Range.MinProto = htons(rt.Port)
nt.Range.MaxProto = nt.Range.MinProto
@@ -292,7 +375,7 @@ func (*nfNATTargetMaker) marshal(target stack.Target) []byte {
return binary.Marshal(ret, usermem.ByteOrder, nt)
}
-func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) {
+func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
if size := nfNATMarhsalledSize; len(buf) < size {
nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size)
return nil, syserr.ErrInvalidArgument
@@ -324,10 +407,12 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (sta
return nil, syserr.ErrInvalidArgument
}
- target := stack.RedirectTarget{
- NetworkProtocol: filter.NetworkProtocol(),
- Addr: tcpip.Address(natRange.MinAddr[:]),
- Port: ntohs(natRange.MinProto),
+ target := redirectTarget{
+ RedirectTarget: stack.RedirectTarget{
+ NetworkProtocol: filter.NetworkProtocol(),
+ Port: ntohs(natRange.MinProto),
+ },
+ addr: tcpip.Address(natRange.MinAddr[:]),
}
return &target, nil
@@ -335,18 +420,24 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (sta
// translateToStandardTarget translates from the value in a
// linux.XTStandardTarget to an stack.Verdict.
-func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (stack.Target, *syserr.Error) {
+func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (target, *syserr.Error) {
// TODO(gvisor.dev/issue/170): Support other verdicts.
switch val {
case -linux.NF_ACCEPT - 1:
- return &stack.AcceptTarget{NetworkProtocol: netProto}, nil
+ return &acceptTarget{stack.AcceptTarget{
+ NetworkProtocol: netProto,
+ }}, nil
case -linux.NF_DROP - 1:
- return &stack.DropTarget{NetworkProtocol: netProto}, nil
+ return &dropTarget{stack.DropTarget{
+ NetworkProtocol: netProto,
+ }}, nil
case -linux.NF_QUEUE - 1:
nflog("unsupported iptables verdict QUEUE")
return nil, syserr.ErrInvalidArgument
case linux.NF_RETURN:
- return &stack.ReturnTarget{NetworkProtocol: netProto}, nil
+ return &returnTarget{stack.ReturnTarget{
+ NetworkProtocol: netProto,
+ }}, nil
default:
nflog("unknown iptables verdict %d", val)
return nil, syserr.ErrInvalidArgument
@@ -382,9 +473,9 @@ type JumpTarget struct {
}
// ID implements Target.ID.
-func (jt *JumpTarget) ID() stack.TargetID {
- return stack.TargetID{
- NetworkProtocol: jt.NetworkProtocol,
+func (jt *JumpTarget) id() targetID {
+ return targetID{
+ networkProtocol: jt.NetworkProtocol,
}
}
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index 844acfede..352c51390 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -71,7 +71,7 @@ func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma
}
if filter.Protocol != header.TCPProtocolNumber {
- return nil, fmt.Errorf("TCP matching is only valid for protocol %d.", header.TCPProtocolNumber)
+ return nil, fmt.Errorf("TCP matching is only valid for protocol %d", header.TCPProtocolNumber)
}
return &TCPMatcher{
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index 63201201c..c88d8268d 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -68,7 +68,7 @@ func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma
}
if filter.Protocol != header.UDPProtocolNumber {
- return nil, fmt.Errorf("UDP matching is only valid for protocol %d.", header.UDPProtocolNumber)
+ return nil, fmt.Errorf("UDP matching is only valid for protocol %d", header.UDPProtocolNumber)
}
return &UDPMatcher{
diff --git a/pkg/sentry/socket/netlink/provider_vfs2.go b/pkg/sentry/socket/netlink/provider_vfs2.go
index e8930f031..f061c5d62 100644
--- a/pkg/sentry/socket/netlink/provider_vfs2.go
+++ b/pkg/sentry/socket/netlink/provider_vfs2.go
@@ -51,7 +51,7 @@ func (*socketProviderVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol
vfsfd := &s.vfsfd
mnt := t.Kernel().SocketMount()
- d := sockfs.NewDentry(t.Credentials(), mnt)
+ d := sockfs.NewDentry(t, mnt)
defer d.DecRef(t)
if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{
DenyPRead: true,
diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go
index c84d8bd7c..f4d034c13 100644
--- a/pkg/sentry/socket/netlink/route/protocol.go
+++ b/pkg/sentry/socket/netlink/route/protocol.go
@@ -36,9 +36,9 @@ type commandKind int
const (
kindNew commandKind = 0x0
- kindDel = 0x1
- kindGet = 0x2
- kindSet = 0x3
+ kindDel commandKind = 0x1
+ kindGet commandKind = 0x2
+ kindSet commandKind = 0x3
)
func typeKind(typ uint16) commandKind {
@@ -423,6 +423,11 @@ func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlin
}
attrs = rest
+ // NOTE: A netlink message will contain multiple header attributes.
+ // Both the IFA_ADDRESS and IFA_LOCAL attributes are typically sent
+ // with IFA_ADDRESS being a prefix address and IFA_LOCAL being the
+ // local interface address. We add the local interface address here
+ // and ignore the IFA_ADDRESS.
switch ahdr.Type {
case linux.IFA_LOCAL:
err := stack.AddInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{
@@ -439,11 +444,60 @@ func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlin
} else if err != nil {
return syserr.ErrInvalidArgument
}
+ case linux.IFA_ADDRESS:
+ default:
+ return syserr.ErrNotSupported
}
}
return nil
}
+// delAddr handles RTM_DELADDR requests.
+func (p *Protocol) delAddr(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network stack.
+ return syserr.ErrProtocolNotSupported
+ }
+
+ var ifa linux.InterfaceAddrMessage
+ attrs, ok := msg.GetData(&ifa)
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+
+ for !attrs.Empty() {
+ ahdr, value, rest, ok := attrs.ParseFirst()
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+ attrs = rest
+
+ // NOTE: A netlink message will contain multiple header attributes.
+ // Both the IFA_ADDRESS and IFA_LOCAL attributes are typically sent
+ // with IFA_ADDRESS being a prefix address and IFA_LOCAL being the
+ // local interface address. We use the local interface address to
+ // remove the address and ignore the IFA_ADDRESS.
+ switch ahdr.Type {
+ case linux.IFA_LOCAL:
+ err := stack.RemoveInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{
+ Family: ifa.Family,
+ PrefixLen: ifa.PrefixLen,
+ Flags: ifa.Flags,
+ Addr: value,
+ })
+ if err != nil {
+ return syserr.ErrBadLocalAddress
+ }
+ case linux.IFA_ADDRESS:
+ default:
+ return syserr.ErrNotSupported
+ }
+ }
+
+ return nil
+}
+
// ProcessMessage implements netlink.Protocol.ProcessMessage.
func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
hdr := msg.Header()
@@ -485,6 +539,8 @@ func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms
return p.dumpRoutes(ctx, msg, ms)
case linux.RTM_NEWADDR:
return p.newAddr(ctx, msg, ms)
+ case linux.RTM_DELADDR:
+ return p.delAddr(ctx, msg, ms)
default:
return syserr.ErrNotSupported
}
diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go
index c83b23242..461d524e5 100644
--- a/pkg/sentry/socket/netlink/socket_vfs2.go
+++ b/pkg/sentry/socket/netlink/socket_vfs2.go
@@ -37,6 +37,8 @@ import (
// to/from the kernel.
//
// SocketVFS2 implements socket.SocketVFS2 and transport.Credentialer.
+//
+// +stateify savable
type SocketVFS2 struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 211f07947..86c634715 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1244,6 +1244,18 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
vP := primitive.Int32(boolToInt32(v))
return &vP, nil
+ case linux.SO_ACCEPTCONN:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.AcceptConnOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
+
default:
socket.GetSockOptEmitUnimplementedEvent(t, name)
}
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index 4c6791fff..b0d9e4d9e 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -35,6 +35,8 @@ import (
// SocketVFS2 encapsulates all the state needed to represent a network stack
// endpoint in the kernel context.
+//
+// +stateify savable
type SocketVFS2 struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
@@ -55,7 +57,7 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu
}
mnt := t.Kernel().SocketMount()
- d := sockfs.NewDentry(t.Credentials(), mnt)
+ d := sockfs.NewDentry(t, mnt)
defer d.DecRef(t)
s := &SocketVFS2{
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index 1028d2a6e..fa9ac9059 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -100,56 +100,101 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
return nicAddrs
}
-// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr.
-func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+// convertAddr converts an InterfaceAddr to a ProtocolAddress.
+func convertAddr(addr inet.InterfaceAddr) (tcpip.ProtocolAddress, error) {
var (
- protocol tcpip.NetworkProtocolNumber
- address tcpip.Address
+ protocol tcpip.NetworkProtocolNumber
+ address tcpip.Address
+ protocolAddress tcpip.ProtocolAddress
)
switch addr.Family {
case linux.AF_INET:
- if len(addr.Addr) < header.IPv4AddressSize {
- return syserror.EINVAL
+ if len(addr.Addr) != header.IPv4AddressSize {
+ return protocolAddress, syserror.EINVAL
}
if addr.PrefixLen > header.IPv4AddressSize*8 {
- return syserror.EINVAL
+ return protocolAddress, syserror.EINVAL
}
protocol = ipv4.ProtocolNumber
- address = tcpip.Address(addr.Addr[:header.IPv4AddressSize])
-
+ address = tcpip.Address(addr.Addr)
case linux.AF_INET6:
- if len(addr.Addr) < header.IPv6AddressSize {
- return syserror.EINVAL
+ if len(addr.Addr) != header.IPv6AddressSize {
+ return protocolAddress, syserror.EINVAL
}
if addr.PrefixLen > header.IPv6AddressSize*8 {
- return syserror.EINVAL
+ return protocolAddress, syserror.EINVAL
}
protocol = ipv6.ProtocolNumber
- address = tcpip.Address(addr.Addr[:header.IPv6AddressSize])
-
+ address = tcpip.Address(addr.Addr)
default:
- return syserror.ENOTSUP
+ return protocolAddress, syserror.ENOTSUP
}
- protocolAddress := tcpip.ProtocolAddress{
+ protocolAddress = tcpip.ProtocolAddress{
Protocol: protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: address,
PrefixLen: int(addr.PrefixLen),
},
}
+ return protocolAddress, nil
+}
+
+// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr.
+func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+ protocolAddress, err := convertAddr(addr)
+ if err != nil {
+ return err
+ }
// Attach address to interface.
- if err := s.Stack.AddProtocolAddressWithOptions(tcpip.NICID(idx), protocolAddress, stack.CanBePrimaryEndpoint); err != nil {
+ nicID := tcpip.NICID(idx)
+ if err := s.Stack.AddProtocolAddressWithOptions(nicID, protocolAddress, stack.CanBePrimaryEndpoint); err != nil {
+ return syserr.TranslateNetstackError(err).ToError()
+ }
+
+ // Add route for local network if it doesn't exist already.
+ localRoute := tcpip.Route{
+ Destination: protocolAddress.AddressWithPrefix.Subnet(),
+ Gateway: "", // No gateway for local network.
+ NIC: nicID,
+ }
+
+ for _, rt := range s.Stack.GetRouteTable() {
+ if rt.Equal(localRoute) {
+ return nil
+ }
+ }
+
+ // Local route does not exist yet. Add it.
+ s.Stack.AddRoute(localRoute)
+
+ return nil
+}
+
+// RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr.
+func (s *Stack) RemoveInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+ protocolAddress, err := convertAddr(addr)
+ if err != nil {
+ return err
+ }
+
+ // Remove addresses matching the address and prefix.
+ nicID := tcpip.NICID(idx)
+ if err := s.Stack.RemoveAddress(nicID, protocolAddress.AddressWithPrefix.Address); err != nil {
return syserr.TranslateNetstackError(err).ToError()
}
- // Add route for local network.
- s.Stack.AddRoute(tcpip.Route{
+ // Remove the corresponding local network route if it exists.
+ localRoute := tcpip.Route{
Destination: protocolAddress.AddressWithPrefix.Subnet(),
Gateway: "", // No gateway for local network.
- NIC: tcpip.NICID(idx),
+ NIC: nicID,
+ }
+ s.Stack.RemoveRoutes(func(rt tcpip.Route) bool {
+ return rt.Equal(localRoute)
})
+
return nil
}
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index cc7408698..cce0acc33 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "socket_refs.go",
package = "unix",
prefix = "socketOperations",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "SocketOperations",
},
@@ -19,7 +19,7 @@ go_template_instance(
out = "socket_vfs2_refs.go",
package = "unix",
prefix = "socketVFS2",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "SocketVFS2",
},
@@ -43,6 +43,7 @@ go_library(
"//pkg/log",
"//pkg/marshal",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/device",
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
index 26c3a51b9..3ebbd28b0 100644
--- a/pkg/sentry/socket/unix/transport/BUILD
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -20,7 +20,7 @@ go_template_instance(
out = "queue_refs.go",
package = "transport",
prefix = "queue",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "queue",
},
@@ -44,6 +44,7 @@ go_library(
"//pkg/ilist",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/sync",
"//pkg/syserr",
"//pkg/tcpip",
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index d6fc03520..b648273a4 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -32,6 +32,8 @@ import (
const initialLimit = 16 * 1024
// A RightsControlMessage is a control message containing FDs.
+//
+// +stateify savable
type RightsControlMessage interface {
// Clone returns a copy of the RightsControlMessage.
Clone() RightsControlMessage
@@ -336,7 +338,7 @@ type Receiver interface {
RecvMaxQueueSize() int64
// Release releases any resources owned by the Receiver. It should be
- // called before droping all references to a Receiver.
+ // called before dropping all references to a Receiver.
Release(ctx context.Context)
}
@@ -487,7 +489,7 @@ func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds
c := q.control.Clone()
// Don't consume data since we are peeking.
- copied, data, _ = vecCopy(data, q.buffer)
+ copied, _, _ = vecCopy(data, q.buffer)
return copied, copied, c, false, q.addr, notify, nil
}
@@ -572,6 +574,12 @@ func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds
return copied, copied, c, cmTruncated, q.addr, notify, nil
}
+// Release implements Receiver.Release.
+func (q *streamQueueReceiver) Release(ctx context.Context) {
+ q.queueReceiver.Release(ctx)
+ q.control.Release(ctx)
+}
+
// A ConnectedEndpoint is an Endpoint that can be used to send Messages.
type ConnectedEndpoint interface {
// Passcred implements Endpoint.Passcred.
@@ -619,7 +627,7 @@ type ConnectedEndpoint interface {
SendMaxQueueSize() int64
// Release releases any resources owned by the ConnectedEndpoint. It should
- // be called before droping all references to a ConnectedEndpoint.
+ // be called before dropping all references to a ConnectedEndpoint.
Release(ctx context.Context)
// CloseUnread sets the fact that this end is closed with unread data to
@@ -879,7 +887,7 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
switch opt {
- case tcpip.KeepaliveEnabledOption:
+ case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption:
return false, nil
case tcpip.PasscredOption:
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index a4a76d0a3..adad485a9 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -81,7 +81,6 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty
},
}
s.EnableLeakCheck()
-
return fs.NewFile(ctx, d, flags, &s)
}
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index 678355fb9..7a78444dc 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -55,7 +55,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{})
// returns a corresponding file description.
func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) {
mnt := t.Kernel().SocketMount()
- d := sockfs.NewDentry(t.Credentials(), mnt)
+ d := sockfs.NewDentry(t, mnt)
defer d.DecRef(t)
fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &vfs.FileLocks{})
@@ -80,6 +80,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3
stype: stype,
},
}
+ sock.EnableLeakCheck()
sock.LockFD.Init(locks)
vfsfd := &sock.vfsfd
if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{
diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD
index 0ea4aab8b..563d60578 100644
--- a/pkg/sentry/state/BUILD
+++ b/pkg/sentry/state/BUILD
@@ -12,10 +12,12 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
+ "//pkg/context",
"//pkg/log",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/time",
+ "//pkg/sentry/vfs",
"//pkg/sentry/watchdog",
"//pkg/state/statefile",
"//pkg/syserror",
diff --git a/pkg/sentry/state/state.go b/pkg/sentry/state/state.go
index 245d2c5cf..167754537 100644
--- a/pkg/sentry/state/state.go
+++ b/pkg/sentry/state/state.go
@@ -19,10 +19,12 @@ import (
"fmt"
"io"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sentry/watchdog"
"gvisor.dev/gvisor/pkg/state/statefile"
"gvisor.dev/gvisor/pkg/syserror"
@@ -57,7 +59,7 @@ type SaveOpts struct {
}
// Save saves the system state.
-func (opts SaveOpts) Save(k *kernel.Kernel, w *watchdog.Watchdog) error {
+func (opts SaveOpts) Save(ctx context.Context, k *kernel.Kernel, w *watchdog.Watchdog) error {
log.Infof("Sandbox save started, pausing all tasks.")
k.Pause()
k.ReceiveTaskStates()
@@ -81,7 +83,7 @@ func (opts SaveOpts) Save(k *kernel.Kernel, w *watchdog.Watchdog) error {
err = ErrStateFile{err}
} else {
// Save the kernel.
- err = k.SaveTo(wc)
+ err = k.SaveTo(ctx, wc)
// ENOSPC is a state file error. This error can only come from
// writing the state file, and not from fs.FileOperations.Fsync
@@ -108,7 +110,7 @@ type LoadOpts struct {
}
// Load loads the given kernel, setting the provided platform and stack.
-func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack, clocks time.Clocks) error {
+func (opts LoadOpts) Load(ctx context.Context, k *kernel.Kernel, n inet.Stack, clocks time.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error {
// Open the file.
r, m, err := statefile.NewReader(opts.Source, opts.Key)
if err != nil {
@@ -118,5 +120,5 @@ func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack, clocks time.Clocks) er
previousMetadata = m
// Restore the Kernel object graph.
- return k.LoadFrom(r, n, clocks)
+ return k.LoadFrom(ctx, r, n, clocks, vfsOpts)
}
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index 9c9def7cd..bb1f715e2 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -118,7 +118,7 @@ var AMD64 = &kernel.SyscallTable{
63: syscalls.Supported("uname", Uname),
64: syscalls.Supported("semget", Semget),
65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
- 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil),
+ 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil),
67: syscalls.Supported("shmdt", Shmdt),
68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
@@ -619,7 +619,7 @@ var ARM64 = &kernel.SyscallTable{
188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
190: syscalls.Supported("semget", Semget),
- 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil),
+ 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil),
192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}),
193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil),
diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go
index 849a47476..f7135ea46 100644
--- a/pkg/sentry/syscalls/linux/sys_pipe.go
+++ b/pkg/sentry/syscalls/linux/sys_pipe.go
@@ -32,7 +32,7 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) {
if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 {
return 0, syserror.EINVAL
}
- r, w := pipe.NewConnectedPipe(t, pipe.DefaultPipeSize, usermem.PageSize)
+ r, w := pipe.NewConnectedPipe(t, pipe.DefaultPipeSize)
r.SetFlags(linuxToFlags(flags).Settable())
defer r.DecRef(t)
diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go
index 47dadb800..e383a0a87 100644
--- a/pkg/sentry/syscalls/linux/sys_sem.go
+++ b/pkg/sentry/syscalls/linux/sys_sem.go
@@ -129,13 +129,27 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
v, err := getPID(t, id, num)
return uintptr(v), nil, err
+ case linux.IPC_STAT:
+ arg := args[3].Pointer()
+ ds, err := ipcStat(t, id)
+ if err == nil {
+ _, err = ds.CopyOut(t, arg)
+ }
+
+ return 0, nil, err
+
+ case linux.GETZCNT:
+ v, err := getZCnt(t, id, num)
+ return uintptr(v), nil, err
+
+ case linux.GETNCNT:
+ v, err := getNCnt(t, id, num)
+ return uintptr(v), nil, err
+
case linux.IPC_INFO,
linux.SEM_INFO,
- linux.IPC_STAT,
linux.SEM_STAT,
- linux.SEM_STAT_ANY,
- linux.GETNCNT,
- linux.GETZCNT:
+ linux.SEM_STAT_ANY:
t.Kernel().EmitUnimplementedEvent(t)
fallthrough
@@ -171,6 +185,16 @@ func ipcSet(t *kernel.Task, id int32, uid auth.UID, gid auth.GID, perms fs.FileP
return set.Change(t, creds, owner, perms)
}
+func ipcStat(t *kernel.Task, id int32) (*linux.SemidDS, error) {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return nil, syserror.EINVAL
+ }
+ creds := auth.CredentialsFromContext(t)
+ return set.GetStat(creds)
+}
+
func setVal(t *kernel.Task, id int32, num int32, val int16) error {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
@@ -240,3 +264,23 @@ func getPID(t *kernel.Task, id int32, num int32) (int32, error) {
}
return int32(tg.ID()), nil
}
+
+func getZCnt(t *kernel.Task, id int32, num int32) (uint16, error) {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return 0, syserror.EINVAL
+ }
+ creds := auth.CredentialsFromContext(t)
+ return set.CountZeroWaiters(num, creds)
+}
+
+func getNCnt(t *kernel.Task, id int32, num int32) (uint16, error) {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return 0, syserror.EINVAL
+ }
+ creds := auth.CredentialsFromContext(t)
+ return set.CountNegativeWaiters(num, creds)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go
index 46616c961..1c4cdb0dd 100644
--- a/pkg/sentry/syscalls/linux/sys_splice.go
+++ b/pkg/sentry/syscalls/linux/sys_splice.go
@@ -41,6 +41,7 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB
inCh chan struct{}
outCh chan struct{}
)
+
for opts.Length > 0 {
n, err = fs.Splice(t, outFile, inFile, opts)
opts.Length -= n
@@ -61,23 +62,28 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB
inW, _ := waiter.NewChannelEntry(inCh)
inFile.EventRegister(&inW, EventMaskRead)
defer inFile.EventUnregister(&inW)
- continue // Need to refresh readiness.
+ // Need to refresh readiness.
+ continue
}
if err = t.Block(inCh); err != nil {
break
}
}
- if outFile.Readiness(EventMaskWrite) == 0 {
- if outCh == nil {
- outCh = make(chan struct{}, 1)
- outW, _ := waiter.NewChannelEntry(outCh)
- outFile.EventRegister(&outW, EventMaskWrite)
- defer outFile.EventUnregister(&outW)
- continue // Need to refresh readiness.
- }
- if err = t.Block(outCh); err != nil {
- break
- }
+ // Don't bother checking readiness of the outFile, because it's not a
+ // guarantee that it won't return EWOULDBLOCK. Both pipes and eventfds
+ // can be "ready" but will reject writes of certain sizes with
+ // EWOULDBLOCK.
+ if outCh == nil {
+ outCh = make(chan struct{}, 1)
+ outW, _ := waiter.NewChannelEntry(outCh)
+ outFile.EventRegister(&outW, EventMaskWrite)
+ defer outFile.EventUnregister(&outW)
+ // We might be ready to write now. Try again before
+ // blocking.
+ continue
+ }
+ if err = t.Block(outCh); err != nil {
+ break
}
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go
index 035e2a6b0..9ce4f280a 100644
--- a/pkg/sentry/syscalls/linux/vfs2/splice.go
+++ b/pkg/sentry/syscalls/linux/vfs2/splice.go
@@ -480,18 +480,17 @@ func (dw *dualWaiter) waitForBoth(t *kernel.Task) error {
// waitForOut waits for dw.outfile to be read.
func (dw *dualWaiter) waitForOut(t *kernel.Task) error {
- if dw.outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 {
- if dw.outCh == nil {
- dw.outW, dw.outCh = waiter.NewChannelEntry(nil)
- dw.outFile.EventRegister(&dw.outW, eventMaskWrite)
- // We might be ready now. Try again before blocking.
- return nil
- }
- if err := t.Block(dw.outCh); err != nil {
- return err
- }
- }
- return nil
+ // Don't bother checking readiness of the outFile, because it's not a
+ // guarantee that it won't return EWOULDBLOCK. Both pipes and eventfds
+ // can be "ready" but will reject writes of certain sizes with
+ // EWOULDBLOCK. See b/172075629, b/170743336.
+ if dw.outCh == nil {
+ dw.outW, dw.outCh = waiter.NewChannelEntry(nil)
+ dw.outFile.EventRegister(&dw.outW, eventMaskWrite)
+ // We might be ready to write now. Try again before blocking.
+ return nil
+ }
+ return t.Block(dw.outCh)
}
// destroy cleans up resources help by dw. No more calls to wait* can occur
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index c855608db..440c9307c 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -32,7 +32,7 @@ go_template_instance(
out = "file_description_refs.go",
package = "vfs",
prefix = "FileDescription",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "FileDescription",
},
@@ -43,7 +43,7 @@ go_template_instance(
out = "mount_namespace_refs.go",
package = "vfs",
prefix = "MountNamespace",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "MountNamespace",
},
@@ -54,7 +54,7 @@ go_template_instance(
out = "filesystem_refs.go",
package = "vfs",
prefix = "Filesystem",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "Filesystem",
},
@@ -87,6 +87,7 @@ go_library(
"pathname.go",
"permissions.go",
"resolving_path.go",
+ "save_restore.go",
"vfs.go",
],
visibility = ["//pkg/sentry:internal"],
@@ -99,6 +100,7 @@ go_library(
"//pkg/gohacks",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/fs",
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
index 8f36c3e3b..a98aac52b 100644
--- a/pkg/sentry/vfs/epoll.go
+++ b/pkg/sentry/vfs/epoll.go
@@ -74,7 +74,7 @@ type epollInterestKey struct {
// +stateify savable
type epollInterest struct {
// epoll is the owning EpollInstance. epoll is immutable.
- epoll *EpollInstance
+ epoll *EpollInstance `state:"wait"`
// key is the file to which this epollInterest applies. key is immutable.
key epollInterestKey
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 183957ad8..546e445aa 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -183,7 +183,6 @@ func (fd *FileDescription) DecRef(ctx context.Context) {
}
fd.vd.DecRef(ctx)
fd.flagsMu.Lock()
- // TODO(gvisor.dev/issue/1663): We may need to unregister during save, as we do in VFS1.
if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil {
fd.asyncHandler.Unregister(fd)
}
diff --git a/pkg/sentry/vfs/genericfstree/genericfstree.go b/pkg/sentry/vfs/genericfstree/genericfstree.go
index 2d27d9d35..ba6e6ed49 100644
--- a/pkg/sentry/vfs/genericfstree/genericfstree.go
+++ b/pkg/sentry/vfs/genericfstree/genericfstree.go
@@ -71,7 +71,7 @@ func PrependPath(vfsroot vfs.VirtualDentry, mnt *vfs.Mount, d *Dentry, b *fspath
if mnt == vfsroot.Mount() && &d.vfsd == vfsroot.Dentry() {
return vfs.PrependPathAtVFSRootError{}
}
- if &d.vfsd == mnt.Root() {
+ if mnt != nil && &d.vfsd == mnt.Root() {
return nil
}
if d.parent == nil {
@@ -81,3 +81,12 @@ func PrependPath(vfsroot vfs.VirtualDentry, mnt *vfs.Mount, d *Dentry, b *fspath
d = d.parent
}
}
+
+// DebugPathname returns a pathname to d relative to its filesystem root.
+// DebugPathname does not correspond to any Linux function; it's used to
+// generate dentry pathnames for debugging.
+func DebugPathname(d *Dentry) string {
+ var b fspath.Builder
+ _ = PrependPath(vfs.VirtualDentry{}, nil, d, &b)
+ return b.String()
+}
diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go
index 3f0b8f45b..107171b61 100644
--- a/pkg/sentry/vfs/inotify.go
+++ b/pkg/sentry/vfs/inotify.go
@@ -65,7 +65,7 @@ type Inotify struct {
// queue is used to notify interested parties when the inotify instance
// becomes readable or writable.
- queue waiter.Queue `state:"nosave"`
+ queue waiter.Queue
// evMu *only* protects the events list. We need a separate lock while
// queuing events: using mu may violate lock ordering, since at that point
diff --git a/pkg/sentry/vfs/lock.go b/pkg/sentry/vfs/lock.go
index 55783d4eb..1ff202f2a 100644
--- a/pkg/sentry/vfs/lock.go
+++ b/pkg/sentry/vfs/lock.go
@@ -12,11 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package lock provides POSIX and BSD style file locking for VFS2 file
-// implementations.
-//
-// The actual implementations can be found in the lock package under
-// sentry/fs/lock.
package vfs
import (
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index 78f115bfa..3ea981ad4 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -106,6 +107,7 @@ func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *Mount
if opts.ReadOnly {
mnt.setReadOnlyLocked(true)
}
+ refsvfs2.Register(mnt)
return mnt
}
@@ -470,11 +472,12 @@ func (vfs *VirtualFilesystem) disconnectLocked(mnt *Mount) VirtualDentry {
// tryIncMountedRef does not require that a reference is held on mnt.
func (mnt *Mount) tryIncMountedRef() bool {
for {
- refs := atomic.LoadInt64(&mnt.refs)
- if refs <= 0 { // refs < 0 => MSB set => eagerly unmounted
+ r := atomic.LoadInt64(&mnt.refs)
+ if r <= 0 { // r < 0 => MSB set => eagerly unmounted
return false
}
- if atomic.CompareAndSwapInt64(&mnt.refs, refs, refs+1) {
+ if atomic.CompareAndSwapInt64(&mnt.refs, r, r+1) {
+ refsvfs2.LogTryIncRef(mnt, r+1)
return true
}
}
@@ -484,29 +487,53 @@ func (mnt *Mount) tryIncMountedRef() bool {
func (mnt *Mount) IncRef() {
// In general, negative values for mnt.refs are valid because the MSB is
// the eager-unmount bit.
- atomic.AddInt64(&mnt.refs, 1)
+ r := atomic.AddInt64(&mnt.refs, 1)
+ refsvfs2.LogIncRef(mnt, r)
}
// DecRef decrements mnt's reference count.
func (mnt *Mount) DecRef(ctx context.Context) {
- refs := atomic.AddInt64(&mnt.refs, -1)
- if refs&^math.MinInt64 == 0 { // mask out MSB
- var vd VirtualDentry
- if mnt.parent() != nil {
- mnt.vfs.mountMu.Lock()
- mnt.vfs.mounts.seq.BeginWrite()
- vd = mnt.vfs.disconnectLocked(mnt)
- mnt.vfs.mounts.seq.EndWrite()
- mnt.vfs.mountMu.Unlock()
- }
- if mnt.root != nil {
- mnt.root.DecRef(ctx)
- }
- mnt.fs.DecRef(ctx)
- if vd.Ok() {
- vd.DecRef(ctx)
- }
+ r := atomic.AddInt64(&mnt.refs, -1)
+ if r&^math.MinInt64 == 0 { // mask out MSB
+ refsvfs2.Unregister(mnt)
+ mnt.destroy(ctx)
+ }
+}
+
+func (mnt *Mount) destroy(ctx context.Context) {
+ var vd VirtualDentry
+ if mnt.parent() != nil {
+ mnt.vfs.mountMu.Lock()
+ mnt.vfs.mounts.seq.BeginWrite()
+ vd = mnt.vfs.disconnectLocked(mnt)
+ mnt.vfs.mounts.seq.EndWrite()
+ mnt.vfs.mountMu.Unlock()
+ }
+ if mnt.root != nil {
+ mnt.root.DecRef(ctx)
}
+ mnt.fs.DecRef(ctx)
+ if vd.Ok() {
+ vd.DecRef(ctx)
+ }
+}
+
+// RefType implements refsvfs2.CheckedObject.Type.
+func (mnt *Mount) RefType() string {
+ return "vfs.Mount"
+}
+
+// LeakMessage implements refsvfs2.CheckedObject.LeakMessage.
+func (mnt *Mount) LeakMessage() string {
+ return fmt.Sprintf("[vfs.Mount %p] reference count of %d instead of 0", mnt, atomic.LoadInt64(&mnt.refs))
+}
+
+// LogRefs implements refsvfs2.CheckedObject.LogRefs.
+//
+// This should only be set to true for debugging purposes, as it can generate an
+// extremely large amount of output and drastically degrade performance.
+func (mnt *Mount) LogRefs() bool {
+ return false
}
// DecRef decrements mntns' reference count.
diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go
index cb8c56bd3..cb882a983 100644
--- a/pkg/sentry/vfs/mount_test.go
+++ b/pkg/sentry/vfs/mount_test.go
@@ -29,7 +29,7 @@ func TestMountTableLookupEmpty(t *testing.T) {
parent := &Mount{}
point := &Dentry{}
if m := mt.Lookup(parent, point); m != nil {
- t.Errorf("empty mountTable lookup: got %p, wanted nil", m)
+ t.Errorf("Empty mountTable lookup: got %p, wanted nil", m)
}
}
@@ -111,13 +111,16 @@ func BenchmarkMountTableParallelLookup(b *testing.B) {
k := keys[i&(numMounts-1)]
m := mt.Lookup(k.mount, k.dentry)
if m == nil {
- b.Fatalf("lookup failed")
+ b.Errorf("Lookup failed")
+ return
}
if parent := m.parent(); parent != k.mount {
- b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount)
+ b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount)
+ return
}
if point := m.point(); point != k.dentry {
- b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry)
+ b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry)
+ return
}
}
}()
@@ -167,13 +170,16 @@ func BenchmarkMountMapParallelLookup(b *testing.B) {
m := ms[k]
mu.RUnlock()
if m == nil {
- b.Fatalf("lookup failed")
+ b.Errorf("Lookup failed")
+ return
}
if parent := m.parent(); parent != k.mount {
- b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount)
+ b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount)
+ return
}
if point := m.point(); point != k.dentry {
- b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry)
+ b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry)
+ return
}
}
}()
@@ -220,14 +226,17 @@ func BenchmarkMountSyncMapParallelLookup(b *testing.B) {
k := keys[i&(numMounts-1)]
mi, ok := ms.Load(k)
if !ok {
- b.Fatalf("lookup failed")
+ b.Errorf("Lookup failed")
+ return
}
m := mi.(*Mount)
if parent := m.parent(); parent != k.mount {
- b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount)
+ b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount)
+ return
}
if point := m.point(); point != k.dentry {
- b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry)
+ b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry)
+ return
}
}
}()
@@ -264,7 +273,7 @@ func BenchmarkMountTableNegativeLookup(b *testing.B) {
k := negkeys[i&(numMounts-1)]
m := mt.Lookup(k.mount, k.dentry)
if m != nil {
- b.Fatalf("lookup got %p, wanted nil", m)
+ b.Fatalf("Lookup got %p, wanted nil", m)
}
}
})
@@ -300,7 +309,7 @@ func BenchmarkMountMapNegativeLookup(b *testing.B) {
m := ms[k]
mu.RUnlock()
if m != nil {
- b.Fatalf("lookup got %p, wanted nil", m)
+ b.Fatalf("Lookup got %p, wanted nil", m)
}
}
})
@@ -333,7 +342,7 @@ func BenchmarkMountSyncMapNegativeLookup(b *testing.B) {
k := negkeys[i&(numMounts-1)]
m, _ := ms.Load(k)
if m != nil {
- b.Fatalf("lookup got %p, wanted nil", m)
+ b.Fatalf("Lookup got %p, wanted nil", m)
}
}
})
diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go
index b7d122d22..cb48c37a1 100644
--- a/pkg/sentry/vfs/mount_unsafe.go
+++ b/pkg/sentry/vfs/mount_unsafe.go
@@ -98,7 +98,6 @@ type mountTable struct {
// length and cap in separate uint32s) for ~free.
size uint64
- // FIXME(gvisor.dev/issue/1663): Slots need to be saved.
slots unsafe.Pointer `state:"nosave"` // []mountSlot; never nil after Init
}
@@ -212,6 +211,26 @@ loop:
}
}
+// Range calls f on each Mount in mt. If f returns false, Range stops iteration
+// and returns immediately.
+func (mt *mountTable) Range(f func(*Mount) bool) {
+ tcap := uintptr(1) << (mt.size & mtSizeOrderMask)
+ slotPtr := mt.slots
+ last := unsafe.Pointer(uintptr(mt.slots) + ((tcap - 1) * mountSlotBytes))
+ for {
+ slot := (*mountSlot)(slotPtr)
+ if slot.value != nil {
+ if !f((*Mount)(slot.value)) {
+ return
+ }
+ }
+ if slotPtr == last {
+ return
+ }
+ slotPtr = unsafe.Pointer(uintptr(slotPtr) + mountSlotBytes)
+ }
+}
+
// Insert inserts the given mount into mt.
//
// Preconditions: mt must not already contain a Mount with the same mount point
diff --git a/pkg/sentry/vfs/save_restore.go b/pkg/sentry/vfs/save_restore.go
new file mode 100644
index 000000000..7723ed643
--- /dev/null
+++ b/pkg/sentry/vfs/save_restore.go
@@ -0,0 +1,124 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
+)
+
+// FilesystemImplSaveRestoreExtension is an optional extension to
+// FilesystemImpl.
+type FilesystemImplSaveRestoreExtension interface {
+ // PrepareSave prepares this filesystem for serialization.
+ PrepareSave(ctx context.Context) error
+
+ // CompleteRestore completes restoration from checkpoint for this
+ // filesystem after deserialization.
+ CompleteRestore(ctx context.Context, opts CompleteRestoreOptions) error
+}
+
+// PrepareSave prepares all filesystems for serialization.
+func (vfs *VirtualFilesystem) PrepareSave(ctx context.Context) error {
+ failures := 0
+ for fs := range vfs.getFilesystems() {
+ if ext, ok := fs.impl.(FilesystemImplSaveRestoreExtension); ok {
+ if err := ext.PrepareSave(ctx); err != nil {
+ ctx.Warningf("%T.PrepareSave failed: %v", fs.impl, err)
+ failures++
+ }
+ }
+ fs.DecRef(ctx)
+ }
+ if failures != 0 {
+ return fmt.Errorf("%d filesystems failed to prepare for serialization", failures)
+ }
+ return nil
+}
+
+// CompleteRestore completes restoration from checkpoint for all filesystems
+// after deserialization.
+func (vfs *VirtualFilesystem) CompleteRestore(ctx context.Context, opts *CompleteRestoreOptions) error {
+ failures := 0
+ for fs := range vfs.getFilesystems() {
+ if ext, ok := fs.impl.(FilesystemImplSaveRestoreExtension); ok {
+ if err := ext.CompleteRestore(ctx, *opts); err != nil {
+ ctx.Warningf("%T.CompleteRestore failed: %v", fs.impl, err)
+ failures++
+ }
+ }
+ fs.DecRef(ctx)
+ }
+ if failures != 0 {
+ return fmt.Errorf("%d filesystems failed to complete restore after deserialization", failures)
+ }
+ return nil
+}
+
+// CompleteRestoreOptions contains options to
+// VirtualFilesystem.CompleteRestore() and
+// FilesystemImplSaveRestoreExtension.CompleteRestore().
+type CompleteRestoreOptions struct {
+ // If ValidateFileSizes is true, filesystem implementations backed by
+ // remote filesystems should verify that file sizes have not changed
+ // between checkpoint and restore.
+ ValidateFileSizes bool
+
+ // If ValidateFileModificationTimestamps is true, filesystem
+ // implementations backed by remote filesystems should validate that file
+ // mtimes have not changed between checkpoint and restore.
+ ValidateFileModificationTimestamps bool
+}
+
+// saveMounts is called by stateify.
+func (vfs *VirtualFilesystem) saveMounts() []*Mount {
+ if atomic.LoadPointer(&vfs.mounts.slots) == nil {
+ // vfs.Init() was never called.
+ return nil
+ }
+ var mounts []*Mount
+ vfs.mounts.Range(func(mount *Mount) bool {
+ mounts = append(mounts, mount)
+ return true
+ })
+ return mounts
+}
+
+// loadMounts is called by stateify.
+func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) {
+ if mounts == nil {
+ return
+ }
+ vfs.mounts.Init()
+ for _, mount := range mounts {
+ vfs.mounts.Insert(mount)
+ }
+}
+
+func (mnt *Mount) afterLoad() {
+ if atomic.LoadInt64(&mnt.refs) != 0 {
+ refsvfs2.Register(mnt)
+ }
+}
+
+// afterLoad is called by stateify.
+func (epi *epollInterest) afterLoad() {
+ // Mark all epollInterests as ready after restore so that the next call to
+ // EpollInstance.ReadEvents() rechecks their readiness.
+ epi.Callback(nil)
+}
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 38d2701d2..48d6252f7 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -71,7 +71,7 @@ type VirtualFilesystem struct {
// points.
//
// mounts is analogous to Linux's mount_hashtable.
- mounts mountTable
+ mounts mountTable `state:".([]*Mount)"`
// mountpoints maps mount points to mounts at those points in all
// namespaces. mountpoints is protected by mountMu.
@@ -780,23 +780,27 @@ func (vfs *VirtualFilesystem) RemoveXattrAt(ctx context.Context, creds *auth.Cre
// SyncAllFilesystems has the semantics of Linux's sync(2).
func (vfs *VirtualFilesystem) SyncAllFilesystems(ctx context.Context) error {
+ var retErr error
+ for fs := range vfs.getFilesystems() {
+ if err := fs.impl.Sync(ctx); err != nil && retErr == nil {
+ retErr = err
+ }
+ fs.DecRef(ctx)
+ }
+ return retErr
+}
+
+func (vfs *VirtualFilesystem) getFilesystems() map[*Filesystem]struct{} {
fss := make(map[*Filesystem]struct{})
vfs.filesystemsMu.Lock()
+ defer vfs.filesystemsMu.Unlock()
for fs := range vfs.filesystems {
if !fs.TryIncRef() {
continue
}
fss[fs] = struct{}{}
}
- vfs.filesystemsMu.Unlock()
- var retErr error
- for fs := range fss {
- if err := fs.impl.Sync(ctx); err != nil && retErr == nil {
- retErr = err
- }
- fs.DecRef(ctx)
- }
- return retErr
+ return fss
}
// MkdirAllAt recursively creates non-existent directories on the given path
diff --git a/pkg/shim/runsc/BUILD b/pkg/shim/runsc/BUILD
index f08599ebd..cb0001852 100644
--- a/pkg/shim/runsc/BUILD
+++ b/pkg/shim/runsc/BUILD
@@ -10,6 +10,7 @@ go_library(
],
visibility = ["//:sandbox"],
deps = [
+ "@com_github_containerd_containerd//log:go_default_library",
"@com_github_containerd_go_runc//:go_default_library",
"@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
],
diff --git a/pkg/shim/runsc/runsc.go b/pkg/shim/runsc/runsc.go
index c5cf68efa..e7c9640ba 100644
--- a/pkg/shim/runsc/runsc.go
+++ b/pkg/shim/runsc/runsc.go
@@ -28,10 +28,12 @@ import (
"syscall"
"time"
+ "github.com/containerd/containerd/log"
runc "github.com/containerd/go-runc"
specs "github.com/opencontainers/runtime-spec/specs-go"
)
+// Monitor is the default process monitor to be used by runsc.
var Monitor runc.ProcessMonitor = runc.Monitor
// DefaultCommand is the default command for Runsc.
@@ -74,6 +76,7 @@ func (r *Runsc) State(context context.Context, id string) (*runc.Container, erro
return &c, nil
}
+// CreateOpts is a set of options to Runsc.Create().
type CreateOpts struct {
runc.IO
ConsoleSocket runc.ConsoleSocket
@@ -197,6 +200,7 @@ func (r *Runsc) Wait(context context.Context, id string) (int, error) {
return res.ExitStatus, nil
}
+// ExecOpts is a set of options to runsc.Exec().
type ExecOpts struct {
runc.IO
PidFile string
@@ -301,6 +305,7 @@ func (r *Runsc) Run(context context.Context, id, bundle string, opts *CreateOpts
return Monitor.Wait(cmd, ec)
}
+// DeleteOpts is a set of options to runsc.Delete().
type DeleteOpts struct {
Force bool
}
@@ -367,6 +372,13 @@ func (r *Runsc) Stats(context context.Context, id string) (*runc.Stats, error) {
if err := json.NewDecoder(rd).Decode(&e); err != nil {
return nil, err
}
+ log.L.Debugf("Stats returned: %+v", e.Stats)
+ if e.Type != "stats" {
+ return nil, fmt.Errorf(`unexpected event type %q, wanted "stats"`, e.Type)
+ }
+ if e.Stats == nil {
+ return nil, fmt.Errorf(`"runsc events -stat" succeeded but no stat was provided`)
+ }
return e.Stats, nil
}
diff --git a/pkg/state/BUILD b/pkg/state/BUILD
index 089b3bbef..92c51879b 100644
--- a/pkg/state/BUILD
+++ b/pkg/state/BUILD
@@ -4,19 +4,6 @@ load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
go_template_instance(
- name = "pending_list",
- out = "pending_list.go",
- package = "state",
- prefix = "pending",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*objectEncodeState",
- "ElementMapper": "pendingMapper",
- "Linker": "*pendingEntry",
- },
-)
-
-go_template_instance(
name = "deferred_list",
out = "deferred_list.go",
package = "state",
@@ -83,7 +70,6 @@ go_library(
"deferred_list.go",
"encode.go",
"encode_unsafe.go",
- "pending_list.go",
"state.go",
"state_norace.go",
"state_race.go",
diff --git a/pkg/state/decode.go b/pkg/state/decode.go
index 89467ca8e..e519ddeca 100644
--- a/pkg/state/decode.go
+++ b/pkg/state/decode.go
@@ -21,6 +21,7 @@ import (
"math"
"reflect"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/state/wire"
)
@@ -258,7 +259,7 @@ func (ds *decodeState) waitObject(ods *objectDecodeState, encoded wire.Object, c
// For the purposes of this function, a child object is either a field within a
// struct or an array element, with one such indirection per element in
// path. The returned value may be an unexported field, so it may not be
-// directly assignable. See unsafePointerTo.
+// directly assignable. See decode_unsafe.go.
func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value {
// See wire.Ref.Dots. The path here is specified in reverse order.
for i := len(path) - 1; i >= 0; i-- {
@@ -519,9 +520,7 @@ func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, e
// Normal assignment: authoritative only if no dots.
v := ds.register(x, obj.Type().Elem())
- if v.IsValid() {
- obj.Set(unsafePointerTo(v))
- }
+ obj.Set(reflectValueRWAddr(v))
case wire.Bool:
obj.SetBool(bool(x))
case wire.Int:
@@ -559,7 +558,7 @@ func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, e
// contents will still be filled in later on.
typ := reflect.ArrayOf(int(x.Capacity), obj.Type().Elem()) // The object type.
v := ds.register(&x.Ref, typ)
- obj.Set(v.Slice3(0, int(x.Length), int(x.Capacity)))
+ obj.Set(reflectValueRWSlice3(v, 0, int(x.Length), int(x.Capacity)))
case *wire.Array:
ds.decodeArray(ods, obj, x)
case *wire.Struct:
@@ -592,7 +591,7 @@ func (ds *decodeState) Load(obj reflect.Value) {
ds.pending.PushBack(rootOds)
// Read the number of objects.
- lastID, object, err := ReadHeader(ds.r)
+ numObjects, object, err := ReadHeader(ds.r)
if err != nil {
Failf("header error: %w", err)
}
@@ -604,42 +603,44 @@ func (ds *decodeState) Load(obj reflect.Value) {
var (
encoded wire.Object
ods *objectDecodeState
- id = objectID(1)
+ id objectID
tid = typeID(1)
)
if err := safely(func() {
// Decode all objects in the stream.
//
- // Note that the structure of this decoding loop should match
- // the raw decoding loop in printer.go.
- for id <= objectID(lastID) {
- // Unmarshal the object.
+ // Note that the structure of this decoding loop should match the raw
+ // decoding loop in state/pretty/pretty.printer.printStream().
+ for i := uint64(0); i < numObjects; {
+ // Unmarshal either a type object or object ID.
encoded = wire.Load(ds.r)
-
- // Is this a type object? Handle inline.
- if wt, ok := encoded.(*wire.Type); ok {
- ds.types.Register(wt)
+ switch we := encoded.(type) {
+ case *wire.Type:
+ ds.types.Register(we)
tid++
encoded = nil
continue
+ case wire.Uint:
+ id = objectID(we)
+ i++
+ // Unmarshal and resolve the actual object.
+ encoded = wire.Load(ds.r)
+ ods = ds.lookup(id)
+ if ods != nil {
+ // Decode the object.
+ ds.decodeObject(ods, ods.obj, encoded)
+ } else {
+ // If an object hasn't had interest registered
+ // previously or isn't yet valid, we deferred
+ // decoding until interest is registered.
+ ds.deferred[id] = encoded
+ }
+ // For error handling.
+ ods = nil
+ encoded = nil
+ default:
+ Failf("wanted type or object ID, got %#v", encoded)
}
-
- // Actually resolve the object.
- ods = ds.lookup(id)
- if ods != nil {
- // Decode the object.
- ds.decodeObject(ods, ods.obj, encoded)
- } else {
- // If an object hasn't had interest registered
- // previously or isn't yet valid, we deferred
- // decoding until interest is registered.
- ds.deferred[id] = encoded
- }
-
- // For error handling.
- ods = nil
- encoded = nil
- id++
}
}); err != nil {
// Include as much information as we can, taking into account
@@ -647,16 +648,25 @@ func (ds *decodeState) Load(obj reflect.Value) {
if ods != nil {
Failf("error decoding object ID %d (%T) from %#v: %w", id, ods.obj.Interface(), encoded, err)
} else if encoded != nil {
- Failf("lookup error decoding object ID %d from %#v: %w", id, encoded, err)
+ Failf("error decoding from %#v: %w", encoded, err)
} else {
Failf("general decoding error: %w", err)
}
}
// Check if we have any deferred objects.
+ numDeferred := 0
for id, encoded := range ds.deferred {
- // Shoud never happen, the graph was bogus.
- Failf("still have deferred objects: one is ID %d, %#v", id, encoded)
+ numDeferred++
+ if s, ok := encoded.(*wire.Struct); ok && s.TypeID != 0 {
+ typ := ds.types.LookupType(typeID(s.TypeID))
+ log.Warningf("unused deferred object: ID %d, type %v", id, typ)
+ } else {
+ log.Warningf("unused deferred object: ID %d, %#v", id, encoded)
+ }
+ }
+ if numDeferred != 0 {
+ Failf("still had %d deferred objects", numDeferred)
}
// Scan and fire all callbacks. We iterate over the list of incomplete
diff --git a/pkg/state/decode_unsafe.go b/pkg/state/decode_unsafe.go
index d048f61a1..f1208e2a2 100644
--- a/pkg/state/decode_unsafe.go
+++ b/pkg/state/decode_unsafe.go
@@ -15,13 +15,62 @@
package state
import (
+ "fmt"
"reflect"
+ "runtime"
"unsafe"
)
-// unsafePointerTo is logically equivalent to reflect.Value.Addr, but works on
-// values representing unexported fields. This bypasses visibility, but not
-// type safety.
-func unsafePointerTo(obj reflect.Value) reflect.Value {
+// reflectValueRWAddr is equivalent to obj.Addr(), except that the returned
+// reflect.Value is usable in assignments even if obj was obtained by the use
+// of unexported struct fields.
+//
+// Preconditions: obj.CanAddr().
+func reflectValueRWAddr(obj reflect.Value) reflect.Value {
return reflect.NewAt(obj.Type(), unsafe.Pointer(obj.UnsafeAddr()))
}
+
+// reflectValueRWSlice3 is equivalent to arr.Slice3(i, j, k), except that the
+// returned reflect.Value is usable in assignments even if obj was obtained by
+// the use of unexported struct fields.
+//
+// Preconditions:
+// * arr.Kind() == reflect.Array.
+// * i, j, k >= 0.
+// * i <= j <= k <= arr.Len().
+func reflectValueRWSlice3(arr reflect.Value, i, j, k int) reflect.Value {
+ if arr.Kind() != reflect.Array {
+ panic(fmt.Sprintf("arr has kind %v, wanted %v", arr.Kind(), reflect.Array))
+ }
+ if i < 0 || j < 0 || k < 0 {
+ panic(fmt.Sprintf("negative subscripts (%d, %d, %d)", i, j, k))
+ }
+ if i > j {
+ panic(fmt.Sprintf("subscript i (%d) > j (%d)", i, j))
+ }
+ if j > k {
+ panic(fmt.Sprintf("subscript j (%d) > k (%d)", j, k))
+ }
+ if k > arr.Len() {
+ panic(fmt.Sprintf("subscript k (%d) > array length (%d)", k, arr.Len()))
+ }
+
+ sliceTyp := reflect.SliceOf(arr.Type().Elem())
+ if i == arr.Len() {
+ // By precondition, i == j == k == arr.Len().
+ return reflect.MakeSlice(sliceTyp, 0, 0)
+ }
+ slh := reflect.SliceHeader{
+ // reflect.Value.CanAddr() == false for arrays, so we need to get the
+ // address from the first element of the array.
+ Data: arr.Index(i).UnsafeAddr(),
+ Len: j - i,
+ Cap: k - i,
+ }
+ slobj := reflect.NewAt(sliceTyp, unsafe.Pointer(&slh)).Elem()
+ // Before slobj is constructed, arr holds the only pointer-typed pointer to
+ // the array since reflect.SliceHeader.Data is a uintptr, so arr must be
+ // kept alive.
+ runtime.KeepAlive(arr)
+ return slobj
+}
diff --git a/pkg/state/encode.go b/pkg/state/encode.go
index 92fcad4e9..560e7c2a3 100644
--- a/pkg/state/encode.go
+++ b/pkg/state/encode.go
@@ -17,13 +17,14 @@ package state
import (
"context"
"reflect"
+ "sort"
"gvisor.dev/gvisor/pkg/state/wire"
)
// objectEncodeState the type and identity of an object occupying a memory
// address range. This is the value type for addrSet, and the intrusive entry
-// for the pending and deferred lists.
+// for the deferred list.
type objectEncodeState struct {
// id is the assigned ID for this object.
id objectID
@@ -47,7 +48,6 @@ type objectEncodeState struct {
// references may be updated directly and automatically.
refs []*wire.Ref
- pendingEntry
deferredEntry
}
@@ -93,9 +93,15 @@ type encodeState struct {
// serialized.
pendingTypes []wire.Type
- // pending is the list of objects to be serialized. Serialization does
+ // pending maps object IDs to objects to be serialized. Serialization does
// not actually occur until the full object graph is computed.
- pending pendingList
+ pending map[objectID]*objectEncodeState
+
+ // encodedStructs maps reflect.Values representing structs to previous
+ // encodings of those structs. This is necessary to avoid duplicate calls
+ // to SaverLoader.StateSave() that may result in multiple calls to
+ // Sink.SaveValue() for a given field, resulting in object duplication.
+ encodedStructs map[reflect.Value]*wire.Struct
// stats tracks time data.
stats Stats
@@ -189,7 +195,8 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) {
// depending on this value knows there's nothing there.
return
}
- if seg, _ := es.values.Find(addr); seg.Ok() {
+ seg, gap := es.values.Find(addr)
+ if seg.Ok() {
// Ensure the map types match.
existing := seg.Value()
if existing.obj.Type() != obj.Type() {
@@ -203,13 +210,20 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) {
}
// Record the map.
+ r := addrRange{addr, addr + 1}
oes := &objectEncodeState{
id: es.nextID(),
obj: obj,
how: encodeMapAsValue,
}
- es.values.Add(addrRange{addr, addr + 1}, oes)
- es.pending.PushBack(oes)
+ // Use Insert instead of InsertWithoutMergingUnchecked when race
+ // detection is enabled to get additional sanity-checking from Merge.
+ if !raceEnabled {
+ es.values.InsertWithoutMergingUnchecked(gap, r, oes)
+ } else {
+ es.values.Insert(gap, r, oes)
+ }
+ es.pending[oes.id] = oes
es.deferred.PushBack(oes)
// See above: no ref recording.
@@ -245,7 +259,7 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) {
obj: obj,
}
es.zeroValues[typ] = oes
- es.pending.PushBack(oes)
+ es.pending[oes.id] = oes
es.deferred.PushBack(oes)
}
@@ -258,86 +272,112 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) {
size = 1 // See above.
}
- // Calculate the container.
end := addr + size
r := addrRange{addr, end}
- if seg, _ := es.values.Find(addr); seg.Ok() {
+ seg := es.values.LowerBoundSegment(addr)
+ var (
+ oes *objectEncodeState
+ gap addrGapIterator
+ )
+
+ // Does at least one previously-registered object overlap this one?
+ if seg.Ok() && seg.Start() < end {
existing := seg.Value()
- switch {
- case seg.Start() == addr && seg.End() == end && obj.Type() == existing.obj.Type():
- // The object is a perfect match. Happy path. Avoid the
- // traversal and just return directly. We don't need to
- // encode the type information or any dots here.
+
+ if seg.Range() == r && typ == existing.obj.Type() {
+ // This exact object is already registered. Avoid the traversal and
+ // just return directly. We don't need to encode the type
+ // information or any dots here.
ref.Root = wire.Uint(existing.id)
existing.refs = append(existing.refs, ref)
return
+ }
- case (seg.Start() < addr && seg.End() >= end) || (seg.Start() <= addr && seg.End() > end):
- // The previously registered object is larger than
- // this, no need to update. But we expect some
- // traversal below.
+ if seg.Range().IsSupersetOf(r) && (seg.Range() != r || isSameSizeParent(existing.obj, typ)) {
+ // This object is contained within a previously-registered object.
+ // Perform traversal from the container to the new object.
+ ref.Root = wire.Uint(existing.id)
+ ref.Dots = traverse(existing.obj.Type(), typ, seg.Start(), addr)
+ ref.Type = es.findType(existing.obj.Type())
+ existing.refs = append(existing.refs, ref)
+ return
+ }
- case seg.Start() == addr && seg.End() == end:
- if !isSameSizeParent(obj, existing.obj.Type()) {
- break // Needs traversal.
+ // This object contains one or more previously-registered objects.
+ // Remove them and update existing references to use the new one.
+ oes := &objectEncodeState{
+ // Reuse the root ID of the first contained element.
+ id: existing.id,
+ obj: obj,
+ }
+ type elementEncodeState struct {
+ addr uintptr
+ typ reflect.Type
+ refs []*wire.Ref
+ }
+ var (
+ elems []elementEncodeState
+ gap addrGapIterator
+ )
+ for {
+ // Each contained object should be completely contained within
+ // this one.
+ if raceEnabled && !r.IsSupersetOf(seg.Range()) {
+ Failf("containing object %#v does not contain existing object %#v", obj, existing.obj)
}
- fallthrough // Needs update.
-
- case (seg.Start() > addr && seg.End() <= end) || (seg.Start() >= addr && seg.End() < end):
- // Update the object and redo the encoding.
- old := existing.obj
- existing.obj = obj
+ elems = append(elems, elementEncodeState{
+ addr: seg.Start(),
+ typ: existing.obj.Type(),
+ refs: existing.refs,
+ })
+ delete(es.pending, existing.id)
es.deferred.Remove(existing)
- es.deferred.PushBack(existing)
-
- // The previously registered object is superseded by
- // this new object. We are guaranteed to not have any
- // mergeable neighbours in this segment set.
- if !raceEnabled {
- seg.SetRangeUnchecked(r)
- } else {
- // Add extra paranoid. This will be statically
- // removed at compile time unless a race build.
- es.values.Remove(seg)
- es.values.Add(r, existing)
- seg = es.values.LowerBoundSegment(addr)
+ gap = es.values.Remove(seg)
+ seg = gap.NextSegment()
+ if !seg.Ok() || seg.Start() >= end {
+ break
}
-
- // Compute the traversal required & update references.
- dots := traverse(obj.Type(), old.Type(), addr, seg.Start())
- wt := es.findType(obj.Type())
- for _, ref := range existing.refs {
+ existing = seg.Value()
+ }
+ wt := es.findType(typ)
+ for _, elem := range elems {
+ dots := traverse(typ, elem.typ, addr, elem.addr)
+ for _, ref := range elem.refs {
+ ref.Root = wire.Uint(oes.id)
ref.Dots = append(ref.Dots, dots...)
ref.Type = wt
}
- default:
- // There is a non-sensical overlap.
- Failf("overlapping objects: [new object] %#v [existing object] %#v", obj, existing.obj)
+ oes.refs = append(oes.refs, elem.refs...)
}
-
- // Compute the new reference, record and return it.
- ref.Root = wire.Uint(existing.id)
- ref.Dots = traverse(existing.obj.Type(), obj.Type(), seg.Start(), addr)
- ref.Type = es.findType(obj.Type())
- existing.refs = append(existing.refs, ref)
+ // Finally register the new containing object.
+ if !raceEnabled {
+ es.values.InsertWithoutMergingUnchecked(gap, r, oes)
+ } else {
+ es.values.Insert(gap, r, oes)
+ }
+ es.pending[oes.id] = oes
+ es.deferred.PushBack(oes)
+ ref.Root = wire.Uint(oes.id)
+ oes.refs = append(oes.refs, ref)
return
}
- // The only remaining case is a pointer value that doesn't overlap with
- // any registered addresses. Create a new entry for it, and start
- // tracking the first reference we just created.
- oes := &objectEncodeState{
+ // No existing object overlaps this one. Register a new object.
+ oes = &objectEncodeState{
id: es.nextID(),
obj: obj,
}
+ if seg.Ok() {
+ gap = seg.PrevGap()
+ } else {
+ gap = es.values.LastGap()
+ }
if !raceEnabled {
- es.values.AddWithoutMerging(r, oes)
+ es.values.InsertWithoutMergingUnchecked(gap, r, oes)
} else {
- // Merges should never happen. This is just enabled extra
- // sanity checks because the Merge function below will panic.
- es.values.Add(r, oes)
+ es.values.Insert(gap, r, oes)
}
- es.pending.PushBack(oes)
+ es.pending[oes.id] = oes
es.deferred.PushBack(oes)
ref.Root = wire.Uint(oes.id)
oes.refs = append(oes.refs, ref)
@@ -439,6 +479,14 @@ func (oe *objectEncoder) save(slot int, obj reflect.Value) {
// encodeStruct encodes a composite object.
func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) {
+ if s, ok := es.encodedStructs[obj]; ok {
+ *dest = s
+ return
+ }
+ s := &wire.Struct{}
+ *dest = s
+ es.encodedStructs[obj] = s
+
// Ensure that the obj is addressable. There are two cases when it is
// not. First, is when this is dispatched via SaveValue. Second, when
// this is a map key as a struct. Either way, we need to make a copy to
@@ -449,10 +497,6 @@ func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) {
obj = localObj.Elem()
}
- // Prepare the value.
- s := &wire.Struct{}
- *dest = s
-
// Look the type up in the database.
te, ok := es.types.Lookup(obj.Type())
if te == nil {
@@ -730,45 +774,43 @@ func (es *encodeState) Save(obj reflect.Value) {
Failf("encoding error at object %#v: %w", oes.obj.Interface(), err)
}
- // Check that items are pending.
- if es.pending.Front() == nil {
+ // Check that we have objects to serialize.
+ if len(es.pending) == 0 {
Failf("pending is empty?")
}
- // Write the header with the number of objects. Note that there is no
- // way that es.lastID could conflict with objectID, which would
- // indicate that an impossibly large encoding.
- if err := WriteHeader(es.w, uint64(es.lastID), true); err != nil {
+ // Write the header with the number of objects.
+ if err := WriteHeader(es.w, uint64(len(es.pending)), true); err != nil {
Failf("error writing header: %w", err)
}
// Serialize all pending types and pending objects. Note that we don't
// bother removing from this list as we walk it because that just
// wastes time. It will not change after this point.
- var id objectID
if err := safely(func() {
for _, wt := range es.pendingTypes {
// Encode the type.
wire.Save(es.w, &wt)
}
- for oes = es.pending.Front(); oes != nil; oes = oes.pendingEntry.Next() {
- id++ // First object is 1.
- if oes.id != id {
- Failf("expected id %d, got %d", id, oes.id)
- }
-
- // Marshall the object.
+ // Emit objects in ID order.
+ ids := make([]objectID, 0, len(es.pending))
+ for id := range es.pending {
+ ids = append(ids, id)
+ }
+ sort.Slice(ids, func(i, j int) bool {
+ return ids[i] < ids[j]
+ })
+ for _, id := range ids {
+ // Encode the id.
+ wire.Save(es.w, wire.Uint(id))
+ // Marshal the object.
+ oes := es.pending[id]
wire.Save(es.w, oes.encoded)
}
}); err != nil {
// Include the object and the error.
Failf("error serializing object %#v: %w", oes.encoded, err)
}
-
- // Check what we wrote.
- if id != es.lastID {
- Failf("expected %d objects, wrote %d", es.lastID, id)
- }
}
// objectFlag indicates that the length is a # of objects, rather than a raw
@@ -797,11 +839,6 @@ func WriteHeader(w wire.Writer, length uint64, object bool) error {
})
}
-// pendingMapper is for the pending list.
-type pendingMapper struct{}
-
-func (pendingMapper) linkerFor(oes *objectEncodeState) *pendingEntry { return &oes.pendingEntry }
-
// deferredMapper is for the deferred list.
type deferredMapper struct{}
diff --git a/pkg/state/pretty/pretty.go b/pkg/state/pretty/pretty.go
index 887f453a9..c6e8bb31d 100644
--- a/pkg/state/pretty/pretty.go
+++ b/pkg/state/pretty/pretty.go
@@ -42,6 +42,7 @@ func (p *printer) formatRef(x *wire.Ref, graph uint64) string {
buf.WriteString(typ)
buf.WriteString(")(")
buf.WriteString(baseRef)
+ buf.WriteString(")")
for _, component := range x.Dots {
switch v := component.(type) {
case *wire.FieldName:
@@ -53,7 +54,6 @@ func (p *printer) formatRef(x *wire.Ref, graph uint64) string {
panic(fmt.Sprintf("unreachable: switch should be exhaustive, unhandled case %v", reflect.TypeOf(component)))
}
}
- buf.WriteString(")")
fullRef = buf.String()
}
if p.html {
@@ -242,19 +242,22 @@ func (p *printer) printStream(w io.Writer, r wire.Reader) (err error) {
// Note that this loop must match the general structure of the
// loop in decode.go. But we don't register type information,
// etc. and just print the raw structures.
+ type objectAndID struct {
+ id uint64
+ obj wire.Object
+ }
var (
tid uint64 = 1
- objects []wire.Object
+ objects []objectAndID
)
- for oid := uint64(1); oid <= length; {
- // Unmarshal the object.
+ for i := uint64(0); i < length; {
+ // Unmarshal either a type object or object ID.
encoded := wire.Load(r)
-
- // Is this a type?
- if typ, ok := encoded.(*wire.Type); ok {
+ switch we := encoded.(type) {
+ case *wire.Type:
str, _ := p.format(graph, 0, encoded)
tag := fmt.Sprintf("g%dt%d", graph, tid)
- p.typeSpecs[tag] = typ
+ p.typeSpecs[tag] = we
if p.html {
// See below.
tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">&#9875;</a>", tag, tag, tag)
@@ -263,20 +266,22 @@ func (p *printer) printStream(w io.Writer, r wire.Reader) (err error) {
return err
}
tid++
- continue
+ case wire.Uint:
+ // Unmarshal the actual object.
+ objects = append(objects, objectAndID{
+ id: uint64(we),
+ obj: wire.Load(r),
+ })
+ i++
+ default:
+ return fmt.Errorf("wanted type or object ID, got %#v", encoded)
}
-
- // Otherwise, it is a node.
- objects = append(objects, encoded)
- oid++
}
- for i, encoded := range objects {
- // oid starts at 1.
- oid := i + 1
+ for _, objAndID := range objects {
// Format the node.
- str, _ := p.format(graph, 0, encoded)
- tag := fmt.Sprintf("g%dr%d", graph, oid)
+ str, _ := p.format(graph, 0, objAndID.obj)
+ tag := fmt.Sprintf("g%dr%d", graph, objAndID.id)
if p.html {
// Create a little tag with an anchor next to it for linking.
tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">&#9875;</a>", tag, tag, tag)
diff --git a/pkg/state/state.go b/pkg/state/state.go
index acb629969..6b8540f03 100644
--- a/pkg/state/state.go
+++ b/pkg/state/state.go
@@ -90,10 +90,12 @@ func (e *ErrState) Unwrap() error {
func Save(ctx context.Context, w wire.Writer, rootPtr interface{}) (Stats, error) {
// Create the encoding state.
es := encodeState{
- ctx: ctx,
- w: w,
- types: makeTypeEncodeDatabase(),
- zeroValues: make(map[reflect.Type]*objectEncodeState),
+ ctx: ctx,
+ w: w,
+ types: makeTypeEncodeDatabase(),
+ zeroValues: make(map[reflect.Type]*objectEncodeState),
+ pending: make(map[objectID]*objectEncodeState),
+ encodedStructs: make(map[reflect.Value]*wire.Struct),
}
// Perform the encoding.
diff --git a/pkg/state/tests/struct.go b/pkg/state/tests/struct.go
index bd2c2b399..69143d194 100644
--- a/pkg/state/tests/struct.go
+++ b/pkg/state/tests/struct.go
@@ -54,12 +54,47 @@ type outerArray struct {
}
// +stateify savable
+type outerSlice struct {
+ inner []inner
+}
+
+// +stateify savable
type inner struct {
v int64
}
// +stateify savable
+type outerFieldValue struct {
+ inner innerFieldValue
+}
+
+// +stateify savable
+type innerFieldValue struct {
+ v int64 `state:".(*savedFieldValue)"`
+}
+
+// +stateify savable
+type savedFieldValue struct {
+ v int64
+}
+
+func (ifv *innerFieldValue) saveV() *savedFieldValue {
+ return &savedFieldValue{ifv.v}
+}
+
+func (ifv *innerFieldValue) loadV(sfv *savedFieldValue) {
+ ifv.v = sfv.v
+}
+
+// +stateify savable
type system struct {
v1 interface{}
v2 interface{}
}
+
+// +stateify savable
+type system3 struct {
+ v1 interface{}
+ v2 interface{}
+ v3 interface{}
+}
diff --git a/pkg/state/tests/struct_test.go b/pkg/state/tests/struct_test.go
index de9d17aa7..c91c2c032 100644
--- a/pkg/state/tests/struct_test.go
+++ b/pkg/state/tests/struct_test.go
@@ -15,6 +15,7 @@
package tests
import (
+ "math/rand"
"testing"
"gvisor.dev/gvisor/pkg/state"
@@ -67,12 +68,23 @@ func TestRegisterTypeOnlyStruct(t *testing.T) {
}
func TestEmbeddedPointers(t *testing.T) {
- var (
- ofs outerSame
- of1 outerFieldFirst
- of2 outerFieldSecond
- oa outerArray
- )
+ // Give each int64 a random value to prevent Go from using
+ // runtime.staticuint64s, which confounds tests for struct duplication.
+ magic := func() int64 {
+ for {
+ n := rand.Int63()
+ if n < 0 || n > 255 {
+ return n
+ }
+ }
+ }
+
+ ofs := outerSame{inner{magic()}}
+ of1 := outerFieldFirst{inner{magic()}, magic()}
+ of2 := outerFieldSecond{magic(), inner{magic()}}
+ oa := outerArray{[2]inner{{magic()}, {magic()}}}
+ osl := outerSlice{oa.inner[:]}
+ ofv := outerFieldValue{innerFieldValue{magic()}}
runTestCases(t, false, "embedded-pointers", []interface{}{
system{&ofs, &ofs.inner},
@@ -85,5 +97,15 @@ func TestEmbeddedPointers(t *testing.T) {
system{&oa, &oa.inner[1]},
system{&oa.inner[0], &oa},
system{&oa.inner[1], &oa},
+ system3{&oa, &oa.inner[0], &oa.inner[1]},
+ system3{&oa, &oa.inner[1], &oa.inner[0]},
+ system3{&oa.inner[0], &oa, &oa.inner[1]},
+ system3{&oa.inner[1], &oa, &oa.inner[0]},
+ system3{&oa.inner[0], &oa.inner[1], &oa},
+ system3{&oa.inner[1], &oa.inner[0], &oa},
+ system{&oa, &osl},
+ system{&osl, &oa},
+ system{&ofv, &ofv.inner},
+ system{&ofv.inner, &ofv},
})
}
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 12b061def..b196324c7 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -97,6 +97,9 @@ type testConnection struct {
func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) {
wq := &waiter.Queue{}
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ return nil, err
+ }
entry, ch := waiter.NewChannelEntry(nil)
wq.EventRegister(&entry, waiter.EventOut)
@@ -145,7 +148,9 @@ func TestCloseReader(t *testing.T) {
defer close(done)
c, err := l.Accept()
if err != nil {
- t.Fatalf("l.Accept() = %v", err)
+ t.Errorf("l.Accept() = %v", err)
+ // Cannot call Fatalf in goroutine. Just return from the goroutine.
+ return
}
// Give c.Read() a chance to block before closing the connection.
@@ -416,7 +421,9 @@ func TestDeadlineChange(t *testing.T) {
defer close(done)
c, err := l.Accept()
if err != nil {
- t.Fatalf("l.Accept() = %v", err)
+ t.Errorf("l.Accept() = %v", err)
+ // Cannot call Fatalf in goroutine. Just return from the goroutine.
+ return
}
c.SetDeadline(time.Now().Add(time.Minute))
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 6f81b0164..530f2ae2f 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -205,7 +205,7 @@ func IPv4Options(want []byte) NetworkChecker {
if !ok {
t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0])
}
- options := ip.Options()
+ options := []byte(ip.Options())
// cmp.Diff does not consider nil slices equal to empty slices, but we do.
if len(want) == 0 && len(options) == 0 {
return
@@ -859,6 +859,21 @@ func ICMPv4Seq(want uint16) TransportChecker {
}
}
+// ICMPv4Pointer creates a checker that checks the ICMPv4 Param Problem pointer.
+func ICMPv4Pointer(want uint8) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ if got := icmpv4.Pointer(); got != want {
+ t.Fatalf("unexpected ICMP Param Problem pointer, got = %d, want = %d", got, want)
+ }
+ }
+}
+
// ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum.
// This assumes that the payload exactly makes up the rest of the slice.
func ICMPv4Checksum() TransportChecker {
@@ -953,6 +968,38 @@ func ICMPv6Code(want header.ICMPv6Code) TransportChecker {
}
}
+// ICMPv6TypeSpecific creates a checker that checks the ICMPv6 TypeSpecific
+// field.
+func ICMPv6TypeSpecific(want uint32) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv6, ok := h.(header.ICMPv6)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
+ }
+ if got := icmpv6.TypeSpecific(); got != want {
+ t.Fatalf("unexpected ICMP TypeSpecific, got = %d, want = %d", got, want)
+ }
+ }
+}
+
+// ICMPv6Payload creates a checker that checks the payload in an ICMPv6 packet.
+func ICMPv6Payload(want []byte) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv6, ok := h.(header.ICMPv6)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
+ }
+ payload := icmpv6.Payload()
+ if diff := cmp.Diff(want, payload); diff != "" {
+ t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
+ }
+ }
+}
+
// NDP creates a checker that checks that the packet contains a valid NDP
// message for type of ty, with potentially additional checks specified by
// checkers.
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index 504408878..2f13dea6a 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -99,7 +99,8 @@ const (
// ICMP codes for ICMPv4 Time Exceeded messages as defined in RFC 792.
const (
- ICMPv4TTLExceeded ICMPv4Code = 0
+ ICMPv4TTLExceeded ICMPv4Code = 0
+ ICMPv4ReassemblyTimeout ICMPv4Code = 1
)
// ICMP codes for ICMPv4 Destination Unreachable messages as defined in RFC 792.
@@ -126,6 +127,12 @@ func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) }
// SetCode sets the ICMP code field.
func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) }
+// Pointer returns the pointer field in a Parameter Problem packet.
+func (b ICMPv4) Pointer() byte { return b[icmpv4PointerOffset] }
+
+// SetPointer sets the pointer field in a Parameter Problem packet.
+func (b ICMPv4) SetPointer(c byte) { b[icmpv4PointerOffset] = c }
+
// Checksum is the ICMP checksum field.
func (b ICMPv4) Checksum() uint16 {
return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:])
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 4c6e4be64..961b77628 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -16,6 +16,7 @@ package header
import (
"encoding/binary"
+ "errors"
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -38,7 +39,6 @@ import (
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Options | Padding |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-//
const (
versIHL = 0
tos = 1
@@ -93,7 +93,7 @@ type IPv4Fields struct {
DstAddr tcpip.Address
}
-// IPv4 represents an ipv4 header stored in a byte array.
+// IPv4 is an IPv4 header.
// Most of the methods of IPv4 access to the underlying slice without
// checking the boundaries and could panic because of 'index out of range'.
// Always call IsValid() to validate an instance of IPv4 before using other
@@ -106,10 +106,13 @@ const (
IPv4MinimumSize = 20
// IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given
- // that there are only 4 bits to represents the header length in 32-bit
- // units, the header cannot exceed 15*4 = 60 bytes.
+ // that there are only 4 bits (max 0xF (15)) to represent the header length
+ // in 32-bit (4 byte) units, the header cannot exceed 15*4 = 60 bytes.
IPv4MaximumHeaderSize = 60
+ // IPv4MaximumOptionsSize is the largest size the IPv4 options can be.
+ IPv4MaximumOptionsSize = IPv4MaximumHeaderSize - IPv4MinimumSize
+
// IPv4MaximumPayloadSize is the maximum size of a valid IPv4 payload.
//
// Linux limits this to 65,515 octets (the max IP datagram size - the IPv4
@@ -130,7 +133,7 @@ const (
// IPv4ProtocolNumber is IPv4's network protocol number.
IPv4ProtocolNumber tcpip.NetworkProtocolNumber = 0x0800
- // IPv4Version is the version of the ipv4 protocol.
+ // IPv4Version is the version of the IPv4 protocol.
IPv4Version = 4
// IPv4AllSystems is the all systems IPv4 multicast address as per
@@ -148,6 +151,13 @@ const (
// packet that every IPv4 capable host must be able to
// process/reassemble.
IPv4MinimumProcessableDatagramSize = 576
+
+ // IPv4MinimumMTU is the minimum MTU required by IPv4, per RFC 791,
+ // section 3.2:
+ // Every internet module must be able to forward a datagram of 68 octets
+ // without further fragmentation. This is because an internet header may be
+ // up to 60 octets, and the minimum fragment is 8 octets.
+ IPv4MinimumMTU = 68
)
// Flags that may be set in an IPv4 packet.
@@ -191,14 +201,13 @@ func IPVersion(b []byte) int {
// Internet Header Length is the length of the internet header in 32
// bit words, and thus points to the beginning of the data. Note that
// the minimum value for a correct header is 5.
-//
const (
ipVersionShift = 4
ipIHLMask = 0x0f
IPv4IHLStride = 4
)
-// HeaderLength returns the value of the "header length" field of the ipv4
+// HeaderLength returns the value of the "header length" field of the IPv4
// header. The length returned is in bytes.
func (b IPv4) HeaderLength() uint8 {
return (b[versIHL] & ipIHLMask) * IPv4IHLStride
@@ -212,17 +221,17 @@ func (b IPv4) SetHeaderLength(hdrLen uint8) {
b[versIHL] = (IPv4Version << ipVersionShift) | ((hdrLen / IPv4IHLStride) & ipIHLMask)
}
-// ID returns the value of the identifier field of the ipv4 header.
+// ID returns the value of the identifier field of the IPv4 header.
func (b IPv4) ID() uint16 {
return binary.BigEndian.Uint16(b[id:])
}
-// Protocol returns the value of the protocol field of the ipv4 header.
+// Protocol returns the value of the protocol field of the IPv4 header.
func (b IPv4) Protocol() uint8 {
return b[protocol]
}
-// Flags returns the "flags" field of the ipv4 header.
+// Flags returns the "flags" field of the IPv4 header.
func (b IPv4) Flags() uint8 {
return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13)
}
@@ -232,41 +241,44 @@ func (b IPv4) More() bool {
return b.Flags()&IPv4FlagMoreFragments != 0
}
-// TTL returns the "TTL" field of the ipv4 header.
+// TTL returns the "TTL" field of the IPv4 header.
func (b IPv4) TTL() uint8 {
return b[ttl]
}
-// FragmentOffset returns the "fragment offset" field of the ipv4 header.
+// FragmentOffset returns the "fragment offset" field of the IPv4 header.
func (b IPv4) FragmentOffset() uint16 {
return binary.BigEndian.Uint16(b[flagsFO:]) << 3
}
-// TotalLength returns the "total length" field of the ipv4 header.
+// TotalLength returns the "total length" field of the IPv4 header.
func (b IPv4) TotalLength() uint16 {
return binary.BigEndian.Uint16(b[IPv4TotalLenOffset:])
}
-// Checksum returns the checksum field of the ipv4 header.
+// Checksum returns the checksum field of the IPv4 header.
func (b IPv4) Checksum() uint16 {
return binary.BigEndian.Uint16(b[checksum:])
}
-// SourceAddress returns the "source address" field of the ipv4 header.
+// SourceAddress returns the "source address" field of the IPv4 header.
func (b IPv4) SourceAddress() tcpip.Address {
return tcpip.Address(b[srcAddr : srcAddr+IPv4AddressSize])
}
-// DestinationAddress returns the "destination address" field of the ipv4
+// DestinationAddress returns the "destination address" field of the IPv4
// header.
func (b IPv4) DestinationAddress() tcpip.Address {
return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize])
}
-// Options returns a a buffer holding the options.
-func (b IPv4) Options() []byte {
+// IPv4Options is a buffer that holds all the raw IP options.
+type IPv4Options []byte
+
+// Options returns a buffer holding the options.
+func (b IPv4) Options() IPv4Options {
hdrLen := b.HeaderLength()
- return b[options:hdrLen:hdrLen]
+ return IPv4Options(b[options:hdrLen:hdrLen])
}
// TransportProtocol implements Network.TransportProtocol.
@@ -279,17 +291,17 @@ func (b IPv4) Payload() []byte {
return b[b.HeaderLength():][:b.PayloadLength()]
}
-// PayloadLength returns the length of the payload portion of the ipv4 packet.
+// PayloadLength returns the length of the payload portion of the IPv4 packet.
func (b IPv4) PayloadLength() uint16 {
return b.TotalLength() - uint16(b.HeaderLength())
}
-// TOS returns the "type of service" field of the ipv4 header.
+// TOS returns the "type of service" field of the IPv4 header.
func (b IPv4) TOS() (uint8, uint32) {
return b[tos], 0
}
-// SetTOS sets the "type of service" field of the ipv4 header.
+// SetTOS sets the "type of service" field of the IPv4 header.
func (b IPv4) SetTOS(v uint8, _ uint32) {
b[tos] = v
}
@@ -299,18 +311,18 @@ func (b IPv4) SetTTL(v byte) {
b[ttl] = v
}
-// SetTotalLength sets the "total length" field of the ipv4 header.
+// SetTotalLength sets the "total length" field of the IPv4 header.
func (b IPv4) SetTotalLength(totalLength uint16) {
binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength)
}
-// SetChecksum sets the checksum field of the ipv4 header.
+// SetChecksum sets the checksum field of the IPv4 header.
func (b IPv4) SetChecksum(v uint16) {
binary.BigEndian.PutUint16(b[checksum:], v)
}
// SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the
-// ipv4 header.
+// IPv4 header.
func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) {
v := (uint16(flags) << 13) | (offset >> 3)
binary.BigEndian.PutUint16(b[flagsFO:], v)
@@ -321,23 +333,23 @@ func (b IPv4) SetID(v uint16) {
binary.BigEndian.PutUint16(b[id:], v)
}
-// SetSourceAddress sets the "source address" field of the ipv4 header.
+// SetSourceAddress sets the "source address" field of the IPv4 header.
func (b IPv4) SetSourceAddress(addr tcpip.Address) {
copy(b[srcAddr:srcAddr+IPv4AddressSize], addr)
}
-// SetDestinationAddress sets the "destination address" field of the ipv4
+// SetDestinationAddress sets the "destination address" field of the IPv4
// header.
func (b IPv4) SetDestinationAddress(addr tcpip.Address) {
copy(b[dstAddr:dstAddr+IPv4AddressSize], addr)
}
-// CalculateChecksum calculates the checksum of the ipv4 header.
+// CalculateChecksum calculates the checksum of the IPv4 header.
func (b IPv4) CalculateChecksum() uint16 {
return Checksum(b[:b.HeaderLength()], 0)
}
-// Encode encodes all the fields of the ipv4 header.
+// Encode encodes all the fields of the IPv4 header.
func (b IPv4) Encode(i *IPv4Fields) {
b.SetHeaderLength(i.IHL)
b[tos] = i.TOS
@@ -351,7 +363,7 @@ func (b IPv4) Encode(i *IPv4Fields) {
copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr)
}
-// EncodePartial updates the total length and checksum fields of ipv4 header,
+// EncodePartial updates the total length and checksum fields of IPv4 header,
// taking in the partial checksum, which is the checksum of the header without
// the total length and checksum fields. It is useful in cases when similar
// packets are produced.
@@ -398,3 +410,424 @@ func IsV4LoopbackAddress(addr tcpip.Address) bool {
}
return addr[0] == 0x7f
}
+
+// ========================= Options ==========================
+
+// An IPv4OptionType can hold the valuse for the Type in an IPv4 option.
+type IPv4OptionType byte
+
+// These constants are needed to identify individual options in the option list.
+// While RFC 791 (page 31) says "Every internet module must be able to act on
+// every option." This has not generally been adhered to and some options have
+// very low rates of support. We do not support options other than those shown
+// below.
+
+const (
+ // IPv4OptionListEndType is the option type for the End Of Option List
+ // option. Anything following is ignored.
+ IPv4OptionListEndType IPv4OptionType = 0
+
+ // IPv4OptionNOPType is the No-Operation option. May appear between other
+ // options and may appear multiple times.
+ IPv4OptionNOPType IPv4OptionType = 1
+
+ // IPv4OptionRecordRouteType is used by each router on the path of the packet
+ // to record its path. It is carried over to an Echo Reply.
+ IPv4OptionRecordRouteType IPv4OptionType = 7
+
+ // IPv4OptionTimestampType is the option type for the Timestamp option.
+ IPv4OptionTimestampType IPv4OptionType = 68
+
+ // ipv4OptionTypeOffset is the offset in an option of its type field.
+ ipv4OptionTypeOffset = 0
+
+ // IPv4OptionLengthOffset is the offset in an option of its length field.
+ IPv4OptionLengthOffset = 1
+)
+
+// Potential errors when parsing generic IP options.
+var (
+ ErrIPv4OptZeroLength = errors.New("zero length IP option")
+ ErrIPv4OptDuplicate = errors.New("duplicate IP option")
+ ErrIPv4OptInvalid = errors.New("invalid IP option")
+ ErrIPv4OptMalformed = errors.New("malformed IP option")
+ ErrIPv4OptionTruncated = errors.New("truncated IP option")
+ ErrIPv4OptionAddress = errors.New("bad IP option address")
+)
+
+// IPv4Option is an interface representing various option types.
+type IPv4Option interface {
+ // Type returns the type identifier of the option.
+ Type() IPv4OptionType
+
+ // Size returns the size of the option in bytes.
+ Size() uint8
+
+ // Contents returns a slice holding the contents of the option.
+ Contents() []byte
+}
+
+var _ IPv4Option = (*IPv4OptionGeneric)(nil)
+
+// IPv4OptionGeneric is an IPv4 Option of unknown type.
+type IPv4OptionGeneric []byte
+
+// Type implements IPv4Option.
+func (o *IPv4OptionGeneric) Type() IPv4OptionType {
+ return IPv4OptionType((*o)[ipv4OptionTypeOffset])
+}
+
+// Size implements IPv4Option.
+func (o *IPv4OptionGeneric) Size() uint8 { return uint8(len(*o)) }
+
+// Contents implements IPv4Option.
+func (o *IPv4OptionGeneric) Contents() []byte { return []byte(*o) }
+
+// IPv4OptionIterator is an iterator pointing to a specific IP option
+// at any point of time. It also holds information as to a new options buffer
+// that we are building up to hand back to the caller.
+type IPv4OptionIterator struct {
+ options IPv4Options
+ // ErrCursor is where we are while parsing options. It is exported as any
+ // resulting ICMP packet is supposed to have a pointer to the byte within
+ // the IP packet where the error was detected.
+ ErrCursor uint8
+ nextErrCursor uint8
+ newOptions [IPv4MaximumOptionsSize]byte
+ writePoint int
+}
+
+// MakeIterator sets up and returns an iterator of options. It also sets up the
+// building of a new option set.
+func (o IPv4Options) MakeIterator() IPv4OptionIterator {
+ return IPv4OptionIterator{
+ options: o,
+ nextErrCursor: IPv4MinimumSize,
+ }
+}
+
+// RemainingBuffer returns the remaining (unused) part of the new option buffer,
+// into which a new option may be written.
+func (i *IPv4OptionIterator) RemainingBuffer() IPv4Options {
+ return IPv4Options(i.newOptions[i.writePoint:])
+}
+
+// ConsumeBuffer marks a portion of the new buffer as used.
+func (i *IPv4OptionIterator) ConsumeBuffer(size int) {
+ i.writePoint += size
+}
+
+// PushNOPOrEnd puts one of the single byte options onto the new options.
+// Only values 0 or 1 (ListEnd or NOP) are valid input.
+func (i *IPv4OptionIterator) PushNOPOrEnd(val IPv4OptionType) {
+ if val > IPv4OptionNOPType {
+ panic(fmt.Sprintf("invalid option type %d pushed onto option build buffer", val))
+ }
+ i.newOptions[i.writePoint] = byte(val)
+ i.writePoint++
+}
+
+// Finalize returns the completed replacement options buffer padded
+// as needed.
+func (i *IPv4OptionIterator) Finalize() IPv4Options {
+ // RFC 791 page 31 says:
+ // The options might not end on a 32-bit boundary. The internet header
+ // must be filled out with octets of zeros. The first of these would
+ // be interpreted as the end-of-options option, and the remainder as
+ // internet header padding.
+ // Since the buffer is already zero filled we just need to step the write
+ // pointer up to the next multiple of 4.
+ options := IPv4Options(i.newOptions[:(i.writePoint+0x3) & ^0x3])
+ // Poison the write pointer.
+ i.writePoint = len(i.newOptions)
+ return options
+}
+
+// Next returns the next IP option in the buffer/list of IP options.
+// It returns
+// - A slice of bytes holding the next option or nil if there is error.
+// - A boolean which is true if parsing of all the options is complete.
+// - An error which is non-nil if an error condition was encountered.
+func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) {
+ // The opts slice gets shorter as we process the options. When we have no
+ // bytes left we are done.
+ if len(i.options) == 0 {
+ return nil, true, nil
+ }
+
+ i.ErrCursor = i.nextErrCursor
+
+ optType := IPv4OptionType(i.options[ipv4OptionTypeOffset])
+
+ if optType == IPv4OptionNOPType || optType == IPv4OptionListEndType {
+ optionBody := i.options[:1]
+ i.options = i.options[1:]
+ i.nextErrCursor = i.ErrCursor + 1
+ retval := IPv4OptionGeneric(optionBody)
+ return &retval, false, nil
+ }
+
+ // There are no more single byte options defined. All the rest have a length
+ // field so we need to sanity check it.
+ if len(i.options) == 1 {
+ return nil, true, ErrIPv4OptMalformed
+ }
+
+ optLen := i.options[IPv4OptionLengthOffset]
+
+ if optLen == 0 {
+ i.ErrCursor++
+ return nil, true, ErrIPv4OptZeroLength
+ }
+
+ if optLen == 1 {
+ i.ErrCursor++
+ return nil, true, ErrIPv4OptMalformed
+ }
+
+ if optLen > uint8(len(i.options)) {
+ i.ErrCursor++
+ return nil, true, ErrIPv4OptionTruncated
+ }
+
+ optionBody := i.options[:optLen]
+ i.nextErrCursor = i.ErrCursor + optLen
+ i.options = i.options[optLen:]
+
+ // Check the length of some option types that we know.
+ switch optType {
+ case IPv4OptionTimestampType:
+ if optLen < IPv4OptionTimestampHdrLength {
+ i.ErrCursor++
+ return nil, true, ErrIPv4OptMalformed
+ }
+ retval := IPv4OptionTimestamp(optionBody)
+ return &retval, false, nil
+
+ case IPv4OptionRecordRouteType:
+ if optLen < IPv4OptionRecordRouteHdrLength {
+ i.ErrCursor++
+ return nil, true, ErrIPv4OptMalformed
+ }
+ retval := IPv4OptionRecordRoute(optionBody)
+ return &retval, false, nil
+ }
+ retval := IPv4OptionGeneric(optionBody)
+ return &retval, false, nil
+}
+
+//
+// IP Timestamp option - RFC 791 page 22.
+// +--------+--------+--------+--------+
+// |01000100| length | pointer|oflw|flg|
+// +--------+--------+--------+--------+
+// | internet address |
+// +--------+--------+--------+--------+
+// | timestamp |
+// +--------+--------+--------+--------+
+// | ... |
+//
+// Type = 68
+//
+// The Option Length is the number of octets in the option counting
+// the type, length, pointer, and overflow/flag octets (maximum
+// length 40).
+//
+// The Pointer is the number of octets from the beginning of this
+// option to the end of timestamps plus one (i.e., it points to the
+// octet beginning the space for next timestamp). The smallest
+// legal value is 5. The timestamp area is full when the pointer
+// is greater than the length.
+//
+// The Overflow (oflw) [4 bits] is the number of IP modules that
+// cannot register timestamps due to lack of space.
+//
+// The Flag (flg) [4 bits] values are
+//
+// 0 -- time stamps only, stored in consecutive 32-bit words,
+//
+// 1 -- each timestamp is preceded with internet address of the
+// registering entity,
+//
+// 3 -- the internet address fields are prespecified. An IP
+// module only registers its timestamp if it matches its own
+// address with the next specified internet address.
+//
+// Timestamps are defined in RFC 791 page 22 as milliseconds since midnight UTC.
+//
+// The Timestamp is a right-justified, 32-bit timestamp in
+// milliseconds since midnight UT. If the time is not available in
+// milliseconds or cannot be provided with respect to midnight UT
+// then any time may be inserted as a timestamp provided the high
+// order bit of the timestamp field is set to one to indicate the
+// use of a non-standard value.
+
+// IPv4OptTSFlags sefines the values expected in the Timestamp
+// option Flags field.
+type IPv4OptTSFlags uint8
+
+//
+// Timestamp option specific related constants.
+const (
+ // IPv4OptionTimestampHdrLength is the length of the timestamp option header.
+ IPv4OptionTimestampHdrLength = 4
+
+ // IPv4OptionTimestampSize is the size of an IP timestamp.
+ IPv4OptionTimestampSize = 4
+
+ // IPv4OptionTimestampWithAddrSize is the size of an IP timestamp + Address.
+ IPv4OptionTimestampWithAddrSize = IPv4AddressSize + IPv4OptionTimestampSize
+
+ // IPv4OptionTimestampMaxSize is limited by space for options
+ IPv4OptionTimestampMaxSize = IPv4MaximumOptionsSize
+
+ // IPv4OptionTimestampOnlyFlag is a flag indicating that only timestamp
+ // is present.
+ IPv4OptionTimestampOnlyFlag IPv4OptTSFlags = 0
+
+ // IPv4OptionTimestampWithIPFlag is a flag indicating that both timestamps and
+ // IP are present.
+ IPv4OptionTimestampWithIPFlag IPv4OptTSFlags = 1
+
+ // IPv4OptionTimestampWithPredefinedIPFlag is a flag indicating that
+ // predefined IP is present.
+ IPv4OptionTimestampWithPredefinedIPFlag IPv4OptTSFlags = 3
+)
+
+// ipv4TimestampTime provides the current time as specified in RFC 791.
+func ipv4TimestampTime(clock tcpip.Clock) uint32 {
+ const millisecondsPerDay = 24 * 3600 * 1000
+ const nanoPerMilli = 1000000
+ return uint32((clock.NowNanoseconds() / nanoPerMilli) % millisecondsPerDay)
+}
+
+// IP Timestamp option fields.
+const (
+ // IPv4OptTSPointerOffset is the offset of the Timestamp pointer field.
+ IPv4OptTSPointerOffset = 2
+
+ // IPv4OptTSPointerOffset is the offset of the combined Flag and Overflow
+ // fields, (each being 4 bits).
+ IPv4OptTSOFLWAndFLGOffset = 3
+ // These constants define the sub byte fields of the Flag and OverFlow field.
+ ipv4OptionTimestampOverflowshift = 4
+ ipv4OptionTimestampFlagsMask byte = 0x0f
+)
+
+var _ IPv4Option = (*IPv4OptionTimestamp)(nil)
+
+// IPv4OptionTimestamp is a Timestamp option from RFC 791.
+type IPv4OptionTimestamp []byte
+
+// Type implements IPv4Option.Type().
+func (ts *IPv4OptionTimestamp) Type() IPv4OptionType { return IPv4OptionTimestampType }
+
+// Size implements IPv4Option.
+func (ts *IPv4OptionTimestamp) Size() uint8 { return uint8(len(*ts)) }
+
+// Contents implements IPv4Option.
+func (ts *IPv4OptionTimestamp) Contents() []byte { return []byte(*ts) }
+
+// Pointer returns the pointer field in the IP Timestamp option.
+func (ts *IPv4OptionTimestamp) Pointer() uint8 {
+ return (*ts)[IPv4OptTSPointerOffset]
+}
+
+// Flags returns the flags field in the IP Timestamp option.
+func (ts *IPv4OptionTimestamp) Flags() IPv4OptTSFlags {
+ return IPv4OptTSFlags((*ts)[IPv4OptTSOFLWAndFLGOffset] & ipv4OptionTimestampFlagsMask)
+}
+
+// Overflow returns the Overflow field in the IP Timestamp option.
+func (ts *IPv4OptionTimestamp) Overflow() uint8 {
+ return (*ts)[IPv4OptTSOFLWAndFLGOffset] >> ipv4OptionTimestampOverflowshift
+}
+
+// IncOverflow increments the Overflow field in the IP Timestamp option. It
+// returns the incremented value. If the return value is 0 then the field
+// overflowed.
+func (ts *IPv4OptionTimestamp) IncOverflow() uint8 {
+ (*ts)[IPv4OptTSOFLWAndFLGOffset] += 1 << ipv4OptionTimestampOverflowshift
+ return ts.Overflow()
+}
+
+// UpdateTimestamp updates the fields of the next free timestamp slot.
+func (ts *IPv4OptionTimestamp) UpdateTimestamp(addr tcpip.Address, clock tcpip.Clock) {
+ slot := (*ts)[ts.Pointer()-1:]
+
+ switch ts.Flags() {
+ case IPv4OptionTimestampOnlyFlag:
+ binary.BigEndian.PutUint32(slot, ipv4TimestampTime(clock))
+ (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampSize
+ case IPv4OptionTimestampWithIPFlag:
+ if n := copy(slot, addr); n != IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IPv4AddressSize))
+ }
+ binary.BigEndian.PutUint32(slot[IPv4AddressSize:], ipv4TimestampTime(clock))
+ (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampWithAddrSize
+ case IPv4OptionTimestampWithPredefinedIPFlag:
+ if tcpip.Address(slot[:IPv4AddressSize]) == addr {
+ binary.BigEndian.PutUint32(slot[IPv4AddressSize:], ipv4TimestampTime(clock))
+ (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampWithAddrSize
+ }
+ }
+}
+
+// RecordRoute option specific related constants.
+//
+// from RFC 791 page 20:
+// Record Route
+//
+// +--------+--------+--------+---------//--------+
+// |00000111| length | pointer| route data |
+// +--------+--------+--------+---------//--------+
+// Type=7
+//
+// The record route option provides a means to record the route of
+// an internet datagram.
+//
+// The option begins with the option type code. The second octet
+// is the option length which includes the option type code and the
+// length octet, the pointer octet, and length-3 octets of route
+// data. The third octet is the pointer into the route data
+// indicating the octet which begins the next area to store a route
+// address. The pointer is relative to this option, and the
+// smallest legal value for the pointer is 4.
+const (
+ // IPv4OptionRecordRouteHdrLength is the length of the Record Route option
+ // header.
+ IPv4OptionRecordRouteHdrLength = 3
+
+ // IPv4OptRRPointerOffset is the offset to the pointer field in an RR
+ // option, which points to the next free slot in the list of addresses.
+ IPv4OptRRPointerOffset = 2
+)
+
+var _ IPv4Option = (*IPv4OptionRecordRoute)(nil)
+
+// IPv4OptionRecordRoute is an IPv4 RecordRoute option defined by RFC 791.
+type IPv4OptionRecordRoute []byte
+
+// Pointer returns the pointer field in the IP RecordRoute option.
+func (rr *IPv4OptionRecordRoute) Pointer() uint8 {
+ return (*rr)[IPv4OptRRPointerOffset]
+}
+
+// StoreAddress stores the given IPv4 address into the next free slot.
+func (rr *IPv4OptionRecordRoute) StoreAddress(addr tcpip.Address) {
+ start := rr.Pointer() - 1 // A one based number.
+ // start and room checked by caller.
+ if n := copy((*rr)[start:], addr); n != IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IPv4AddressSize))
+ }
+ (*rr)[IPv4OptRRPointerOffset] += IPv4AddressSize
+}
+
+// Type implements IPv4Option.
+func (rr *IPv4OptionRecordRoute) Type() IPv4OptionType { return IPv4OptionRecordRouteType }
+
+// Size implements IPv4Option.
+func (rr *IPv4OptionRecordRoute) Size() uint8 { return uint8(len(*rr)) }
+
+// Contents implements IPv4Option.
+func (rr *IPv4OptionRecordRoute) Contents() []byte { return []byte(*rr) }
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index c5d8a3456..4e7e5f76a 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -101,8 +101,10 @@ const (
// The address is ff02::2.
IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460,
- // section 5.
+ // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 8200,
+ // section 5:
+ // IPv6 requires that every link in the Internet have an MTU of 1280 octets
+ // or greater. This is known as the IPv6 minimum link MTU.
IPv6MinimumMTU = 1280
// IPv6Loopback is the IPv6 Loopback address.
@@ -373,6 +375,12 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool {
return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80
}
+// IsV6LoopbackAddress determines if the provided address is an IPv6 loopback
+// address.
+func IsV6LoopbackAddress(addr tcpip.Address) bool {
+ return addr == IPv6Loopback
+}
+
// IsV6LinkLocalMulticastAddress determines if the provided address is an IPv6
// link-local multicast address.
func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool {
diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
index dc239a0d0..2777f1411 100644
--- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
+++ b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
@@ -470,6 +470,7 @@ func TestConcurrentReaderWriter(t *testing.T) {
const count = 1000000
var wg sync.WaitGroup
+ defer wg.Wait()
wg.Add(1)
go func() {
defer wg.Done()
@@ -489,30 +490,23 @@ func TestConcurrentReaderWriter(t *testing.T) {
}
}()
- wg.Add(1)
- go func() {
- defer wg.Done()
- runtime.Gosched()
- for i := 0; i < count; i++ {
- n := 1 + rr.Intn(80)
- rb := rx.Pull()
- for rb == nil {
- rb = rx.Pull()
- }
+ for i := 0; i < count; i++ {
+ n := 1 + rr.Intn(80)
+ rb := rx.Pull()
+ for rb == nil {
+ rb = rx.Pull()
+ }
- if n != len(rb) {
- t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n)
- }
+ if n != len(rb) {
+ t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n)
+ }
- for j := range rb {
- if v := byte(rr.Intn(256)); v != rb[j] {
- t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v)
- }
+ for j := range rb {
+ if v := byte(rr.Intn(256)); v != rb[j] {
+ t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v)
}
-
- rx.Flush()
}
- }()
- wg.Wait()
+ rx.Flush()
+ }
}
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 560477926..b3e8c4b92 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -205,7 +205,12 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
//
// We don't clone the original packet buffer so that the new packet buffer
// does not have any of its headers set.
- pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views())})
+ //
+ // We trim the link headers from the cloned buffer as the sniffer doesn't
+ // handle link headers.
+ vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
+ vv.TrimFront(len(pkt.LinkHeader().View()))
+ pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv})
switch protocol {
case header.IPv4ProtocolNumber:
if ok := parse.IPv4(pkt); !ok {
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
index 0243424f6..86f14db76 100644
--- a/pkg/tcpip/link/tun/BUILD
+++ b/pkg/tcpip/link/tun/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "tun_endpoint_refs.go",
package = "tun",
prefix = "tunEndpoint",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "tunEndpoint",
},
@@ -28,6 +28,7 @@ go_library(
"//pkg/context",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/sync",
"//pkg/syserror",
"//pkg/tcpip",
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index f94491026..cda6328a2 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -150,7 +150,6 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
// 2. Creating a new NIC.
id := tcpip.NICID(s.UniqueID())
- // TODO(gvisor.dev/1486): enable leak check for tunEndpoint.
endpoint := &tunEndpoint{
Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""),
stack: s,
@@ -158,6 +157,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
name: name,
isTap: prefix == "tap",
}
+ endpoint.EnableLeakCheck()
endpoint.Endpoint.LinkEPCapabilities = linkCaps
if endpoint.name == "" {
endpoint.name = fmt.Sprintf("%s%d", prefix, id)
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index b40dde96b..8a6bcfc2c 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -30,5 +30,6 @@ go_test(
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
"@com_github_google_go_cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 7df77c66e..33a4a0720 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -18,6 +18,7 @@
package arp
import (
+ "fmt"
"sync/atomic"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -121,7 +122,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
return tcpip.ErrNotSupported
}
-func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if !e.isEnabled() {
return
}
@@ -144,34 +145,43 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
} else {
- if r.Stack().CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
return // we have no useful answer, ignore the request
}
remoteAddr := tcpip.Address(h.ProtocolAddressSender())
remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- e.nud.HandleProbe(remoteAddr, localAddr, ProtocolNumber, remoteLinkAddr, e.protocol)
+ e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol)
}
- // As per RFC 826, under Packet Reception:
- // Swap hardware and protocol fields, putting the local hardware and
- // protocol addresses in the sender fields.
- //
- // Send the packet to the (new) target hardware address on the same
- // hardware on which the request was received.
- origSender := h.HardwareAddressSender()
- r.RemoteLinkAddress = tcpip.LinkAddress(origSender)
respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize,
})
packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize))
+ respPkt.NetworkProtocolNumber = ProtocolNumber
packet.SetIPv4OverEthernet()
packet.SetOp(header.ARPReply)
- copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:])
- copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget())
- copy(packet.HardwareAddressTarget(), origSender)
- copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender())
- _ = e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, respPkt)
+ // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a
+ // link address.
+ _ = copy(packet.HardwareAddressSender(), e.nic.LinkAddress())
+ if n := copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()); n != header.IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
+ }
+ origSender := h.HardwareAddressSender()
+ if n := copy(packet.HardwareAddressTarget(), origSender); n != header.EthernetAddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.EthernetAddressSize))
+ }
+ if n := copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()); n != header.IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
+ }
+
+ // As per RFC 826, under Packet Reception:
+ // Swap hardware and protocol fields, putting the local hardware and
+ // protocol addresses in the sender fields.
+ //
+ // Send the packet to the (new) target hardware address on the same
+ // hardware on which the request was received.
+ _ = e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt)
case header.ARPReply:
addr := tcpip.Address(h.ProtocolAddressSender())
@@ -199,6 +209,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// protocol implements stack.NetworkProtocol and stack.LinkAddressResolver.
type protocol struct {
+ stack *stack.Stack
}
func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
@@ -227,26 +238,44 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
-func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
- r := &stack.Route{
- NetProto: ProtocolNumber,
- RemoteLinkAddress: remoteLinkAddr,
+func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error {
+ if len(remoteLinkAddr) == 0 {
+ remoteLinkAddr = header.EthernetBroadcastAddress
}
- if len(r.RemoteLinkAddress) == 0 {
- r.RemoteLinkAddress = header.EthernetBroadcastAddress
+
+ nicID := nic.ID()
+ if len(localAddr) == 0 {
+ addr, err := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber)
+ if err != nil {
+ return err
+ }
+
+ if len(addr.Address) == 0 {
+ return tcpip.ErrNetworkUnreachable
+ }
+
+ localAddr = addr.Address
+ } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ return tcpip.ErrBadLocalAddress
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.ARPSize,
+ ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.ARPSize,
})
h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize))
+ pkt.NetworkProtocolNumber = ProtocolNumber
h.SetIPv4OverEthernet()
h.SetOp(header.ARPRequest)
- copy(h.HardwareAddressSender(), linkEP.LinkAddress())
- copy(h.ProtocolAddressSender(), localAddr)
- copy(h.ProtocolAddressTarget(), addr)
-
- return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
+ // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a
+ // link address.
+ _ = copy(h.HardwareAddressSender(), nic.LinkAddress())
+ if n := copy(h.ProtocolAddressSender(), localAddr); n != header.IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
+ }
+ if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
+ }
+ return nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt)
}
// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress.
@@ -286,6 +315,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
// Note, to make sure that the ARP endpoint receives ARP packets, the "arp"
// address must be added to every NIC that should respond to ARP requests. See
// ProtocolAddress for more details.
-func NewProtocol(*stack.Stack) stack.NetworkProtocol {
- return &protocol{}
+func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
+ return &protocol{stack: s}
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 626af975a..087ee9c66 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -22,6 +22,7 @@ import (
"time"
"github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -78,13 +79,11 @@ func (t eventType) String() string {
type eventInfo struct {
eventType eventType
nicID tcpip.NICID
- addr tcpip.Address
- linkAddr tcpip.LinkAddress
- state stack.NeighborState
+ entry stack.NeighborEntry
}
func (e eventInfo) String() string {
- return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.eventType, e.nicID, e.addr, e.linkAddr, e.state)
+ return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry)
}
// arpDispatcher implements NUDDispatcher to validate the dispatching of
@@ -96,35 +95,29 @@ type arpDispatcher struct {
var _ stack.NUDDispatcher = (*arpDispatcher)(nil)
-func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) {
+func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) {
e := eventInfo{
eventType: entryAdded,
nicID: nicID,
- addr: addr,
- linkAddr: linkAddr,
- state: state,
+ entry: entry,
}
d.C <- e
}
-func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) {
+func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) {
e := eventInfo{
eventType: entryChanged,
nicID: nicID,
- addr: addr,
- linkAddr: linkAddr,
- state: state,
+ entry: entry,
}
d.C <- e
}
-func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) {
+func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) {
e := eventInfo{
eventType: entryRemoved,
nicID: nicID,
- addr: addr,
- linkAddr: linkAddr,
- state: state,
+ entry: entry,
}
d.C <- e
}
@@ -132,7 +125,7 @@ func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address,
func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error {
select {
case got := <-d.C:
- if diff := cmp.Diff(got, want, cmp.AllowUnexported(got)); diff != "" {
+ if diff := cmp.Diff(got, want, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" {
return fmt.Errorf("got invalid event (-got +want):\n%s", diff)
}
case <-ctx.Done():
@@ -373,9 +366,11 @@ func TestDirectRequestWithNeighborCache(t *testing.T) {
wantEvent := eventInfo{
eventType: entryAdded,
nicID: nicID,
- addr: test.senderAddr,
- linkAddr: tcpip.LinkAddress(test.senderLinkAddr),
- state: stack.Stale,
+ entry: stack.NeighborEntry{
+ Addr: test.senderAddr,
+ LinkAddr: tcpip.LinkAddress(test.senderLinkAddr),
+ State: stack.Stale,
+ },
}
if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil {
t.Fatal(err)
@@ -404,9 +399,6 @@ func TestDirectRequestWithNeighborCache(t *testing.T) {
if got, want := neigh.LinkAddr, test.senderLinkAddr; got != want {
t.Errorf("got neighbor LinkAddr = %s, want = %s", got, want)
}
- if got, want := neigh.LocalAddr, stackAddr; got != want {
- t.Errorf("got neighbor LocalAddr = %s, want = %s", got, want)
- }
if got, want := neigh.State, stack.Stale; got != want {
t.Errorf("got neighbor State = %s, want = %s", got, want)
}
@@ -423,43 +415,164 @@ func TestDirectRequestWithNeighborCache(t *testing.T) {
}
}
+var _ stack.NetworkInterface = (*testInterface)(nil)
+
+type testInterface struct {
+ stack.LinkEndpoint
+
+ nicID tcpip.NICID
+}
+
+func (t *testInterface) ID() tcpip.NICID {
+ return t.nicID
+}
+
+func (*testInterface) IsLoopback() bool {
+ return false
+}
+
+func (*testInterface) Name() string {
+ return ""
+}
+
+func (*testInterface) Enabled() bool {
+ return true
+}
+
+func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ r := stack.Route{
+ NetProto: protocol,
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+ return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt)
+}
+
func TestLinkAddressRequest(t *testing.T) {
+ const nicID = 1
+
+ testAddr := tcpip.Address([]byte{1, 2, 3, 4})
+
tests := []struct {
name string
+ nicAddr tcpip.Address
+ localAddr tcpip.Address
remoteLinkAddr tcpip.LinkAddress
- expectLinkAddr tcpip.LinkAddress
+
+ expectedErr *tcpip.Error
+ expectedLocalAddr tcpip.Address
+ expectedRemoteLinkAddr tcpip.LinkAddress
}{
{
- name: "Unicast",
+ name: "Unicast",
+ nicAddr: stackAddr,
+ localAddr: stackAddr,
+ remoteLinkAddr: remoteLinkAddr,
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: remoteLinkAddr,
+ },
+ {
+ name: "Multicast",
+ nicAddr: stackAddr,
+ localAddr: stackAddr,
+ remoteLinkAddr: "",
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
+ },
+ {
+ name: "Unicast with unspecified source",
+ nicAddr: stackAddr,
+ remoteLinkAddr: remoteLinkAddr,
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: remoteLinkAddr,
+ },
+ {
+ name: "Multicast with unspecified source",
+ nicAddr: stackAddr,
+ remoteLinkAddr: "",
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
+ },
+ {
+ name: "Unicast with unassigned address",
+ localAddr: testAddr,
+ remoteLinkAddr: remoteLinkAddr,
+ expectedErr: tcpip.ErrBadLocalAddress,
+ },
+ {
+ name: "Multicast with unassigned address",
+ localAddr: testAddr,
+ remoteLinkAddr: "",
+ expectedErr: tcpip.ErrBadLocalAddress,
+ },
+ {
+ name: "Unicast with no local address available",
remoteLinkAddr: remoteLinkAddr,
- expectLinkAddr: remoteLinkAddr,
+ expectedErr: tcpip.ErrNetworkUnreachable,
},
{
- name: "Multicast",
+ name: "Multicast with no local address available",
remoteLinkAddr: "",
- expectLinkAddr: header.EthernetBroadcastAddress,
+ expectedErr: tcpip.ErrNetworkUnreachable,
},
}
for _, test := range tests {
- p := arp.NewProtocol(nil)
- linkRes, ok := p.(stack.LinkAddressResolver)
- if !ok {
- t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
- }
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
+ })
+ p := s.NetworkProtocolInstance(arp.ProtocolNumber)
+ linkRes, ok := p.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
+ }
- linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
- if err := linkRes.LinkAddressRequest(stackAddr, remoteAddr, test.remoteLinkAddr, linkEP); err != nil {
- t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr, remoteAddr, test.remoteLinkAddr, err)
- }
+ linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
+ if err := s.CreateNIC(nicID, linkEP); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
- pkt, ok := linkEP.Read()
- if !ok {
- t.Fatal("expected to send a link address request")
- }
+ if len(test.nicAddr) != 0 {
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err)
+ }
+ }
- if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want {
- t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want)
- }
+ // We pass a test network interface to LinkAddressRequest with the same
+ // NIC ID and link endpoint used by the NIC we created earlier so that we
+ // can mock a link address request and observe the packets sent to the
+ // link endpoint even though the stack uses the real NIC to validate the
+ // local address.
+ if err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr {
+ t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, test.localAddr, test.remoteLinkAddr, err, test.expectedErr)
+ }
+
+ if test.expectedErr != nil {
+ return
+ }
+
+ pkt, ok := linkEP.Read()
+ if !ok {
+ t.Fatal("expected to send a link address request")
+ }
+
+ if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr)
+ }
+
+ rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
+ if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr {
+ t.Errorf("got HardwareAddressSender = %s, want = %s", got, stackLinkAddr)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressSender()); got != test.expectedLocalAddr {
+ t.Errorf("got ProtocolAddressSender = %s, want = %s", got, test.expectedLocalAddr)
+ }
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"); got != want {
+ t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressTarget()); got != remoteAddr {
+ t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, remoteAddr)
+ }
+ })
}
}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index ed502a473..936601287 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -136,8 +136,16 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea
// proto is the protocol number marked in the fragment being processed. It has
// to be given here outside of the FragmentID struct because IPv6 should not use
// the protocol to identify a fragment.
+//
+// releaseCB is a callback that will run when the fragment reassembly of a
+// packet is complete or cancelled. releaseCB take a a boolean argument which is
+// true iff the reassembly is cancelled due to timeout. releaseCB should be
+// passed only with the first fragment of a packet. If more than one releaseCB
+// are passed for the same packet, only the first releaseCB will be saved for
+// the packet and the succeeding ones will be dropped by running them
+// immediately with a false argument.
func (f *Fragmentation) Process(
- id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) (
+ id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView, releaseCB func(bool)) (
buffer.VectorisedView, uint8, bool, error) {
if first > last {
return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs)
@@ -171,6 +179,12 @@ func (f *Fragmentation) Process(
f.releaseReassemblersLocked()
}
}
+ if releaseCB != nil {
+ if !r.setCallback(releaseCB) {
+ // We got a duplicate callback. Release it immediately.
+ releaseCB(false /* timedOut */)
+ }
+ }
f.mu.Unlock()
res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, vv)
@@ -178,14 +192,14 @@ func (f *Fragmentation) Process(
// We probably got an invalid sequence of fragments. Just
// discard the reassembler and move on.
f.mu.Lock()
- f.release(r)
+ f.release(r, false /* timedOut */)
f.mu.Unlock()
return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragmentation processing error: %w", err)
}
f.mu.Lock()
f.size += consumed
if done {
- f.release(r)
+ f.release(r, false /* timedOut */)
}
// Evict reassemblers if we are consuming more memory than highLimit until
// we reach lowLimit.
@@ -195,14 +209,14 @@ func (f *Fragmentation) Process(
if tail == nil {
break
}
- f.release(tail)
+ f.release(tail, false /* timedOut */)
}
}
f.mu.Unlock()
return res, firstFragmentProto, done, nil
}
-func (f *Fragmentation) release(r *reassembler) {
+func (f *Fragmentation) release(r *reassembler, timedOut bool) {
// Before releasing a fragment we need to check if r is already marked as done.
// Otherwise, we would delete it twice.
if r.checkDoneOrMark() {
@@ -216,6 +230,8 @@ func (f *Fragmentation) release(r *reassembler) {
log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size)
f.size = 0
}
+
+ r.release(timedOut) // releaseCB may run.
}
// releaseReassemblersLocked releases already-expired reassemblers, then
@@ -238,31 +254,31 @@ func (f *Fragmentation) releaseReassemblersLocked() {
break
}
// If the oldest reassembler has already expired, release it.
- f.release(r)
+ f.release(r, true /* timedOut*/)
}
}
// PacketFragmenter is the book-keeping struct for packet fragmentation.
type PacketFragmenter struct {
- transportHeader buffer.View
- data buffer.VectorisedView
- reserve int
- innerMTU int
- fragmentCount int
- currentFragment int
- fragmentOffset int
+ transportHeader buffer.View
+ data buffer.VectorisedView
+ reserve int
+ fragmentPayloadLen int
+ fragmentCount int
+ currentFragment int
+ fragmentOffset int
}
// MakePacketFragmenter prepares the struct needed for packet fragmentation.
//
// pkt is the packet to be fragmented.
//
-// innerMTU is the maximum number of bytes of fragmentable data a fragment can
+// fragmentPayloadLen is the maximum number of bytes of fragmentable data a fragment can
// have.
//
// reserve is the number of bytes that should be reserved for the headers in
// each generated fragment.
-func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) PacketFragmenter {
+func MakePacketFragmenter(pkt *stack.PacketBuffer, fragmentPayloadLen uint32, reserve int) PacketFragmenter {
// As per RFC 8200 Section 4.5, some IPv6 extension headers should not be
// repeated in each fragment. However we do not currently support any header
// of that kind yet, so the following computation is valid for both IPv4 and
@@ -273,13 +289,13 @@ func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) Pa
var fragmentableData buffer.VectorisedView
fragmentableData.AppendView(pkt.TransportHeader().View())
fragmentableData.Append(pkt.Data)
- fragmentCount := (fragmentableData.Size() + innerMTU - 1) / innerMTU
+ fragmentCount := (uint32(fragmentableData.Size()) + fragmentPayloadLen - 1) / fragmentPayloadLen
return PacketFragmenter{
- data: fragmentableData,
- reserve: reserve,
- innerMTU: innerMTU,
- fragmentCount: fragmentCount,
+ data: fragmentableData,
+ reserve: reserve,
+ fragmentPayloadLen: int(fragmentPayloadLen),
+ fragmentCount: int(fragmentCount),
}
}
@@ -302,7 +318,7 @@ func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int,
})
// Copy data for the fragment.
- copied := pf.data.ReadToVV(&fragPkt.Data, pf.innerMTU)
+ copied := pf.data.ReadToVV(&fragPkt.Data, pf.fragmentPayloadLen)
offset := pf.fragmentOffset
pf.fragmentOffset += copied
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
index d3c7d7f92..5dcd10730 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -105,7 +105,7 @@ func TestFragmentationProcess(t *testing.T) {
f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{})
firstFragmentProto := c.in[0].proto
for i, in := range c.in {
- vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv)
+ vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv, nil)
if err != nil {
t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %X) failed: %s",
in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), err)
@@ -240,7 +240,7 @@ func TestReassemblingTimeout(t *testing.T) {
for _, event := range test.events {
clock.Advance(event.clockAdvance)
if frag := event.fragment; frag != nil {
- _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data))
+ _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data), nil)
if err != nil {
t.Fatalf("%s: f.Process failed: %s", event.name, err)
}
@@ -259,15 +259,15 @@ func TestReassemblingTimeout(t *testing.T) {
func TestMemoryLimits(t *testing.T) {
f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{})
// Send first fragment with id = 0.
- f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"))
+ f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"), nil)
// Send first fragment with id = 1.
- f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1"))
+ f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1"), nil)
// Send first fragment with id = 2.
- f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2"))
+ f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2"), nil)
// Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
// evicted.
- f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3"))
+ f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3"), nil)
if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok {
t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
@@ -283,9 +283,9 @@ func TestMemoryLimits(t *testing.T) {
func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{})
// Send first fragment with id = 0.
- f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"))
+ f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil)
// Send the same packet again.
- f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"))
+ f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil)
got := f.size
want := 1
@@ -377,7 +377,7 @@ func TestErrors(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{})
- _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data))
+ _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data), nil)
if !errors.Is(err, test.err) {
t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err)
}
@@ -403,14 +403,14 @@ func TestPacketFragmenter(t *testing.T) {
tests := []struct {
name string
- innerMTU int
+ fragmentPayloadLen uint32
transportHeaderLen int
payloadSize int
wantFragments []fragmentInfo
}{
{
name: "Packet exactly fits in MTU",
- innerMTU: 1280,
+ fragmentPayloadLen: 1280,
transportHeaderLen: 0,
payloadSize: 1280,
wantFragments: []fragmentInfo{
@@ -419,7 +419,7 @@ func TestPacketFragmenter(t *testing.T) {
},
{
name: "Packet exactly does not fit in MTU",
- innerMTU: 1000,
+ fragmentPayloadLen: 1000,
transportHeaderLen: 0,
payloadSize: 1001,
wantFragments: []fragmentInfo{
@@ -429,7 +429,7 @@ func TestPacketFragmenter(t *testing.T) {
},
{
name: "Packet has a transport header",
- innerMTU: 560,
+ fragmentPayloadLen: 560,
transportHeaderLen: 40,
payloadSize: 560,
wantFragments: []fragmentInfo{
@@ -439,7 +439,7 @@ func TestPacketFragmenter(t *testing.T) {
},
{
name: "Packet has a huge transport header",
- innerMTU: 500,
+ fragmentPayloadLen: 500,
transportHeaderLen: 1300,
payloadSize: 500,
wantFragments: []fragmentInfo{
@@ -458,7 +458,7 @@ func TestPacketFragmenter(t *testing.T) {
originalPayload.AppendView(pkt.TransportHeader().View())
originalPayload.Append(pkt.Data)
var reassembledPayload buffer.VectorisedView
- pf := MakePacketFragmenter(pkt, test.innerMTU, reserve)
+ pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve)
for i := 0; ; i++ {
fragPkt, offset, copied, more := pf.BuildNextFragment()
wantFragment := test.wantFragments[i]
@@ -474,8 +474,8 @@ func TestPacketFragmenter(t *testing.T) {
if more != wantFragment.more {
t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more)
}
- if got := fragPkt.Size(); got > test.innerMTU {
- t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.innerMTU)
+ if got := uint32(fragPkt.Size()); got > test.fragmentPayloadLen {
+ t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.fragmentPayloadLen)
}
if got := fragPkt.AvailableHeaderBytes(); got != reserve {
t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve)
@@ -497,3 +497,89 @@ func TestPacketFragmenter(t *testing.T) {
})
}
}
+
+func TestReleaseCallback(t *testing.T) {
+ const (
+ proto = 99
+ )
+
+ var result int
+ var callbackReasonIsTimeout bool
+ cb1 := func(timedOut bool) { result = 1; callbackReasonIsTimeout = timedOut }
+ cb2 := func(timedOut bool) { result = 2; callbackReasonIsTimeout = timedOut }
+
+ tests := []struct {
+ name string
+ callbacks []func(bool)
+ timeout bool
+ wantResult int
+ wantCallbackReasonIsTimeout bool
+ }{
+ {
+ name: "callback runs on release",
+ callbacks: []func(bool){cb1},
+ timeout: false,
+ wantResult: 1,
+ wantCallbackReasonIsTimeout: false,
+ },
+ {
+ name: "first callback is nil",
+ callbacks: []func(bool){nil, cb2},
+ timeout: false,
+ wantResult: 2,
+ wantCallbackReasonIsTimeout: false,
+ },
+ {
+ name: "two callbacks - first one is set",
+ callbacks: []func(bool){cb1, cb2},
+ timeout: false,
+ wantResult: 1,
+ wantCallbackReasonIsTimeout: false,
+ },
+ {
+ name: "callback runs on timeout",
+ callbacks: []func(bool){cb1},
+ timeout: true,
+ wantResult: 1,
+ wantCallbackReasonIsTimeout: true,
+ },
+ {
+ name: "no callbacks",
+ callbacks: []func(bool){nil},
+ timeout: false,
+ wantResult: 0,
+ wantCallbackReasonIsTimeout: false,
+ },
+ }
+
+ id := FragmentID{ID: 0}
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ result = 0
+ callbackReasonIsTimeout = false
+
+ f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{})
+
+ for i, cb := range test.callbacks {
+ _, _, _, err := f.Process(id, uint16(i), uint16(i), true, proto, vv(1, "0"), cb)
+ if err != nil {
+ t.Errorf("f.Process error = %s", err)
+ }
+ }
+
+ r, ok := f.reassemblers[id]
+ if !ok {
+ t.Fatalf("Reassemberr not found")
+ }
+ f.release(r, test.timeout)
+
+ if result != test.wantResult {
+ t.Errorf("got result = %d, want = %d", result, test.wantResult)
+ }
+ if callbackReasonIsTimeout != test.wantCallbackReasonIsTimeout {
+ t.Errorf("got callbackReasonIsTimeout = %t, want = %t", callbackReasonIsTimeout, test.wantCallbackReasonIsTimeout)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index 9bb051a30..c0cc0bde0 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -41,6 +41,7 @@ type reassembler struct {
heap fragHeap
done bool
creationTime int64
+ callback func(bool)
}
func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler {
@@ -123,3 +124,24 @@ func (r *reassembler) checkDoneOrMark() bool {
r.mu.Unlock()
return prev
}
+
+func (r *reassembler) setCallback(c func(bool)) bool {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.callback != nil {
+ return false
+ }
+ r.callback = c
+ return true
+}
+
+func (r *reassembler) release(timedOut bool) {
+ r.mu.Lock()
+ callback := r.callback
+ r.callback = nil
+ r.mu.Unlock()
+
+ if callback != nil {
+ callback(timedOut)
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
index a0a04a027..fa2a70dc8 100644
--- a/pkg/tcpip/network/fragmentation/reassembler_test.go
+++ b/pkg/tcpip/network/fragmentation/reassembler_test.go
@@ -105,3 +105,26 @@ func TestUpdateHoles(t *testing.T) {
}
}
}
+
+func TestSetCallback(t *testing.T) {
+ result := 0
+ reasonTimeout := false
+
+ cb1 := func(timedOut bool) { result = 1; reasonTimeout = timedOut }
+ cb2 := func(timedOut bool) { result = 2; reasonTimeout = timedOut }
+
+ r := newReassembler(FragmentID{}, &faketime.NullClock{})
+ if !r.setCallback(cb1) {
+ t.Errorf("setCallback failed")
+ }
+ if r.setCallback(cb2) {
+ t.Errorf("setCallback should fail if one is already set")
+ }
+ r.release(true)
+ if result != 1 {
+ t.Errorf("got result = %d, want = 1", result)
+ }
+ if !reasonTimeout {
+ t.Errorf("got reasonTimeout = %t, want = true", reasonTimeout)
+ }
+}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index f20b94d97..8873bd91f 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -110,8 +110,9 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff
// DeliverTransportPacket is called by network endpoints after parsing incoming
// packets. This is used by the test object to verify that the results of the
// parsing are expected.
-func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition {
- t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress)
+func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition {
+ netHdr := pkt.Network()
+ t.checkValues(protocol, pkt.Data, netHdr.SourceAddress(), netHdr.DestinationAddress())
t.dataCalls++
return stack.TransportPacketHandled
}
@@ -304,6 +305,10 @@ func (t *testInterface) setEnabled(v bool) {
t.mu.disabled = !v
}
+func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
func TestSourceAddressValidation(t *testing.T) {
rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) {
totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
@@ -604,7 +609,8 @@ func TestIPv4Receive(t *testing.T) {
if _, _, ok := proto.Parse(pkt); !ok {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
- ep.HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
@@ -690,6 +696,10 @@ func TestIPv4ReceiveControl(t *testing.T) {
view[i] = uint8(i)
}
+ icmp.SetChecksum(0)
+ checksum := ^header.Checksum(icmp, 0 /* initial */)
+ icmp.SetChecksum(checksum)
+
// Give packet to IPv4 endpoint, dispatcher will validate that
// it's ok.
nic.testObject.protocol = 10
@@ -699,7 +709,9 @@ func TestIPv4ReceiveControl(t *testing.T) {
nic.testObject.typ = c.expectedTyp
nic.testObject.extra = c.expectedExtra
- ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize))
+ pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
if want := c.expectedCount; nic.testObject.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
}
@@ -780,7 +792,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
if _, _, ok := proto.Parse(pkt); !ok {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
- ep.HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 0 {
t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls)
}
@@ -792,7 +805,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
if _, _, ok := proto.Parse(pkt); !ok {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
- ep.HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
@@ -892,7 +906,8 @@ func TestIPv6Receive(t *testing.T) {
if _, _, ok := proto.Parse(pkt); !ok {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
- ep.HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
if nic.testObject.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls)
}
@@ -1009,7 +1024,9 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Set ICMPv6 checksum.
icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{}))
- ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize))
+ pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
if want := c.expectedCount; nic.testObject.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want)
}
@@ -1063,7 +1080,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum tcpip.NetworkProtocolNumber
nicAddr tcpip.Address
remoteAddr tcpip.Address
- pktGen func(*testing.T, tcpip.Address) buffer.View
+ pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView
checker func(*testing.T, *stack.PacketBuffer, tcpip.Address)
expectedErr *tcpip.Error
}{
@@ -1073,7 +1090,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv4.ProtocolNumber,
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv4MinimumSize + len(data)
hdr := buffer.NewPrependable(totalLen)
if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
@@ -1087,7 +1104,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return hdr.View()
+ return hdr.View().ToVectorisedView()
},
checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
if src == header.IPv4Any {
@@ -1115,7 +1132,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv4.ProtocolNumber,
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv4MinimumSize + len(data)
hdr := buffer.NewPrependable(totalLen)
if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
@@ -1129,7 +1146,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return hdr.View()
+ return hdr.View().ToVectorisedView()
},
expectedErr: tcpip.ErrMalformedHeader,
},
@@ -1139,7 +1156,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv4.ProtocolNumber,
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
@@ -1148,7 +1165,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return buffer.View(ip[:len(ip)-1])
+ return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
},
expectedErr: tcpip.ErrMalformedHeader,
},
@@ -1158,7 +1175,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv4.ProtocolNumber,
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
@@ -1167,7 +1184,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return buffer.View(ip)
+ return buffer.View(ip).ToVectorisedView()
},
checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
if src == header.IPv4Any {
@@ -1195,7 +1212,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv4.ProtocolNumber,
nicAddr: localIPv4Addr,
remoteAddr: remoteIPv4Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ipHdrLen := header.IPv4MinimumSize + len(ipv4Options)
totalLen := ipHdrLen + len(data)
hdr := buffer.NewPrependable(totalLen)
@@ -1213,7 +1230,49 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
if n := copy(ip.Options(), ipv4Options); n != len(ipv4Options) {
t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv4Options))
}
- return hdr.View()
+ return hdr.View().ToVectorisedView()
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv4Any {
+ src = localIPv4Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ hdrLen := header.IPv4MinimumSize + len(ipv4Options)
+ if len(netHdr.View()) != hdrLen {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
+ }
+
+ checker.IPv4(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv4Addr),
+ checker.IPv4HeaderLength(hdrLen),
+ checker.IPFullLength(uint16(hdrLen+len(data))),
+ checker.IPv4Options(ipv4Options),
+ checker.IPPayload(data),
+ )
+ },
+ },
+ {
+ name: "IPv4 with options and data across views",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
+ ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: uint8(header.IPv4MinimumSize + len(ipv4Options)),
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ vv := buffer.View(ip).ToVectorisedView()
+ vv.AppendView(ipv4Options)
+ vv.AppendView(data)
+ return vv
},
checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
if src == header.IPv4Any {
@@ -1243,7 +1302,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv6.ProtocolNumber,
nicAddr: localIPv6Addr,
remoteAddr: remoteIPv6Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv6MinimumSize + len(data)
hdr := buffer.NewPrependable(totalLen)
if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
@@ -1256,7 +1315,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return hdr.View()
+ return hdr.View().ToVectorisedView()
},
checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
if src == header.IPv6Any {
@@ -1283,7 +1342,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv6.ProtocolNumber,
nicAddr: localIPv6Addr,
remoteAddr: remoteIPv6Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data)
hdr := buffer.NewPrependable(totalLen)
if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
@@ -1299,7 +1358,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return hdr.View()
+ return hdr.View().ToVectorisedView()
},
checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
if src == header.IPv6Any {
@@ -1326,7 +1385,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv6.ProtocolNumber,
nicAddr: localIPv6Addr,
remoteAddr: remoteIPv6Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
NextHeader: transportProto,
@@ -1334,7 +1393,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return buffer.View(ip)
+ return buffer.View(ip).ToVectorisedView()
},
checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
if src == header.IPv6Any {
@@ -1361,7 +1420,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
protoNum: ipv6.ProtocolNumber,
nicAddr: localIPv6Addr,
remoteAddr: remoteIPv6Addr,
- pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView {
ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
NextHeader: transportProto,
@@ -1369,7 +1428,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
SrcAddr: src,
DstAddr: header.IPv4Any,
})
- return buffer.View(ip[:len(ip)-1])
+ return buffer.View(ip[:len(ip)-1]).ToVectorisedView()
},
expectedErr: tcpip.ErrMalformedHeader,
},
@@ -1413,7 +1472,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) {
defer r.Release()
if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
- Data: test.pktGen(t, subTest.srcAddr).ToVectorisedView(),
+ Data: test.pktGen(t, subTest.srcAddr),
})); err != test.expectedErr {
t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr)
}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 7fc12e229..6252614ec 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -29,6 +29,7 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 3407755ed..9b5e37fee 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -15,6 +15,7 @@
package ipv4
import (
+ "errors"
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -23,10 +24,10 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
-// handleControl handles the case when an ICMP packet contains the headers of
-// the original packet that caused the ICMP one to be sent. This information is
-// used to find out which transport endpoint must be notified about the ICMP
-// packet.
+// handleControl handles the case when an ICMP error packet contains the headers
+// of the original packet that caused the ICMP one to be sent. This information
+// is used to find out which transport endpoint must be notified about the ICMP
+// packet. We only expect the payload, not the enclosing ICMP packet.
func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
@@ -41,8 +42,8 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
//
// Drop packet if it doesn't have the basic IPv4 header or if the
// original source address doesn't match an address we own.
- src := hdr.SourceAddress()
- if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 {
+ srcAddr := hdr.SourceAddress()
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, srcAddr) == 0 {
return
}
@@ -57,11 +58,11 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
// Skip the ip header, then deliver control message.
pkt.Data.TrimFront(hlen)
p := hdr.TransportProtocol()
- e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
+ e.dispatcher.DeliverTransportControlPacket(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
-func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
- stats := r.Stats()
+func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) {
+ stats := e.protocol.stack.Stats()
received := stats.ICMP.V4PacketsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
// TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
@@ -73,20 +74,65 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
}
h := header.ICMPv4(v)
+ // Only do in-stack processing if the checksum is correct.
+ if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xffff {
+ received.Invalid.Increment()
+ // It's possible that a raw socket expects to receive this regardless
+ // of checksum errors. If it's an echo request we know it's safe because
+ // we are the only handler, however other types do not cope well with
+ // packets with checksum errors.
+ switch h.Type() {
+ case header.ICMPv4Echo:
+ e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt)
+ }
+ return
+ }
+
+ iph := header.IPv4(pkt.NetworkHeader().View())
+ var newOptions header.IPv4Options
+ if len(iph) > header.IPv4MinimumSize {
+ // RFC 1122 section 3.2.2.6 (page 43) (and similar for other round trip
+ // type ICMP packets):
+ // If a Record Route and/or Time Stamp option is received in an
+ // ICMP Echo Request, this option (these options) SHOULD be
+ // updated to include the current host and included in the IP
+ // header of the Echo Reply message, without "truncation".
+ // Thus, the recorded route will be for the entire round trip.
+ //
+ // So we need to let the option processor know how it should handle them.
+ var op optionsUsage
+ if h.Type() == header.ICMPv4Echo {
+ op = &optionUsageEcho{}
+ } else {
+ op = &optionUsageReceive{}
+ }
+ aux, tmp, err := e.processIPOptions(pkt, iph.Options(), op)
+ if err != nil {
+ switch {
+ case
+ errors.Is(err, header.ErrIPv4OptDuplicate),
+ errors.Is(err, errIPv4RecordRouteOptInvalidLength),
+ errors.Is(err, errIPv4RecordRouteOptInvalidPointer),
+ errors.Is(err, errIPv4TimestampOptInvalidLength),
+ errors.Is(err, errIPv4TimestampOptInvalidPointer),
+ errors.Is(err, errIPv4TimestampOptOverflow):
+ _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt)
+ stats.MalformedRcvdPackets.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ }
+ return
+ }
+ newOptions = tmp
+ }
+
// TODO(b/112892170): Meaningfully handle all ICMP types.
switch h.Type() {
case header.ICMPv4Echo:
received.Echo.Increment()
- // Only send a reply if the checksum is valid.
- headerChecksum := h.Checksum()
- h.SetChecksum(0)
- calculatedChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
- h.SetChecksum(headerChecksum)
- if calculatedChecksum != headerChecksum {
- // It's possible that a raw socket still expects to receive this.
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
- received.Invalid.Increment()
+ sent := stats.ICMP.V4PacketsSent
+ if !e.protocol.stack.AllowICMPMessage() {
+ sent.RateLimited.Increment()
return
}
@@ -98,19 +144,27 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
// waiting endpoints. Consider moving responsibility for doing the copy to
// DeliverTransportPacket so that is is only done when needed.
replyData := pkt.Data.ToOwnedView()
- replyIPHdr := header.IPv4(append(buffer.View(nil), pkt.NetworkHeader().View()...))
+ ipHdr := header.IPv4(pkt.NetworkHeader().View())
+ localAddressBroadcast := pkt.NetworkPacketInfo.LocalAddressBroadcast
+
+ // It's possible that a raw socket expects to receive this.
+ e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt)
+ pkt = nil
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
+ // Take the base of the incoming request IP header but replace the options.
+ replyHeaderLength := uint8(header.IPv4MinimumSize + len(newOptions))
+ replyIPHdr := header.IPv4(append(iph[:header.IPv4MinimumSize:header.IPv4MinimumSize], newOptions...))
+ replyIPHdr.SetHeaderLength(replyHeaderLength)
// As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP
// source address MUST be one of its own IP addresses (but not a broadcast
// or multicast address).
- localAddr := r.LocalAddress
- if r.IsInboundBroadcast() || header.IsV4MulticastAddress(localAddr) {
+ localAddr := ipHdr.DestinationAddress()
+ if localAddressBroadcast || header.IsV4MulticastAddress(localAddr) {
localAddr = ""
}
- r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ r, err := e.protocol.stack.FindRoute(e.nic.ID(), localAddr, ipHdr.SourceAddress(), ProtocolNumber, false /* multicastLoop */)
if err != nil {
// If we cannot find a route to the destination, silently drop the packet.
return
@@ -139,7 +193,8 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
// The fields we need to alter.
//
// We need to produce the entire packet in the data segment in order to
- // use WriteHeaderIncludedPacket().
+ // use WriteHeaderIncludedPacket(). WriteHeaderIncludedPacket sets the
+ // total length and the header checksum so we don't need to set those here.
replyIPHdr.SetSourceAddress(r.LocalAddress)
replyIPHdr.SetDestinationAddress(r.RemoteAddress)
replyIPHdr.SetTTL(r.DefaultTTL())
@@ -157,8 +212,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
})
replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
- // The checksum will be calculated so we don't need to do it here.
- sent := stats.ICMP.V4PacketsSent
if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil {
sent.Dropped.Increment()
return
@@ -168,7 +221,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
case header.ICMPv4EchoReply:
received.EchoReply.Increment()
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
+ e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt)
case header.ICMPv4DstUnreachable:
received.DstUnreachable.Increment()
@@ -182,8 +235,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
e.handleControl(stack.ControlPortUnreachable, 0, pkt)
case header.ICMPv4FragmentationNeeded:
- mtu := uint32(h.MTU())
- e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt)
+ networkMTU, err := calculateNetworkMTU(uint32(h.MTU()), header.IPv4MinimumSize)
+ if err != nil {
+ networkMTU = 0
+ }
+ e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt)
}
case header.ICMPv4SrcQuench:
@@ -234,12 +290,31 @@ type icmpReasonProtoUnreachable struct{}
func (*icmpReasonProtoUnreachable) isICMPReason() {}
+// icmpReasonReassemblyTimeout is an error where insufficient fragments are
+// received to complete reassembly of a packet within a configured time after
+// the reception of the first-arriving fragment of that packet.
+type icmpReasonReassemblyTimeout struct{}
+
+func (*icmpReasonReassemblyTimeout) isICMPReason() {}
+
+// icmpReasonParamProblem is an error to use to request a Parameter Problem
+// message to be sent.
+type icmpReasonParamProblem struct {
+ pointer byte
+}
+
+func (*icmpReasonParamProblem) isICMPReason() {}
+
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv4 and sends it back to the remote device that sent
// the problematic packet. It incorporates as much of that packet as
// possible as well as any error metadata as is available. returnError
// expects pkt to hold a valid IPv4 packet as per the wire format.
-func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+ origIPHdr := header.IPv4(pkt.NetworkHeader().View())
+ origIPHdrSrc := origIPHdr.SourceAddress()
+ origIPHdrDst := origIPHdr.DestinationAddress()
+
// We check we are responding only when we are allowed to.
// See RFC 1812 section 4.3.2.7 (shown below).
//
@@ -263,8 +338,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
//
// TODO(gvisor.dev/issues/4058): Make sure we don't send ICMP errors in
// response to a non-initial fragment, but it currently can not happen.
-
- if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv4Any {
+ if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(origIPHdrDst) || origIPHdrSrc == header.IPv4Any {
return nil
}
@@ -272,14 +346,11 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
// a route to it - the remote may be blocked via routing rules. We must always
// consult our routing table and find a route to the remote before sending any
// packet.
- route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */)
if err != nil {
return err
}
defer route.Release()
- // From this point on, the incoming route should no longer be used; route
- // must be used to send the ICMP error.
- r = nil
sent := p.stack.Stats().ICMP.V4PacketsSent
if !p.stack.AllowICMPMessage() {
@@ -287,11 +358,10 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
return nil
}
- networkHeader := pkt.NetworkHeader().View()
transportHeader := pkt.TransportHeader().View()
// Don't respond to icmp error packets.
- if header.IPv4(networkHeader).Protocol() == uint8(header.ICMPv4ProtocolNumber) {
+ if origIPHdr.Protocol() == uint8(header.ICMPv4ProtocolNumber) {
// TODO(gvisor.dev/issue/3810):
// Unfortunately the current stack pretty much always has ICMPv4 headers
// in the Data section of the packet but there is no guarantee that is the
@@ -348,7 +418,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
return nil
}
- payloadLen := networkHeader.Size() + transportHeader.Size() + pkt.Data.Size()
+ payloadLen := len(origIPHdr) + transportHeader.Size() + pkt.Data.Size()
if payloadLen > available {
payloadLen = available
}
@@ -360,7 +430,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
// view with the entire incoming IP packet reassembled and truncated as
// required. This is now the payload of the new ICMP packet and no longer
// considered a packet in its own right.
- newHeader := append(buffer.View(nil), networkHeader...)
+ newHeader := append(buffer.View(nil), origIPHdr...)
newHeader = append(newHeader, transportHeader...)
payload := newHeader.ToVectorisedView()
payload.AppendView(pkt.Data.ToView())
@@ -374,17 +444,29 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
- switch reason.(type) {
+ var counter *tcpip.StatCounter
+ switch reason := reason.(type) {
case *icmpReasonPortUnreachable:
+ icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4PortUnreachable)
+ counter = sent.DstUnreachable
case *icmpReasonProtoUnreachable:
+ icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
+ counter = sent.DstUnreachable
+ case *icmpReasonReassemblyTimeout:
+ icmpHdr.SetType(header.ICMPv4TimeExceeded)
+ icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout)
+ counter = sent.TimeExceeded
+ case *icmpReasonParamProblem:
+ icmpHdr.SetType(header.ICMPv4ParamProblem)
+ icmpHdr.SetCode(header.ICMPv4UnusedCode)
+ icmpHdr.SetPointer(reason.pointer)
+ counter = sent.ParamProblem
default:
panic(fmt.Sprintf("unsupported ICMP type %T", reason))
}
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data))
- counter := sent.DstUnreachable
if err := route.WritePacket(
nil, /* gso */
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index e7c58ae0a..cfd0c505a 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -16,7 +16,9 @@
package ipv4
import (
+ "errors"
"fmt"
+ "math"
"sync/atomic"
"time"
@@ -31,6 +33,8 @@ import (
)
const (
+ // ReassembleTimeout is the time a packet stays in the reassembly
+ // system before being evicted.
// As per RFC 791 section 3.2:
// The current recommendation for the initial timer setting is 15 seconds.
// This may be changed as experience with this protocol accumulates.
@@ -38,7 +42,7 @@ const (
// Considering that it is an old recommendation, we use the same reassembly
// timeout that linux defines, which is 30 seconds:
// https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ip.h#L138
- reassembleTimeout = 30 * time.Second
+ ReassembleTimeout = 30 * time.Second
// ProtocolNumber is the ipv4 protocol number.
ProtocolNumber = header.IPv4ProtocolNumber
@@ -176,7 +180,11 @@ func (e *endpoint) DefaultTTL() uint8 {
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
func (e *endpoint) MTU() uint32 {
- return calculateMTU(e.nic.MTU())
+ networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv4MinimumSize)
+ if err != nil {
+ return 0
+ }
+ return networkMTU
}
// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
@@ -211,18 +219,15 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
pkt.NetworkProtocolNumber = ProtocolNumber
}
-func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool {
- return (gso == nil || gso.Type == stack.GSONone) && pkt.Size() > int(e.nic.MTU())
-}
-
// handleFragments fragments pkt and calls the handler function on each
// fragment. It returns the number of fragments handled and the number of
// fragments left to be processed. The IP header must already be present in the
-// original packet. The mtu is the maximum size of the packets.
-func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
- fragMTU := int(calculateFragmentInnerMTU(mtu, pkt))
+// original packet.
+func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
+ // Round the MTU down to align to 8 bytes.
+ fragmentPayloadSize := networkMTU &^ 7
networkHeader := header.IPv4(pkt.NetworkHeader().View())
- pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, pkt.AvailableHeaderBytes()+len(networkHeader))
+ pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadSize, pkt.AvailableHeaderBytes()+len(networkHeader))
var n int
for {
@@ -247,8 +252,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
- ipt := e.protocol.stack.IPTables()
- if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
+ if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesOutputDropped.Increment()
return nil
@@ -265,23 +269,40 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
netHeader := header.IPv4(pkt.NetworkHeader().View())
ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress())
if err == nil {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- ep.HandlePacket(&route, pkt)
+ pkt := pkt.CloneToInbound()
+ if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ route.PopulatePacketInfo(pkt)
+ // Since we rewrote the packet but it is being routed back to us, we can
+ // safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = true
+ ep.HandlePacket(pkt)
+ }
return nil
}
}
if r.Loop&stack.PacketLoop != 0 {
- loopedR := r.MakeLoopedRoute()
- e.HandlePacket(&loopedR, pkt)
- loopedR.Release()
+ pkt := pkt.CloneToInbound()
+ if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
+ loopedR := r.MakeLoopedRoute()
+ loopedR.PopulatePacketInfo(pkt)
+ loopedR.Release()
+ e.HandlePacket(pkt)
+ }
}
if r.Loop&stack.PacketOut == 0 {
return nil
}
- if e.packetMustBeFragmented(pkt, gso) {
- sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return err
+ }
+
+ if packetMustBeFragmented(pkt, networkMTU, gso) {
+ sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
// TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
// fragment one by one using WritePacket() (current strategy) or if we
// want to create a PacketBufferList from the fragments and feed it to
@@ -292,6 +313,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain))
return err
}
+
if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return err
@@ -311,17 +333,23 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
e.addIPHeader(r, pkt, params)
- if e.packetMustBeFragmented(pkt, gso) {
+ networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
+ return 0, err
+ }
+
+ if packetMustBeFragmented(pkt, networkMTU, gso) {
// Keep track of the packet that is about to be fragmented so it can be
// removed once the fragmentation is done.
originalPkt := pkt
- if _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
// Modify the packet list in place with the new fragments.
pkts.InsertAfter(pkt, fragPkt)
pkt = fragPkt
return nil
}); err != nil {
- panic(fmt.Sprintf("e.handleFragments(_, _, %d, _, _) = %s", e.nic.MTU(), err))
+ panic(fmt.Sprintf("e.handleFragments(_, _, %d, _, _) = %s", networkMTU, err))
}
// Remove the packet that was just fragmented and process the rest.
pkts.Remove(originalPkt)
@@ -355,10 +383,12 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv4(pkt.NetworkHeader().View())
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
- src := netHeader.SourceAddress()
- dst := netHeader.DestinationAddress()
- route := r.ReverseRoute(src, dst)
- ep.HandlePacket(&route, pkt)
+ pkt := pkt.CloneToInbound()
+ if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ route.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
+ }
n++
continue
}
@@ -385,6 +415,16 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
if !ok {
return tcpip.ErrMalformedHeader
}
+
+ hdrLen := header.IPv4(h).HeaderLength()
+ if hdrLen < header.IPv4MinimumSize {
+ return tcpip.ErrMalformedHeader
+ }
+
+ h, ok = pkt.Data.PullUp(int(hdrLen))
+ if !ok {
+ return tcpip.ErrMalformedHeader
+ }
ip := header.IPv4(h)
// Always set the total length.
@@ -429,14 +469,17 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if !e.isEnabled() {
return
}
+ pkt.NICID = e.nic.ID()
+ stats := e.protocol.stack.Stats()
+
h := header.IPv4(pkt.NetworkHeader().View())
if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
return
}
@@ -462,7 +505,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// is all 1 bits (-0 in 1's complement arithmetic), the check
// succeeds.
if h.CalculateChecksum() != 0xffff {
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
return
}
@@ -470,8 +513,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// When a host sends any datagram, the IP source address MUST
// be one of its own IP addresses (but not a broadcast or
// multicast address).
- if r.IsOutboundBroadcast() || header.IsV4MulticastAddress(r.RemoteAddress) {
- r.Stats().IP.InvalidSourceAddressesReceived.Increment()
+ if pkt.NetworkPacketInfo.RemoteAddressBroadcast || header.IsV4MulticastAddress(h.SourceAddress()) {
+ stats.IP.InvalidSourceAddressesReceived.Increment()
return
}
@@ -480,7 +523,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
- r.Stats().IP.IPTablesInputDropped.Increment()
+ stats.IP.IPTablesInputDropped.Increment()
return
}
@@ -488,8 +531,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 {
// Drop the packet as it's marked as a fragment but has
// no payload.
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
return
}
// The packet is a fragment, let's try to reassemble it.
@@ -502,10 +545,30 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// size). Otherwise the packet would've been rejected as invalid before
// reaching here.
if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
return
}
+
+ // Set up a callback in case we need to send a Time Exceeded Message, as per
+ // RFC 792:
+ //
+ // If a host reassembling a fragmented datagram cannot complete the
+ // reassembly due to missing fragments within its time limit it discards
+ // the datagram, and it may send a time exceeded message.
+ //
+ // If fragment zero is not available then no time exceeded need be sent at
+ // all.
+ var releaseCB func(bool)
+ if start == 0 {
+ pkt := pkt.Clone()
+ releaseCB = func(timedOut bool) {
+ if timedOut {
+ _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt)
+ }
+ }
+ }
+
var ready bool
var err error
proto := h.Protocol()
@@ -523,29 +586,56 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
h.More(),
proto,
pkt.Data,
+ releaseCB,
)
if err != nil {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
return
}
if !ready {
return
}
+
+ // The reassembler doesn't take care of fixing up the header, so we need
+ // to do it here.
+ h.SetTotalLength(uint16(pkt.Data.Size() + len((h))))
+ h.SetFlagsFragmentOffset(0, 0)
}
+ stats.IP.PacketsDelivered.Increment()
- r.Stats().IP.PacketsDelivered.Increment()
p := h.TransportProtocol()
if p == header.ICMPv4ProtocolNumber {
// TODO(gvisor.dev/issues/3810): when we sort out ICMP and transport
// headers, the setting of the transport number here should be
// unnecessary and removed.
pkt.TransportProtocolNumber = p
- e.handleICMP(r, pkt)
+ e.handleICMP(pkt)
return
}
+ if len(h.Options()) != 0 {
+ // TODO(gvisor.dev/issue/4586):
+ // When we add forwarding support we should use the verified options
+ // rather than just throwing them away.
+ aux, _, err := e.processIPOptions(pkt, h.Options(), &optionUsageReceive{})
+ if err != nil {
+ switch {
+ case
+ errors.Is(err, header.ErrIPv4OptDuplicate),
+ errors.Is(err, errIPv4RecordRouteOptInvalidPointer),
+ errors.Is(err, errIPv4RecordRouteOptInvalidLength),
+ errors.Is(err, errIPv4TimestampOptInvalidLength),
+ errors.Is(err, errIPv4TimestampOptInvalidPointer),
+ errors.Is(err, errIPv4TimestampOptOverflow):
+ _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt)
+ stats.MalformedRcvdPackets.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ }
+ return
+ }
+ }
- switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res {
+ switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res {
case stack.TransportPacketHandled:
case stack.TransportPacketDestinationPortUnreachable:
// As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination
@@ -553,13 +643,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// 3 (Port Unreachable), when the designated transport protocol
// (e.g., UDP) is unable to demultiplex the datagram but has no
// protocol mechanism to inform the sender.
- _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ _ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt)
case stack.TransportPacketProtocolUnreachable:
// As per RFC: 1122 Section 3.2.2.1
// A host SHOULD generate Destination Unreachable messages with code:
// 2 (Protocol Unreachable), when the designated transport protocol
// is not supported
- _ = e.protocol.returnError(r, &icmpReasonProtoUnreachable{}, pkt)
+ _ = e.protocol.returnError(&icmpReasonProtoUnreachable{}, pkt)
default:
panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res))
}
@@ -602,7 +692,7 @@ func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp boo
loopback := e.nic.IsLoopback()
addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool {
- subnet := addressEndpoint.AddressWithPrefix().Subnet()
+ subnet := addressEndpoint.Subnet()
// IPv4 has a notion of a subnet broadcast address and considers the
// loopback interface bound to an address's whole subnet (on linux).
return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr))
@@ -778,26 +868,32 @@ func (p *protocol) SetForwarding(v bool) {
}
}
-// calculateMTU calculates the network-layer payload MTU based on the link-layer
-// payload mtu.
-func calculateMTU(mtu uint32) uint32 {
- if mtu > MaxTotalSize {
- mtu = MaxTotalSize
+// calculateNetworkMTU calculates the network-layer payload MTU based on the
+// link-layer payload mtu.
+func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, *tcpip.Error) {
+ if linkMTU < header.IPv4MinimumMTU {
+ return 0, tcpip.ErrInvalidEndpointState
}
- return mtu - header.IPv4MinimumSize
-}
-// calculateFragmentInnerMTU calculates the maximum number of bytes of
-// fragmentable data a fragment can have, based on the link layer mtu and pkt's
-// network header size.
-func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 {
- if mtu > MaxTotalSize {
- mtu = MaxTotalSize
+ // As per RFC 791 section 3.1, an IPv4 header cannot exceed 60 bytes in
+ // length:
+ // The maximal internet header is 60 octets, and a typical internet header
+ // is 20 octets, allowing a margin for headers of higher level protocols.
+ if networkHeaderSize > header.IPv4MaximumHeaderSize {
+ return 0, tcpip.ErrMalformedHeader
}
- mtu -= uint32(pkt.NetworkHeader().View().Size())
- // Round the MTU down to align to 8 bytes.
- mtu &^= 7
- return mtu
+
+ networkMTU := linkMTU
+ if networkMTU > MaxTotalSize {
+ networkMTU = MaxTotalSize
+ }
+
+ return networkMTU - uint32(networkHeaderSize), nil
+}
+
+func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool {
+ payload := pkt.TransportHeader().View().Size() + pkt.Data.Size()
+ return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU
}
// addressToUint32 translates an IPv4 address into its little endian uint32
@@ -836,7 +932,7 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
ids: ids,
hashIV: hashIV,
defaultTTL: DefaultTTL,
- fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()),
+ fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()),
}
}
@@ -846,6 +942,7 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader head
originalIPHeaderLength := len(originalIPHeader)
nextFragIPHeader := header.IPv4(fragPkt.NetworkHeader().Push(originalIPHeaderLength))
+ fragPkt.NetworkProtocolNumber = ProtocolNumber
if copied := copy(nextFragIPHeader, originalIPHeader); copied != len(originalIPHeader) {
panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got = %d, want = %d", copied, originalIPHeaderLength))
@@ -862,3 +959,324 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader head
return fragPkt, more
}
+
+// optionAction describes possible actions that may be taken on an option
+// while processing it.
+type optionAction uint8
+
+const (
+ // optionRemove says that the option should not be in the output option set.
+ optionRemove optionAction = iota
+
+ // optionProcess says that the option should be fully processed.
+ optionProcess
+
+ // optionVerify says the option should be checked and passed unchanged.
+ optionVerify
+
+ // optionPass says to pass the output set without checking.
+ optionPass
+)
+
+// optionActions list what to do for each option in a given scenario.
+type optionActions struct {
+ // timestamp controls what to do with a Timestamp option.
+ timestamp optionAction
+
+ // recordroute controls what to do with a Record Route option.
+ recordRoute optionAction
+
+ // unknown controls what to do with an unknown option.
+ unknown optionAction
+}
+
+// optionsUsage specifies the ways options may be operated upon for a given
+// scenario during packet processing.
+type optionsUsage interface {
+ actions() optionActions
+}
+
+// optionUsageReceive implements optionsUsage for received packets.
+type optionUsageReceive struct{}
+
+// actions implements optionsUsage.
+func (*optionUsageReceive) actions() optionActions {
+ return optionActions{
+ timestamp: optionVerify,
+ recordRoute: optionVerify,
+ unknown: optionPass,
+ }
+}
+
+// TODO(gvisor.dev/issue/4586): Add an entry here for forwarding when it
+// is enabled (Process, Process, Pass) and for fragmenting (Process, Process,
+// Pass for frag1, but Remove,Remove,Remove for all other frags).
+
+// optionUsageEcho implements optionsUsage for echo packet processing.
+type optionUsageEcho struct{}
+
+// actions implements optionsUsage.
+func (*optionUsageEcho) actions() optionActions {
+ return optionActions{
+ timestamp: optionProcess,
+ recordRoute: optionProcess,
+ unknown: optionRemove,
+ }
+}
+
+var (
+ errIPv4TimestampOptInvalidLength = errors.New("invalid Timestamp length")
+ errIPv4TimestampOptInvalidPointer = errors.New("invalid Timestamp pointer")
+ errIPv4TimestampOptOverflow = errors.New("overflow in Timestamp")
+ errIPv4TimestampOptInvalidFlags = errors.New("invalid Timestamp flags")
+)
+
+// handleTimestamp does any required processing on a Timestamp option
+// in place.
+func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Address, clock tcpip.Clock, usage optionsUsage) (uint8, error) {
+ flags := tsOpt.Flags()
+ var entrySize uint8
+ switch flags {
+ case header.IPv4OptionTimestampOnlyFlag:
+ entrySize = header.IPv4OptionTimestampSize
+ case
+ header.IPv4OptionTimestampWithIPFlag,
+ header.IPv4OptionTimestampWithPredefinedIPFlag:
+ entrySize = header.IPv4OptionTimestampWithAddrSize
+ default:
+ return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptInvalidFlags
+ }
+
+ pointer := tsOpt.Pointer()
+ // To simplify processing below, base further work on the array of timestamps
+ // beyond the header, rather than on the whole option. Also to aid
+ // calculations set 'nextSlot' to be 0 based as in the packet it is 1 based.
+ nextSlot := pointer - (header.IPv4OptionTimestampHdrLength + 1)
+ optLen := tsOpt.Size()
+ dataLength := optLen - header.IPv4OptionTimestampHdrLength
+
+ // In the section below, we verify the pointer, length and overflow counter
+ // fields of the option. The distinction is in which byte you return as being
+ // in error in the ICMP packet. Offsets 1 (length), 2 pointer)
+ // or 3 (overflowed counter).
+ //
+ // The following RFC sections cover this section:
+ //
+ // RFC 791 (page 22):
+ // If there is some room but not enough room for a full timestamp
+ // to be inserted, or the overflow count itself overflows, the
+ // original datagram is considered to be in error and is discarded.
+ // In either case an ICMP parameter problem message may be sent to
+ // the source host [3].
+ //
+ // You can get this situation in two ways. Firstly if the data area is not
+ // a multiple of the entry size or secondly, if the pointer is not at a
+ // multiple of the entry size. The wording of the RFC suggests that
+ // this is not an error until you actually run out of space.
+ if pointer > optLen {
+ // RFC 791 (page 22) says we should switch to using the overflow count.
+ // If the timestamp data area is already full (the pointer exceeds
+ // the length) the datagram is forwarded without inserting the
+ // timestamp, but the overflow count is incremented by one.
+ if flags == header.IPv4OptionTimestampWithPredefinedIPFlag {
+ // By definition we have nothing to do.
+ return 0, nil
+ }
+
+ if tsOpt.IncOverflow() != 0 {
+ return 0, nil
+ }
+ // The overflow count is also full.
+ return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptOverflow
+ }
+ if nextSlot+entrySize > dataLength {
+ // The data area isn't full but there isn't room for a new entry.
+ // Either Length or Pointer could be bad.
+ if false {
+ // We must select Pointer for Linux compatibility, even if
+ // only the length is bad.
+ // The Linux code is at (in October 2020)
+ // https://github.com/torvalds/linux/blob/bbf5c979011a099af5dc76498918ed7df445635b/net/ipv4/ip_options.c#L367-L370
+ // if (optptr[2]+3 > optlen) {
+ // pp_ptr = optptr + 2;
+ // goto error;
+ // }
+ // which doesn't distinguish between which of optptr[2] or optlen
+ // is wrong, but just arbitrarily decides on optptr+2.
+ if dataLength%entrySize != 0 {
+ // The Data section size should be a multiple of the expected
+ // timestamp entry size.
+ return header.IPv4OptionLengthOffset, errIPv4TimestampOptInvalidLength
+ }
+ // If the size is OK, the pointer must be corrupted.
+ }
+ return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer
+ }
+
+ if usage.actions().timestamp == optionProcess {
+ tsOpt.UpdateTimestamp(localAddress, clock)
+ }
+ return 0, nil
+}
+
+var (
+ errIPv4RecordRouteOptInvalidLength = errors.New("invalid length in Record Route")
+ errIPv4RecordRouteOptInvalidPointer = errors.New("invalid pointer in Record Route")
+)
+
+// handleRecordRoute checks and processes a Record route option. It is much
+// like the timestamp type 1 option, but without timestamps. The passed in
+// address is stored in the option in the correct spot if possible.
+func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Address, usage optionsUsage) (uint8, error) {
+ optlen := rrOpt.Size()
+
+ if optlen < header.IPv4AddressSize+header.IPv4OptionRecordRouteHdrLength {
+ return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength
+ }
+
+ nextSlot := rrOpt.Pointer() - 1 // Pointer is 1 based.
+
+ // RFC 791 page 21 says
+ // If the route data area is already full (the pointer exceeds the
+ // length) the datagram is forwarded without inserting the address
+ // into the recorded route. If there is some room but not enough
+ // room for a full address to be inserted, the original datagram is
+ // considered to be in error and is discarded. In either case an
+ // ICMP parameter problem message may be sent to the source
+ // host.
+ // The use of the words "In either case" suggests that a 'full' RR option
+ // could generate an ICMP at every hop after it fills up. We chose to not
+ // do this (as do most implementations). It is probable that the inclusion
+ // of these words is a copy/paste error from the timestamp option where
+ // there are two failure reasons given.
+ if nextSlot >= optlen {
+ return 0, nil
+ }
+
+ // The data area isn't full but there isn't room for a new entry.
+ // Either Length or Pointer could be bad. We must select Pointer for Linux
+ // compatibility, even if only the length is bad.
+ if nextSlot+header.IPv4AddressSize > optlen {
+ if false {
+ // This is what we would do if we were not being Linux compatible.
+ // Check for bad pointer or length value. Must be a multiple of 4 after
+ // accounting for the 3 byte header and not within that header.
+ // RFC 791, page 20 says:
+ // The pointer is relative to this option, and the
+ // smallest legal value for the pointer is 4.
+ //
+ // A recorded route is composed of a series of internet addresses.
+ // Each internet address is 32 bits or 4 octets.
+ // Linux skips this test so we must too. See Linux code at:
+ // https://github.com/torvalds/linux/blob/bbf5c979011a099af5dc76498918ed7df445635b/net/ipv4/ip_options.c#L338-L341
+ // if (optptr[2]+3 > optlen) {
+ // pp_ptr = optptr + 2;
+ // goto error;
+ // }
+ if (optlen-header.IPv4OptionRecordRouteHdrLength)%header.IPv4AddressSize != 0 {
+ // Length is bad, not on integral number of slots.
+ return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength
+ }
+ // If not length, the fault must be with the pointer.
+ }
+ return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer
+ }
+ if usage.actions().recordRoute == optionVerify {
+ return 0, nil
+ }
+ rrOpt.StoreAddress(localAddress)
+ return 0, nil
+}
+
+// processIPOptions parses the IPv4 options and produces a new set of options
+// suitable for use in the next step of packet processing as informed by usage.
+// The original will not be touched.
+//
+// Returns
+// - The location of an error if there was one (or 0 if no error)
+// - If there is an error, information as to what it was was.
+// - The replacement option set.
+func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) {
+ stats := e.protocol.stack.Stats()
+ opts := header.IPv4Options(orig)
+ optIter := opts.MakeIterator()
+
+ // Each option other than NOP must only appear (RFC 791 section 3.1, at the
+ // definition of every type). Keep track of each of the possible types in
+ // the 8 bit 'type' field.
+ var seenOptions [math.MaxUint8 + 1]bool
+
+ // TODO(gvisor.dev/issue/4586):
+ // This will need tweaking when we start really forwarding packets
+ // as we may need to get two addresses, for rx and tx interfaces.
+ // We will also have to take usage into account.
+ prefixedAddress, err := e.protocol.stack.GetMainNICAddress(e.nic.ID(), ProtocolNumber)
+ localAddress := prefixedAddress.Address
+ if err != nil {
+ h := header.IPv4(pkt.NetworkHeader().View())
+ dstAddr := h.DestinationAddress()
+ if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(dstAddr) {
+ return 0 /* errCursor */, nil, header.ErrIPv4OptionAddress
+ }
+ localAddress = dstAddr
+ }
+
+ for {
+ option, done, err := optIter.Next()
+ if done || err != nil {
+ return optIter.ErrCursor, optIter.Finalize(), err
+ }
+ optType := option.Type()
+ if optType == header.IPv4OptionNOPType {
+ optIter.PushNOPOrEnd(optType)
+ continue
+ }
+ if optType == header.IPv4OptionListEndType {
+ optIter.PushNOPOrEnd(optType)
+ return 0 /* errCursor */, optIter.Finalize(), nil /* err */
+ }
+
+ // check for repeating options (multiple NOPs are OK)
+ if seenOptions[optType] {
+ return optIter.ErrCursor, nil, header.ErrIPv4OptDuplicate
+ }
+ seenOptions[optType] = true
+
+ optLen := int(option.Size())
+ switch option := option.(type) {
+ case *header.IPv4OptionTimestamp:
+ stats.IP.OptionTSReceived.Increment()
+ if usage.actions().timestamp != optionRemove {
+ clock := e.protocol.stack.Clock()
+ newBuffer := optIter.RemainingBuffer()[:len(*option)]
+ _ = copy(newBuffer, option.Contents())
+ offset, err := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage)
+ if err != nil {
+ return optIter.ErrCursor + offset, nil, err
+ }
+ optIter.ConsumeBuffer(optLen)
+ }
+
+ case *header.IPv4OptionRecordRoute:
+ stats.IP.OptionRRReceived.Increment()
+ if usage.actions().recordRoute != optionRemove {
+ newBuffer := optIter.RemainingBuffer()[:len(*option)]
+ _ = copy(newBuffer, option.Contents())
+ offset, err := handleRecordRoute(header.IPv4OptionRecordRoute(newBuffer), localAddress, usage)
+ if err != nil {
+ return optIter.ErrCursor + offset, nil, err
+ }
+ optIter.ConsumeBuffer(optLen)
+ }
+
+ default:
+ stats.IP.OptionUnknownReceived.Increment()
+ if usage.actions().unknown == optionPass {
+ newBuffer := optIter.RemainingBuffer()[:optLen]
+ // Arguments already heavily checked.. ignore result.
+ _ = copy(newBuffer, option.Contents())
+ optIter.ConsumeBuffer(optLen)
+ }
+ }
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index fee11bb38..c7f434591 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -21,11 +21,13 @@ import (
"math"
"net"
"testing"
+ "time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
@@ -39,7 +41,10 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-const extraHeaderReserve = 50
+const (
+ extraHeaderReserve = 50
+ defaultMTU = 65536
+)
func TestExcludeBroadcast(t *testing.T) {
s := stack.New(stack.Options{
@@ -47,7 +52,6 @@ func TestExcludeBroadcast(t *testing.T) {
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- const defaultMTU = 65536
ep := stack.LinkEndpoint(channel.New(256, defaultMTU, ""))
if testing.Verbose() {
ep = sniffer.New(ep)
@@ -103,7 +107,6 @@ func TestExcludeBroadcast(t *testing.T) {
// checks the response.
func TestIPv4Sanity(t *testing.T) {
const (
- defaultMTU = header.IPv6MinimumMTU
ttl = 255
nicID = 1
randomSequence = 123
@@ -118,27 +121,29 @@ func TestIPv4Sanity(t *testing.T) {
)
tests := []struct {
- name string
- headerLength uint8 // value of 0 means "use correct size"
- badHeaderChecksum bool
- maxTotalLength uint16
- transportProtocol uint8
- TTL uint8
- shouldFail bool
- expectICMP bool
- ICMPType header.ICMPv4Type
- ICMPCode header.ICMPv4Code
- options []byte
+ name string
+ headerLength uint8 // value of 0 means "use correct size"
+ badHeaderChecksum bool
+ maxTotalLength uint16
+ transportProtocol uint8
+ TTL uint8
+ options []byte
+ replyOptions []byte // if succeeds, reply should look like this
+ shouldFail bool
+ expectErrorICMP bool
+ ICMPType header.ICMPv4Type
+ ICMPCode header.ICMPv4Code
+ paramProblemPointer uint8
}{
{
- name: "valid",
- maxTotalLength: defaultMTU,
+ name: "valid no options",
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
},
{
name: "bad header checksum",
- maxTotalLength: defaultMTU,
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
badHeaderChecksum: true,
@@ -157,47 +162,47 @@ func TestIPv4Sanity(t *testing.T) {
// received with TTL less than 2.
{
name: "zero TTL",
- maxTotalLength: defaultMTU,
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: 0,
- shouldFail: false,
},
{
name: "one TTL",
- maxTotalLength: defaultMTU,
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: 1,
- shouldFail: false,
},
{
name: "End options",
- maxTotalLength: defaultMTU,
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
options: []byte{0, 0, 0, 0},
+ replyOptions: []byte{0, 0, 0, 0},
},
{
name: "NOP options",
- maxTotalLength: defaultMTU,
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
options: []byte{1, 1, 1, 1},
+ replyOptions: []byte{1, 1, 1, 1},
},
{
name: "NOP and End options",
- maxTotalLength: defaultMTU,
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
options: []byte{1, 1, 0, 0},
+ replyOptions: []byte{1, 1, 0, 0},
},
{
name: "bad header length",
headerLength: header.IPv4MinimumSize - 1,
- maxTotalLength: defaultMTU,
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
shouldFail: true,
- expectICMP: false,
},
{
name: "bad total length (0)",
@@ -205,7 +210,6 @@ func TestIPv4Sanity(t *testing.T) {
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
shouldFail: true,
- expectICMP: false,
},
{
name: "bad total length (ip - 1)",
@@ -213,7 +217,6 @@ func TestIPv4Sanity(t *testing.T) {
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
shouldFail: true,
- expectICMP: false,
},
{
name: "bad total length (ip + icmp - 1)",
@@ -221,28 +224,361 @@ func TestIPv4Sanity(t *testing.T) {
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
shouldFail: true,
- expectICMP: false,
},
{
name: "bad protocol",
- maxTotalLength: defaultMTU,
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: 99,
TTL: ttl,
shouldFail: true,
- expectICMP: true,
+ expectErrorICMP: true,
ICMPType: header.ICMPv4DstUnreachable,
ICMPCode: header.ICMPv4ProtoUnreachable,
},
+ {
+ name: "timestamp option overflow",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 12, 13, 0x11,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ },
+ replyOptions: []byte{
+ 68, 12, 13, 0x21,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ },
+ },
+ {
+ name: "timestamp option overflow full",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 12, 13, 0xF1,
+ // ^ Counter full (15/0xF)
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 3,
+ replyOptions: []byte{},
+ },
+ {
+ name: "unknown option",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{10, 4, 9, 0},
+ // ^^
+ // The unknown option should be stripped out of the reply.
+ replyOptions: []byte{},
+ },
+ {
+ name: "bad option - length 0",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 0, 9, 0,
+ // ^
+ 1, 2, 3, 4,
+ },
+ shouldFail: true,
+ },
+ {
+ name: "bad option - length big",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 9, 9, 0,
+ // ^
+ // There are only 8 bytes allocated to options so 9 bytes of timestamp
+ // space is not possible. (Second byte)
+ 1, 2, 3, 4,
+ },
+ shouldFail: true,
+ },
+ {
+ // This tests for some linux compatible behaviour.
+ // The ICMP pointer returned is 22 for Linux but the
+ // error is actually in spot 21.
+ name: "bad option - length bad",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ // Timestamps are in multiples of 4 or 8 but never 7.
+ // The option space should be padded out.
+ options: []byte{
+ 68, 7, 5, 0,
+ // ^ ^ Linux points here which is wrong.
+ // | Not a multiple of 4
+ 1, 2, 3,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ {
+ name: "multiple type 0 with room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 24, 21, 0x00,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 0, 0, 0, 0,
+ },
+ replyOptions: []byte{
+ 68, 24, 25, 0x00,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
+ },
+ },
+ {
+ // The timestamp area is full so add to the overflow count.
+ name: "multiple type 1 timestamps",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 20, 21, 0x11,
+ // ^
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 192, 168, 1, 13,
+ 5, 6, 7, 8,
+ },
+ // Overflow count is the top nibble of the 4th byte.
+ replyOptions: []byte{
+ 68, 20, 21, 0x21,
+ // ^
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 192, 168, 1, 13,
+ 5, 6, 7, 8,
+ },
+ },
+ {
+ name: "multiple type 1 timestamps with room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 28, 21, 0x01,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 192, 168, 1, 13,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ replyOptions: []byte{
+ 68, 28, 29, 0x01,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 192, 168, 1, 13,
+ 5, 6, 7, 8,
+ 192, 168, 1, 58, // New IP Address.
+ 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
+ },
+ },
+ {
+ // Needs 8 bytes for a type 1 timestamp but there are only 4 free.
+ name: "bad timer element alignment",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 20, 17, 0x01,
+ // ^^ ^^ 20 byte area, next free spot at 17.
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ // End of option list with illegal option after it, which should be ignored.
+ {
+ name: "end of options list",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 12, 13, 0x11,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 0, 10, 3, 99,
+ },
+ replyOptions: []byte{
+ 68, 12, 13, 0x21,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 0, 0, 0, 0, // 3 bytes unknown option
+ }, // ^ End of options hides following bytes.
+ },
+ {
+ // Timestamp with a size too small.
+ name: "timestamp truncated",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{68, 1, 0, 0},
+ // ^ Smallest possible is 8.
+ shouldFail: true,
+ },
+ {
+ name: "single record route with room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 7, 4, // 3 byte header
+ 0, 0, 0, 0,
+ 0,
+ },
+ replyOptions: []byte{
+ 7, 7, 8, // 3 byte header
+ 192, 168, 1, 58, // New IP Address.
+ 0, // padding to multiple of 4 bytes.
+ },
+ },
+ {
+ name: "multiple record route with room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 23, 20, // 3 byte header
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 0, 0, 0, 0,
+ 0,
+ },
+ replyOptions: []byte{
+ 7, 23, 24,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 192, 168, 1, 58, // New IP Address.
+ 0, // padding to multiple of 4 bytes.
+ },
+ },
+ {
+ name: "single record route with no room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 7, 8, // 3 byte header
+ 1, 2, 3, 4,
+ 0,
+ },
+ replyOptions: []byte{
+ 7, 7, 8, // 3 byte header
+ 1, 2, 3, 4,
+ 0, // padding to multiple of 4 bytes.
+ },
+ },
+ {
+ // Unlike timestamp, this should just succeed.
+ name: "multiple record route with no room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 23, 24, // 3 byte header
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 0,
+ },
+ replyOptions: []byte{
+ 7, 23, 24,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 0, // padding to multiple of 4 bytes.
+ },
+ },
+ {
+ // Confirm linux bug for bug compatibility.
+ // Linux returns slot 22 but the error is in slot 21.
+ name: "multiple record route with not enough room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 8, 8, // 3 byte header
+ // ^ ^ Linux points here. We must too.
+ // | Not enough room. 1 byte free, need 4.
+ 1, 2, 3, 4,
+ 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ replyOptions: []byte{},
+ },
+ {
+ name: "duplicate record route",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 7, 8, // 3 byte header
+ 1, 2, 3, 4,
+ 7, 7, 8, // 3 byte header
+ 1, 2, 3, 4,
+ 0, 0, // pad
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 7,
+ replyOptions: []byte{},
+ },
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
+ Clock: clock,
})
// We expect at most a single packet in response to our ICMP Echo Request.
- e := channel.New(1, defaultMTU, "")
+ e := channel.New(1, ipv4.MaxTotalSize, "")
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
@@ -250,6 +586,9 @@ func TestIPv4Sanity(t *testing.T) {
if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
}
+ // Advance the clock by some unimportant amount to make
+ // sure it's all set up.
+ clock.Advance(time.Millisecond * 0x10203040)
// Default routes for IPv4 so ICMP can find a route to the remote
// node when attempting to send the ICMP Echo Reply.
@@ -312,14 +651,20 @@ func TestIPv4Sanity(t *testing.T) {
reply, ok := e.Read()
if !ok {
if test.shouldFail {
- if test.expectICMP {
- t.Fatal("expected ICMP error response missing")
+ if test.expectErrorICMP {
+ t.Fatalf("ICMP error response (type %d, code %d) missing", test.ICMPType, test.ICMPCode)
}
return // Expected silent failure.
}
t.Fatal("expected ICMP echo reply missing")
}
+ // We didn't expect a packet. Register our surprise but carry on to
+ // provide more information about what we got.
+ if test.shouldFail && !test.expectErrorICMP {
+ t.Error("unexpected packet response")
+ }
+
// Check the route that brought the packet to us.
if reply.Route.LocalAddress != ipv4Addr.Address {
t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address)
@@ -328,57 +673,90 @@ func TestIPv4Sanity(t *testing.T) {
t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr)
}
- // Make sure it's all in one buffer.
- vv := buffer.NewVectorisedView(reply.Pkt.Size(), reply.Pkt.Views())
- replyIPHeader := header.IPv4(vv.ToView())
+ // Make sure it's all in one buffer for checker.
+ replyIPHeader := header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader()))
- // At this stage we only know it's an IP header so verify that much.
+ // At this stage we only know it's probably an IP+ICMP header so verify
+ // that much.
checker.IPv4(t, replyIPHeader,
checker.SrcAddr(ipv4Addr.Address),
checker.DstAddr(remoteIPv4Addr),
+ checker.ICMPv4(
+ checker.ICMPv4Checksum(),
+ ),
)
- // All expected responses are ICMP packets.
- if got, want := replyIPHeader.Protocol(), uint8(header.ICMPv4ProtocolNumber); got != want {
- t.Fatalf("not ICMP response, got protocol %d, want = %d", got, want)
+ // Don't proceed any further if the checker found problems.
+ if t.Failed() {
+ t.FailNow()
}
- replyICMPHeader := header.ICMPv4(replyIPHeader.Payload())
- // Sanity check the response.
+ // OK it's ICMP. We can safely look at the type now.
+ replyICMPHeader := header.ICMPv4(replyIPHeader.Payload())
switch replyICMPHeader.Type() {
- case header.ICMPv4DstUnreachable:
+ case header.ICMPv4ParamProblem:
+ if !test.shouldFail {
+ t.Fatalf("got Parameter Problem with pointer %d, wanted Echo Reply", replyICMPHeader.Pointer())
+ }
+ if !test.expectErrorICMP {
+ t.Fatalf("got Parameter Problem with pointer %d, wanted no response", replyICMPHeader.Pointer())
+ }
checker.IPv4(t, replyIPHeader,
checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())),
checker.IPv4HeaderLength(header.IPv4MinimumSize),
checker.ICMPv4(
+ checker.ICMPv4Type(test.ICMPType),
checker.ICMPv4Code(test.ICMPCode),
- checker.ICMPv4Checksum(),
+ checker.ICMPv4Pointer(test.paramProblemPointer),
checker.ICMPv4Payload([]byte(hdr.View())),
),
)
- if !test.shouldFail || !test.expectICMP {
- t.Fatalf("unexpected packet rejection, got ICMP error packet type %d, code %d",
+ return
+ case header.ICMPv4DstUnreachable:
+ if !test.shouldFail {
+ t.Fatalf("got ICMP error packet type %d, code %d, wanted Echo Reply",
+ header.ICMPv4DstUnreachable, replyICMPHeader.Code())
+ }
+ if !test.expectErrorICMP {
+ t.Fatalf("got ICMP error packet type %d, code %d, wanted no response",
header.ICMPv4DstUnreachable, replyICMPHeader.Code())
}
+ checker.IPv4(t, replyIPHeader,
+ checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.ICMPv4(
+ checker.ICMPv4Type(test.ICMPType),
+ checker.ICMPv4Code(test.ICMPCode),
+ checker.ICMPv4Payload([]byte(hdr.View())),
+ ),
+ )
return
case header.ICMPv4EchoReply:
+ if test.shouldFail {
+ if !test.expectErrorICMP {
+ t.Error("got Echo Reply packet, want no response")
+ } else {
+ t.Errorf("got Echo Reply, want ICMP error type %d, code %d", test.ICMPType, test.ICMPCode)
+ }
+ }
+ // If the IP options change size then the packet will change size, so
+ // some IP header fields will need to be adjusted for the checks.
+ sizeChange := len(test.replyOptions) - len(test.options)
+
checker.IPv4(t, replyIPHeader,
- checker.IPv4HeaderLength(ipHeaderLength),
- checker.IPv4Options(test.options),
- checker.IPFullLength(uint16(requestPkt.Size())),
+ checker.IPv4HeaderLength(ipHeaderLength+sizeChange),
+ checker.IPv4Options(test.replyOptions),
+ checker.IPFullLength(uint16(requestPkt.Size()+sizeChange)),
checker.ICMPv4(
+ checker.ICMPv4Checksum(),
checker.ICMPv4Code(header.ICMPv4UnusedCode),
checker.ICMPv4Seq(randomSequence),
checker.ICMPv4Ident(randomIdent),
- checker.ICMPv4Checksum(),
),
)
- if test.shouldFail {
- t.Fatalf("unexpected Echo Reply packet\n")
- }
default:
- t.Fatalf("unexpected ICMP response, got type %d, want = %d or %d",
- replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable)
+ t.Fatalf("unexpected ICMP response, got type %d, want = %d, %d or %d",
+ replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable, header.ICMPv4ParamProblem)
}
})
}
@@ -462,7 +840,7 @@ var fragmentationTests = []struct {
wantFragments []fragmentInfo
}{
{
- description: "No Fragmentation",
+ description: "No fragmentation",
mtu: 1280,
gso: nil,
transportHeaderLength: 0,
@@ -483,6 +861,30 @@ var fragmentationTests = []struct {
},
},
{
+ description: "Fragmented with the minimum mtu",
+ mtu: header.IPv4MinimumMTU,
+ gso: nil,
+ transportHeaderLength: 0,
+ payloadSize: 100,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 48, more: true},
+ {offset: 48, payloadSize: 48, more: true},
+ {offset: 96, payloadSize: 4, more: false},
+ },
+ },
+ {
+ description: "Fragmented with mtu not a multiple of 8",
+ mtu: header.IPv4MinimumMTU + 1,
+ gso: nil,
+ transportHeaderLength: 0,
+ payloadSize: 100,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 48, more: true},
+ {offset: 48, payloadSize: 48, more: true},
+ {offset: 96, payloadSize: 4, more: false},
+ },
+ },
+ {
description: "No fragmentation with big header",
mtu: 2000,
gso: nil,
@@ -647,43 +1049,50 @@ func TestFragmentationWritePackets(t *testing.T) {
}
}
-// TestFragmentationErrors checks that errors are returned from write packet
+// TestFragmentationErrors checks that errors are returned from WritePacket
// correctly.
func TestFragmentationErrors(t *testing.T) {
const ttl = 42
- expectedError := tcpip.ErrAborted
- fragTests := []struct {
+ tests := []struct {
description string
mtu uint32
transportHeaderLength int
payloadSize int
allowPackets int
- fragmentCount int
+ outgoingErrors int
+ mockError *tcpip.Error
+ wantError *tcpip.Error
}{
{
description: "No frag",
mtu: 2000,
- transportHeaderLength: 0,
payloadSize: 1000,
+ transportHeaderLength: 0,
allowPackets: 0,
- fragmentCount: 1,
+ outgoingErrors: 1,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
},
{
description: "Error on first frag",
mtu: 500,
- transportHeaderLength: 0,
payloadSize: 1000,
+ transportHeaderLength: 0,
allowPackets: 0,
- fragmentCount: 3,
+ outgoingErrors: 3,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
},
{
description: "Error on second frag",
mtu: 500,
- transportHeaderLength: 0,
payloadSize: 1000,
+ transportHeaderLength: 0,
allowPackets: 1,
- fragmentCount: 3,
+ outgoingErrors: 2,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
},
{
description: "Error on first frag MTU smaller than header",
@@ -691,28 +1100,40 @@ func TestFragmentationErrors(t *testing.T) {
transportHeaderLength: 1000,
payloadSize: 500,
allowPackets: 0,
- fragmentCount: 4,
+ outgoingErrors: 4,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error when MTU is smaller than IPv4 minimum MTU",
+ mtu: header.IPv4MinimumMTU - 1,
+ transportHeaderLength: 0,
+ payloadSize: 500,
+ allowPackets: 0,
+ outgoingErrors: 1,
+ mockError: nil,
+ wantError: tcpip.ErrInvalidEndpointState,
},
}
- for _, ft := range fragTests {
+ for _, ft := range tests {
t.Run(ft.description, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(ft.mtu, expectedError, ft.allowPackets)
- r := buildRoute(t, ep)
pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
+ r := buildRoute(t, ep)
err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: ttl,
TOS: stack.DefaultTOS,
}, pkt)
- if err != expectedError {
- t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, expectedError)
+ if err != ft.wantError {
+ t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError)
}
- if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want {
- t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want)
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets {
+ t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets)
}
- if got, want := int(r.Stats().IP.OutgoingPacketErrors.Value()), ft.fragmentCount-ft.allowPackets; got != want {
- t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, want)
+ if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors)
}
})
}
@@ -744,7 +1165,6 @@ func TestInvalidFragments(t *testing.T) {
autoChecksum bool // if true, the Checksum field will be overwritten.
}
- // These packets have both IHL and TotalLength set to 0.
tests := []struct {
name string
fragments []fragmentData
@@ -984,7 +1404,6 @@ func TestInvalidFragments(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
-
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
@@ -1027,6 +1446,259 @@ func TestInvalidFragments(t *testing.T) {
}
}
+func TestFragmentReassemblyTimeout(t *testing.T) {
+ const (
+ nicID = 1
+ linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ addr1 = "\x0a\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x02"
+ tos = 0
+ ident = 1
+ ttl = 48
+ protocol = 99
+ data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT"
+ )
+
+ type fragmentData struct {
+ ipv4fields header.IPv4Fields
+ payload []byte
+ }
+
+ tests := []struct {
+ name string
+ fragments []fragmentData
+ expectICMP bool
+ }{
+ {
+ name: "first fragment only",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 16,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[:16],
+ },
+ },
+ expectICMP: true,
+ },
+ {
+ name: "two first fragments",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 16,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[:16],
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 16,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[:16],
+ },
+ },
+ expectICMP: true,
+ },
+ {
+ name: "second fragment only",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16),
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 8,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[16:],
+ },
+ },
+ expectICMP: false,
+ },
+ {
+ name: "two fragments with a gap",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 8,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[:8],
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16),
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 16,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[16:],
+ },
+ },
+ expectICMP: true,
+ },
+ {
+ name: "two fragments with a gap in reverse order",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16),
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 16,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[16:],
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 8,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[:8],
+ },
+ },
+ expectICMP: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocol,
+ },
+ Clock: clock,
+ })
+ e := channel.New(1, 1500, linkAddr)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ }})
+
+ var firstFragmentSent buffer.View
+ for _, f := range test.fragments {
+ pktSize := header.IPv4MinimumSize
+ hdr := buffer.NewPrependable(pktSize)
+
+ ip := header.IPv4(hdr.Prepend(pktSize))
+ ip.Encode(&f.ipv4fields)
+
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ vv := hdr.View().ToVectorisedView()
+ vv.AppendView(f.payload)
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ })
+
+ if firstFragmentSent == nil && ip.FragmentOffset() == 0 {
+ firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader())
+ }
+
+ e.InjectInbound(header.IPv4ProtocolNumber, pkt)
+ }
+
+ clock.Advance(ipv4.ReassembleTimeout)
+
+ reply, ok := e.Read()
+ if !test.expectICMP {
+ if ok {
+ t.Fatalf("unexpected ICMP error message received: %#v", reply)
+ }
+ return
+ }
+ if !ok {
+ t.Fatal("expected ICMP error message missing")
+ }
+ if firstFragmentSent == nil {
+ t.Fatalf("unexpected ICMP error message received: %#v", reply)
+ }
+
+ checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
+ checker.SrcAddr(addr2),
+ checker.DstAddr(addr1),
+ checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+firstFragmentSent.Size())),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4TimeExceeded),
+ checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout),
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Payload([]byte(firstFragmentSent)),
+ ),
+ )
+ })
+ }
+}
+
// TestReceiveFragments feeds fragments in through the incoming packet path to
// test reassembly
func TestReceiveFragments(t *testing.T) {
@@ -1506,13 +2178,10 @@ func TestWriteStats(t *testing.T) {
// Install Output DROP rule.
t.Helper()
ipt := stk.IPTables()
- filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */)
- if !ok {
- t.Fatalf("failed to find filter table")
- }
+ filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
ruleIdx := filter.BuiltinChains[stack.Output]
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
+ if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %s", err)
}
},
@@ -1527,17 +2196,14 @@ func TestWriteStats(t *testing.T) {
// of the 3 packets.
t.Helper()
ipt := stk.IPTables()
- filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */)
- if !ok {
- t.Fatalf("failed to find filter table")
- }
+ filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
// We'll match and DROP the last packet.
ruleIdx := filter.BuiltinChains[stack.Output]
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
// Make sure the next rule is ACCEPT.
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil {
+ if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %s", err)
}
},
@@ -1577,7 +2243,7 @@ func TestWriteStats(t *testing.T) {
t.Run(writer.name, func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumSize+header.UDPMinimumSize, tcpip.ErrInvalidEndpointState, test.allowPackets)
+ ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets)
rt := buildRoute(t, ep)
var pkts stack.PacketBufferList
@@ -1783,7 +2449,7 @@ func TestPacketQueing(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr)
+ e := channel.New(1, defaultMTU, host1NICLinkAddr)
e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index a30437f02..0ac24a6fb 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -36,6 +36,7 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index ead6bedcb..8502b848c 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -124,8 +124,8 @@ func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) {
})
}
-func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) {
- stats := r.Stats().ICMP
+func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) {
+ stats := e.protocol.stack.Stats().ICMP
sent := stats.V6PacketsSent
received := stats.V6PacketsReceived
// TODO(gvisor.dev/issue/170): ICMP packets don't have their
@@ -138,13 +138,15 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
}
h := header.ICMPv6(v)
iph := header.IPv6(pkt.NetworkHeader().View())
+ srcAddr := iph.SourceAddress()
+ dstAddr := iph.DestinationAddress()
// Validate ICMPv6 checksum before processing the packet.
//
// This copy is used as extra payload during the checksum calculation.
payload := pkt.Data.Clone(nil)
payload.TrimFront(len(h))
- if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want {
+ if got, want := h.Checksum(), header.ICMPv6Checksum(h, srcAddr, dstAddr, payload); got != want {
received.Invalid.Increment()
return
}
@@ -170,8 +172,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
return
}
pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
- mtu := header.ICMPv6(hdr).MTU()
- e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt)
+ networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize)
+ if err != nil {
+ networkMTU = 0
+ }
+ e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt)
case header.ICMPv6DstUnreachable:
received.DstUnreachable.Increment()
@@ -221,7 +226,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// we know we are also performing DAD on it). In this case we let the
// stack know so it can handle such a scenario and do nothing further with
// the NS.
- if r.RemoteAddress == header.IPv6Any {
+ if srcAddr == header.IPv6Any {
// We would get an error if the address no longer exists or the address
// is no longer tentative (DAD resolved between the call to
// hasTentativeAddr and this point). Both of these are valid scenarios:
@@ -248,7 +253,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// section 5.4.3.
// Is the NS targeting us?
- if r.Stack().CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 {
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 {
return
}
@@ -274,9 +279,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// Otherwise, on link layers that have addresses this option MUST be
// included in multicast solicitations and SHOULD be included in unicast
// solicitations.
- unspecifiedSource := r.RemoteAddress == header.IPv6Any
+ unspecifiedSource := srcAddr == header.IPv6Any
if len(sourceLinkAddr) == 0 {
- if header.IsV6MulticastAddress(r.LocalAddress) && !unspecifiedSource {
+ if header.IsV6MulticastAddress(dstAddr) && !unspecifiedSource {
received.Invalid.Increment()
return
}
@@ -284,9 +289,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
received.Invalid.Increment()
return
} else if e.nud != nil {
- e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
+ e.nud.HandleProbe(srcAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
} else {
- e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), srcAddr, sourceLinkAddr)
}
// As per RFC 4861 section 7.1.1:
@@ -295,7 +300,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// ...
// - If the IP source address is the unspecified address, the IP
// destination address is a solicited-node multicast address.
- if unspecifiedSource && !header.IsSolicitedNodeAddr(r.LocalAddress) {
+ if unspecifiedSource && !header.IsSolicitedNodeAddr(dstAddr) {
received.Invalid.Increment()
return
}
@@ -305,7 +310,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// If the source of the solicitation is the unspecified address, the node
// MUST [...] and multicast the advertisement to the all-nodes address.
//
- remoteAddr := r.RemoteAddress
+ remoteAddr := srcAddr
if unspecifiedSource {
remoteAddr = header.IPv6AllNodesMulticastAddress
}
@@ -462,12 +467,12 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// As per RFC 4291 section 2.7, multicast addresses must not be used as
// source addresses in IPv6 packets.
- localAddr := r.LocalAddress
- if header.IsV6MulticastAddress(r.LocalAddress) {
+ localAddr := dstAddr
+ if header.IsV6MulticastAddress(dstAddr) {
localAddr = ""
}
- r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ r, err := e.protocol.stack.FindRoute(e.nic.ID(), localAddr, srcAddr, ProtocolNumber, false /* multicastLoop */)
if err != nil {
// If we cannot find a route to the destination, silently drop the packet.
return
@@ -483,7 +488,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
copy(packet, icmpHdr)
packet.SetType(header.ICMPv6EchoReply)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data))
- if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, replyPkt); err != nil {
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: r.DefaultTTL(),
+ TOS: stack.DefaultTOS,
+ }, replyPkt); err != nil {
sent.Dropped.Increment()
return
}
@@ -495,7 +504,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
received.Invalid.Increment()
return
}
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, pkt)
+ e.dispatcher.DeliverTransportPacket(header.ICMPv6ProtocolNumber, pkt)
case header.ICMPv6TimeExceeded:
received.TimeExceeded.Increment()
@@ -516,7 +525,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
return
}
- stack := r.Stack()
+ stack := e.protocol.stack
// Is the networking stack operating as a router?
if !stack.Forwarding(ProtocolNumber) {
@@ -547,7 +556,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// As per RFC 4861 section 4.1, the Source Link-Layer Address Option MUST
// NOT be included when the source IP address is the unspecified address.
// Otherwise, it SHOULD be included on link layers that have addresses.
- if r.RemoteAddress == header.IPv6Any {
+ if srcAddr == header.IPv6Any {
received.Invalid.Increment()
return
}
@@ -555,7 +564,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
if e.nud != nil {
// A RS with a specified source IP address modifies the NUD state
// machine in the same way a reachability probe would.
- e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
+ e.nud.HandleProbe(srcAddr, ProtocolNumber, sourceLinkAddr, e.protocol)
}
}
@@ -572,7 +581,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
return
}
- routerAddr := iph.SourceAddress()
+ routerAddr := srcAddr
// Is the IP Source Address a link-local address?
if !header.IsV6LinkLocalAddress(routerAddr) {
@@ -605,7 +614,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// If the RA has the source link layer option, update the link address
// cache with the link address for the advertised router.
if len(sourceLinkAddr) != 0 && e.nud != nil {
- e.nud.HandleProbe(routerAddr, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
+ e.nud.HandleProbe(routerAddr, ProtocolNumber, sourceLinkAddr, e.protocol)
}
e.mu.Lock()
@@ -648,52 +657,46 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements stack.LinkAddressResolver.
-func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
- // TODO(b/148672031): Use stack.FindRoute instead of manually creating the
- // route here. Note, we would need the nicID to do this properly so the right
- // NIC (associated to linkEP) is used to send the NDP NS message.
- r := stack.Route{
- LocalAddress: localAddr,
- RemoteAddress: addr,
- LocalLinkAddress: linkEP.LinkAddress(),
- RemoteLinkAddress: remoteLinkAddr,
+func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error {
+ remoteAddr := targetAddr
+ if len(remoteLinkAddr) == 0 {
+ remoteAddr = header.SolicitedNodeAddr(targetAddr)
+ remoteLinkAddr = header.EthernetAddressFromMulticastIPv6Address(remoteAddr)
}
- // If a remote address is not already known, then send a multicast
- // solicitation since multicast addresses have a static mapping to link
- // addresses.
- if len(r.RemoteLinkAddress) == 0 {
- r.RemoteAddress = header.SolicitedNodeAddr(addr)
- r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(r.RemoteAddress)
+ r, err := p.stack.FindRoute(nic.ID(), localAddr, remoteAddr, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
}
+ defer r.Release()
+ r.ResolveWith(remoteLinkAddr)
optsSerializer := header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(linkEP.LinkAddress()),
+ header.NDPSourceLinkLayerAddressOption(nic.LinkAddress()),
}
neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length()
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + neighborSolicitSize,
+ ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborSolicitSize,
})
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize))
packet.SetType(header.ICMPv6NeighborSolicit)
ns := header.NDPNeighborSolicit(packet.NDPPayload())
- ns.SetTargetAddress(addr)
+ ns.SetTargetAddress(targetAddr)
ns.Options().Serialize(optsSerializer)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
- length := uint16(pkt.Size())
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: length,
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
- })
+ stat := p.stack.Stats().ICMP.V6PacketsSent
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ }, pkt); err != nil {
+ stat.Dropped.Increment()
+ return err
+ }
- // TODO(stijlist): count this in ICMP stats.
- return linkEP.WritePacket(&r, nil /* gso */, ProtocolNumber, pkt)
+ stat.NeighborSolicit.Increment()
+ return nil
}
// ResolveStaticAddress implements stack.LinkAddressResolver.
@@ -747,9 +750,20 @@ type icmpReasonPortUnreachable struct{}
func (*icmpReasonPortUnreachable) isICMPReason() {}
+// icmpReasonReassemblyTimeout is an error where insufficient fragments are
+// received to complete reassembly of a packet within a configured time after
+// the reception of the first-arriving fragment of that packet.
+type icmpReasonReassemblyTimeout struct{}
+
+func (*icmpReasonReassemblyTimeout) isICMPReason() {}
+
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv6 and sends it.
-func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
+ origIPHdr := header.IPv6(pkt.NetworkHeader().View())
+ origIPHdrSrc := origIPHdr.SourceAddress()
+ origIPHdrDst := origIPHdr.DestinationAddress()
+
// Only send ICMP error if the address is not a multicast v6
// address and the source is not the unspecified address.
//
@@ -776,7 +790,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
allowResponseToMulticast = reason.respondToMulticast
}
- if (!allowResponseToMulticast && header.IsV6MulticastAddress(r.LocalAddress)) || r.RemoteAddress == header.IPv6Any {
+ if (!allowResponseToMulticast && header.IsV6MulticastAddress(origIPHdrDst)) || origIPHdrSrc == header.IPv6Any {
return nil
}
@@ -784,14 +798,11 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
// a route to it - the remote may be blocked via routing rules. We must always
// consult our routing table and find a route to the remote before sending any
// packet.
- route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */)
if err != nil {
return err
}
defer route.Release()
- // From this point on, the incoming route should no longer be used; route
- // must be used to send the ICMP error.
- r = nil
stats := p.stack.Stats().ICMP
sent := stats.V6PacketsSent
@@ -839,7 +850,9 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
if payloadLen > available {
payloadLen = available
}
- payload := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
+ payload := network.ToVectorisedView()
+ payload.AppendView(transport)
+ payload.Append(pkt.Data)
payload.CapLength(payloadLen)
newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -860,6 +873,10 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
icmpHdr.SetType(header.ICMPv6DstUnreachable)
icmpHdr.SetCode(header.ICMPv6PortUnreachable)
counter = sent.DstUnreachable
+ case *icmpReasonReassemblyTimeout:
+ icmpHdr.SetType(header.ICMPv6TimeExceeded)
+ icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout)
+ counter = sent.TimeExceeded
default:
panic(fmt.Sprintf("unsupported ICMP type %T", reason))
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 8dc33c560..76013daa1 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -51,6 +51,7 @@ const (
var (
lladdr0 = header.LinkLocalAddr(linkAddr0)
lladdr1 = header.LinkLocalAddr(linkAddr1)
+ lladdr2 = header.LinkLocalAddr(linkAddr2)
)
type stubLinkEndpoint struct {
@@ -86,7 +87,7 @@ type stubDispatcher struct {
stack.TransportDispatcher
}
-func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition {
+func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition {
return stack.TransportPacketHandled
}
@@ -108,31 +109,27 @@ type stubNUDHandler struct {
var _ stack.NUDHandler = (*stubNUDHandler)(nil)
-func (s *stubNUDHandler) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes stack.LinkAddressResolver) {
+func (s *stubNUDHandler) HandleProbe(tcpip.Address, tcpip.NetworkProtocolNumber, tcpip.LinkAddress, stack.LinkAddressResolver) {
s.probeCount++
}
-func (s *stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags stack.ReachabilityConfirmationFlags) {
+func (s *stubNUDHandler) HandleConfirmation(tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) {
s.confirmationCount++
}
-func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) {
+func (*stubNUDHandler) HandleUpperLevelConfirmation(tcpip.Address) {
}
var _ stack.NetworkInterface = (*testInterface)(nil)
type testInterface struct {
- stack.NetworkLinkEndpoint
-
- linkAddr tcpip.LinkAddress
-}
+ stack.LinkEndpoint
-func (i *testInterface) LinkAddress() tcpip.LinkAddress {
- return i.linkAddr
+ nicID tcpip.NICID
}
func (*testInterface) ID() tcpip.NICID {
- return 0
+ return nicID
}
func (*testInterface) IsLoopback() bool {
@@ -147,6 +144,14 @@ func (*testInterface) Enabled() bool {
return true
}
+func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ r := stack.Route{
+ NetProto: protocol,
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+ return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt)
+}
+
func TestICMPCounts(t *testing.T) {
tests := []struct {
name string
@@ -277,7 +282,8 @@ func TestICMPCounts(t *testing.T) {
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- ep.HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
}
for _, typ := range types {
@@ -419,7 +425,8 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- ep.HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
}
for _, typ := range types {
@@ -1235,26 +1242,72 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
}
func TestLinkAddressRequest(t *testing.T) {
+ const nicID = 1
+
snaddr := header.SolicitedNodeAddr(lladdr0)
mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr)
tests := []struct {
- name string
- remoteLinkAddr tcpip.LinkAddress
- expectedLinkAddr tcpip.LinkAddress
- expectedAddr tcpip.Address
+ name string
+ nicAddr tcpip.Address
+ localAddr tcpip.Address
+ remoteLinkAddr tcpip.LinkAddress
+
+ expectedErr *tcpip.Error
+ expectedRemoteAddr tcpip.Address
+ expectedRemoteLinkAddr tcpip.LinkAddress
}{
{
- name: "Unicast",
- remoteLinkAddr: linkAddr1,
- expectedLinkAddr: linkAddr1,
- expectedAddr: lladdr0,
+ name: "Unicast",
+ nicAddr: lladdr1,
+ localAddr: lladdr1,
+ remoteLinkAddr: linkAddr1,
+ expectedRemoteAddr: lladdr0,
+ expectedRemoteLinkAddr: linkAddr1,
+ },
+ {
+ name: "Multicast",
+ nicAddr: lladdr1,
+ localAddr: lladdr1,
+ remoteLinkAddr: "",
+ expectedRemoteAddr: snaddr,
+ expectedRemoteLinkAddr: mcaddr,
+ },
+ {
+ name: "Unicast with unspecified source",
+ nicAddr: lladdr1,
+ remoteLinkAddr: linkAddr1,
+ expectedRemoteAddr: lladdr0,
+ expectedRemoteLinkAddr: linkAddr1,
},
{
- name: "Multicast",
- remoteLinkAddr: "",
- expectedLinkAddr: mcaddr,
- expectedAddr: snaddr,
+ name: "Multicast with unspecified source",
+ nicAddr: lladdr1,
+ remoteLinkAddr: "",
+ expectedRemoteAddr: snaddr,
+ expectedRemoteLinkAddr: mcaddr,
+ },
+ {
+ name: "Unicast with unassigned address",
+ localAddr: lladdr1,
+ remoteLinkAddr: linkAddr1,
+ expectedErr: tcpip.ErrNetworkUnreachable,
+ },
+ {
+ name: "Multicast with unassigned address",
+ localAddr: lladdr1,
+ remoteLinkAddr: "",
+ expectedErr: tcpip.ErrNetworkUnreachable,
+ },
+ {
+ name: "Unicast with no local address available",
+ remoteLinkAddr: linkAddr1,
+ expectedErr: tcpip.ErrNetworkUnreachable,
+ },
+ {
+ name: "Multicast with no local address available",
+ remoteLinkAddr: "",
+ expectedErr: tcpip.ErrNetworkUnreachable,
},
}
@@ -1269,26 +1322,43 @@ func TestLinkAddressRequest(t *testing.T) {
}
linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0)
- if err := linkRes.LinkAddressRequest(lladdr0, lladdr1, test.remoteLinkAddr, linkEP); err != nil {
- t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", lladdr0, lladdr1, test.remoteLinkAddr, err)
+ if err := s.CreateNIC(nicID, linkEP); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if len(test.nicAddr) != 0 {
+ if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err)
+ }
+ }
+
+ // We pass a test network interface to LinkAddressRequest with the same NIC
+ // ID and link endpoint used by the NIC we created earlier so that we can
+ // mock a link address request and observe the packets sent to the link
+ // endpoint even though the stack uses the real NIC.
+ if err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr {
+ t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", lladdr0, test.localAddr, test.remoteLinkAddr, err, test.expectedErr)
+ }
+
+ if test.expectedErr != nil {
+ return
}
pkt, ok := linkEP.Read()
if !ok {
t.Fatal("expected to send a link address request")
}
- if pkt.Route.RemoteLinkAddress != test.expectedLinkAddr {
- t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedLinkAddr)
+ if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr)
}
- if pkt.Route.RemoteAddress != test.expectedAddr {
- t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedAddr)
+ if pkt.Route.RemoteAddress != test.expectedRemoteAddr {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr)
}
if pkt.Route.LocalAddress != lladdr1 {
t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, lladdr1)
}
checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()),
checker.SrcAddr(lladdr1),
- checker.DstAddr(test.expectedAddr),
+ checker.DstAddr(test.expectedRemoteAddr),
checker.TTL(header.NDPHopLimit),
checker.NDPNS(
checker.NDPNSTargetAddress(lladdr0),
@@ -1698,7 +1768,7 @@ func TestCallsToNeighborCache(t *testing.T) {
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
nudHandler := &stubNUDHandler{}
- ep := netProto.NewEndpoint(&testInterface{linkAddr: linkAddr0}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{})
+ ep := netProto.NewEndpoint(&testInterface{LinkEndpoint: channel.New(0, header.IPv6MinimumMTU, linkAddr0)}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{})
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -1728,7 +1798,8 @@ func TestCallsToNeighborCache(t *testing.T) {
SrcAddr: r.RemoteAddress,
DstAddr: r.LocalAddress,
})
- ep.HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
// Confirm the endpoint calls the correct NUDHandler method.
if nudHandler.probeCount != test.wantProbeCount {
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 9670696c7..0526190cc 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -41,12 +41,12 @@ const (
//
// Linux also uses 60 seconds for reassembly timeout:
// https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ipv6.h#L456
- reassembleTimeout = 60 * time.Second
+ ReassembleTimeout = 60 * time.Second
// ProtocolNumber is the ipv6 protocol number.
ProtocolNumber = header.IPv6ProtocolNumber
- // maxTotalSize is maximum size that can be encoded in the 16-bit
+ // maxPayloadSize is the maximum size that can be encoded in the 16-bit
// PayloadLength field of the ipv6 header.
maxPayloadSize = 0xffff
@@ -166,7 +166,7 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
return err
}
- prefix := addressEndpoint.AddressWithPrefix().Subnet()
+ prefix := addressEndpoint.Subnet()
switch t := addressEndpoint.ConfigType(); t {
case stack.AddressConfigStatic:
@@ -363,7 +363,11 @@ func (e *endpoint) DefaultTTL() uint8 {
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
func (e *endpoint) MTU() uint32 {
- return calculateMTU(e.nic.MTU())
+ networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv6MinimumSize)
+ if err != nil {
+ return 0
+ }
+ return networkMTU
}
// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
@@ -386,27 +390,40 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
pkt.NetworkProtocolNumber = ProtocolNumber
}
-func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool {
- return (gso == nil || gso.Type == stack.GSONone) && pkt.Size() > int(e.nic.MTU())
+func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool {
+ payload := pkt.TransportHeader().View().Size() + pkt.Data.Size()
+ return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU
}
// handleFragments fragments pkt and calls the handler function on each
// fragment. It returns the number of fragments handled and the number of
// fragments left to be processed. The IP header must already be present in the
-// original packet. The mtu is the maximum size of the packets. The transport
-// header protocol number is required to avoid parsing the IPv6 extension
-// headers.
-func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
- fragMTU := int(calculateFragmentInnerMTU(mtu, pkt))
- if fragMTU < pkt.TransportHeader().View().Size() {
+// original packet. The transport header protocol number is required to avoid
+// parsing the IPv6 extension headers.
+func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
+ networkHeader := header.IPv6(pkt.NetworkHeader().View())
+
+ // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are
+ // supported for outbound packets, their length should not affect the fragment
+ // maximum payload length because they should only be transmitted once.
+ fragmentPayloadLen := (networkMTU - header.IPv6FragmentHeaderSize) &^ 7
+ if fragmentPayloadLen < header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit {
+ // We need at least 8 bytes of space left for the fragmentable part because
+ // the fragment payload must obviously be non-zero and must be a multiple
+ // of 8 as per RFC 8200 section 4.5:
+ // Each complete fragment, except possibly the last ("rightmost") one, is
+ // an integer multiple of 8 octets long.
+ return 0, 1, tcpip.ErrMessageTooLong
+ }
+
+ if fragmentPayloadLen < uint32(pkt.TransportHeader().View().Size()) {
// As per RFC 8200 Section 4.5, the Transport Header is expected to be small
// enough to fit in the first fragment.
return 0, 1, tcpip.ErrMessageTooLong
}
- pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, calculateFragmentReserve(pkt))
+ pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadLen, calculateFragmentReserve(pkt))
id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, e.protocol.hashIV)%buckets], 1)
- networkHeader := header.IPv6(pkt.NetworkHeader().View())
var n int
for {
@@ -448,28 +465,40 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
if pkt.NatDone {
netHeader := header.IPv6(pkt.NetworkHeader().View())
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
- route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
- ep.HandlePacket(&route, pkt)
+ pkt := pkt.CloneToInbound()
+ if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ route.PopulatePacketInfo(pkt)
+ // Since we rewrote the packet but it is being routed back to us, we can
+ // safely assume the checksum is valid.
+ pkt.RXTransportChecksumValidated = true
+ ep.HandlePacket(pkt)
+ }
return nil
}
}
if r.Loop&stack.PacketLoop != 0 {
- loopedR := r.MakeLoopedRoute()
-
- e.HandlePacket(&loopedR, stack.NewPacketBuffer(stack.PacketBufferOptions{
- // The inbound path expects an unparsed packet.
- Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
- }))
-
- loopedR.Release()
+ pkt := pkt.CloneToInbound()
+ if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
+ loopedR := r.MakeLoopedRoute()
+ loopedR.PopulatePacketInfo(pkt)
+ loopedR.Release()
+ e.HandlePacket(pkt)
+ }
}
if r.Loop&stack.PacketOut == 0 {
return nil
}
- if e.packetMustBeFragmented(pkt, gso) {
- sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return err
+ }
+
+ if packetMustBeFragmented(pkt, networkMTU, gso) {
+ sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
// TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
// fragment one by one using WritePacket() (current strategy) or if we
// want to create a PacketBufferList from the fragments and feed it to
@@ -499,13 +528,20 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return pkts.Len(), nil
}
+ linkMTU := e.nic.MTU()
for pb := pkts.Front(); pb != nil; pb = pb.Next() {
e.addIPHeader(r, pb, params)
- if e.packetMustBeFragmented(pb, gso) {
+
+ networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size()))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
+ return 0, err
+ }
+ if packetMustBeFragmented(pb, networkMTU, gso) {
// Keep track of the packet that is about to be fragmented so it can be
// removed once the fragmentation is done.
originalPkt := pb
- if _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
// Modify the packet list in place with the new fragments.
pkts.InsertAfter(pb, fragPkt)
pb = fragPkt
@@ -546,10 +582,12 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv6(pkt.NetworkHeader().View())
if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil {
- src := netHeader.SourceAddress()
- dst := netHeader.DestinationAddress()
- route := r.ReverseRoute(src, dst)
- ep.HandlePacket(&route, pkt)
+ pkt := pkt.CloneToInbound()
+ if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK {
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ route.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
+ }
n++
continue
}
@@ -569,7 +607,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return n + len(dropped), nil
}
-// WriteHeaderIncludedPacker implements stack.NetworkEndpoint.
+// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
// The packet already has an IP header, but there are a few required checks.
h, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
@@ -607,22 +645,27 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
if !e.isEnabled() {
return
}
+ pkt.NICID = e.nic.ID()
+ stats := e.protocol.stack.Stats()
+
h := header.IPv6(pkt.NetworkHeader().View())
if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
return
}
+ srcAddr := h.SourceAddress()
+ dstAddr := h.DestinationAddress()
// As per RFC 4291 section 2.7:
// Multicast addresses must not be used as source addresses in IPv6
// packets or appear in any Routing header.
- if header.IsV6MulticastAddress(r.RemoteAddress) {
- r.Stats().IP.InvalidSourceAddressesReceived.Increment()
+ if header.IsV6MulticastAddress(srcAddr) {
+ stats.IP.InvalidSourceAddressesReceived.Increment()
return
}
@@ -641,7 +684,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
- r.Stats().IP.IPTablesInputDropped.Increment()
+ stats.IP.IPTablesInputDropped.Increment()
return
}
@@ -651,7 +694,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
previousHeaderStart := it.HeaderOffset()
extHdr, done, err := it.Next()
if err != nil {
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
return
}
if done {
@@ -663,7 +706,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// As per RFC 8200 section 4.1, the Hop By Hop extension header is
// restricted to appear immediately after an IPv6 fixed header.
if previousHeaderStart != 0 {
- _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6UnknownHeader,
pointer: previousHeaderStart,
}, pkt)
@@ -675,7 +718,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
for {
opt, done, err := optsIt.Next()
if err != nil {
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
return
}
if done {
@@ -689,7 +732,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6OptionUnknownActionDiscard:
return
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
- if header.IsV6MulticastAddress(r.LocalAddress) {
+ if header.IsV6MulticastAddress(dstAddr) {
return
}
fallthrough
@@ -702,7 +745,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// ICMP Parameter Problem, Code 2, message to the packet's
// Source Address, pointing to the unrecognized Option Type.
//
- _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6UnknownOption,
pointer: it.ParseOffset() + optsIt.OptionOffset(),
respondToMulticast: true,
@@ -727,7 +770,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// header, so we just make sure Segments Left is zero before processing
// the next extension header.
if extHdr.SegmentsLeft() != 0 {
- _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6ErroneousHeader,
pointer: it.ParseOffset(),
}, pkt)
@@ -747,6 +790,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
continue
}
+ fragmentFieldOffset := it.ParseOffset()
+
// Don't consume the iterator if we have the first fragment because we
// will use it to validate that the first fragment holds the upper layer
// header.
@@ -762,8 +807,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
for {
it, done, err := it.Next()
if err != nil {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
return
}
if done {
@@ -790,8 +835,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
switch lastHdr.(type) {
case header.IPv6RawPayloadHeader:
default:
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
return
}
}
@@ -799,30 +844,70 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
fragmentPayloadLen := rawPayload.Buf.Size()
if fragmentPayloadLen == 0 {
// Drop the packet as it's marked as a fragment but has no payload.
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+
+ // As per RFC 2460 Section 4.5:
+ //
+ // If the length of a fragment, as derived from the fragment packet's
+ // Payload Length field, is not a multiple of 8 octets and the M flag
+ // of that fragment is 1, then that fragment must be discarded and an
+ // ICMP Parameter Problem, Code 0, message should be sent to the source
+ // of the fragment, pointing to the Payload Length field of the
+ // fragment packet.
+ if extHdr.More() && fragmentPayloadLen%header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit != 0 {
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
+ code: header.ICMPv6ErroneousHeader,
+ pointer: header.IPv6PayloadLenOffset,
+ }, pkt)
return
}
// The packet is a fragment, let's try to reassemble it.
start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit
- // Drop the fragment if the size of the reassembled payload would exceed
- // the maximum payload size.
+ // As per RFC 2460 Section 4.5:
+ //
+ // If the length and offset of a fragment are such that the Payload
+ // Length of the packet reassembled from that fragment would exceed
+ // 65,535 octets, then that fragment must be discarded and an ICMP
+ // Parameter Problem, Code 0, message should be sent to the source of
+ // the fragment, pointing to the Fragment Offset field of the fragment
+ // packet.
if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
+ code: header.ICMPv6ErroneousHeader,
+ pointer: fragmentFieldOffset,
+ }, pkt)
return
}
+ // Set up a callback in case we need to send a Time Exceeded Message as
+ // per RFC 2460 Section 4.5.
+ var releaseCB func(bool)
+ if start == 0 {
+ pkt := pkt.Clone()
+ releaseCB = func(timedOut bool) {
+ if timedOut {
+ _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt)
+ }
+ }
+ }
+
// Note that pkt doesn't have its transport header set after reassembly,
// and won't until DeliverNetworkPacket sets it.
data, proto, ready, err := e.protocol.fragmentation.Process(
// IPv6 ignores the Protocol field since the ID only needs to be unique
// across source-destination pairs, as per RFC 8200 section 4.5.
fragmentation.FragmentID{
- Source: h.SourceAddress(),
- Destination: h.DestinationAddress(),
+ Source: srcAddr,
+ Destination: dstAddr,
ID: extHdr.ID(),
},
start,
@@ -830,10 +915,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
extHdr.More(),
uint8(rawPayload.Identifier),
rawPayload.Buf,
+ releaseCB,
)
if err != nil {
- r.Stats().IP.MalformedPacketsReceived.Increment()
- r.Stats().IP.MalformedFragmentsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedFragmentsReceived.Increment()
return
}
pkt.Data = data
@@ -852,7 +938,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
for {
opt, done, err := optsIt.Next()
if err != nil {
- r.Stats().IP.MalformedPacketsReceived.Increment()
+ stats.IP.MalformedPacketsReceived.Increment()
return
}
if done {
@@ -866,7 +952,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
case header.IPv6OptionUnknownActionDiscard:
return
case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
- if header.IsV6MulticastAddress(r.LocalAddress) {
+ if header.IsV6MulticastAddress(dstAddr) {
return
}
fallthrough
@@ -879,7 +965,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// ICMP Parameter Problem, Code 2, message to the packet's
// Source Address, pointing to the unrecognized Option Type.
//
- _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6UnknownOption,
pointer: it.ParseOffset() + optsIt.OptionOffset(),
respondToMulticast: true,
@@ -902,13 +988,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size())
pkt.Data = extHdr.Buf
- r.Stats().IP.PacketsDelivered.Increment()
+ stats.IP.PacketsDelivered.Increment()
if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
pkt.TransportProtocolNumber = p
- e.handleICMP(r, pkt, hasFragmentHeader)
+ e.handleICMP(pkt, hasFragmentHeader)
} else {
- r.Stats().IP.PacketsDelivered.Increment()
- switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res {
+ stats.IP.PacketsDelivered.Increment()
+ switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res {
case stack.TransportPacketHandled:
case stack.TransportPacketDestinationPortUnreachable:
// As per RFC 4443 section 3.1:
@@ -916,7 +1002,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// message with Code 4 in response to a packet for which the
// transport protocol (e.g., UDP) has no listener, if that transport
// protocol has no alternative means to inform the sender.
- _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt)
+ _ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt)
case stack.TransportPacketProtocolUnreachable:
// As per RFC 8200 section 4. (page 7):
// Extension headers are numbered from IANA IP Protocol Numbers
@@ -937,7 +1023,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
//
// Which when taken together indicate that an unknown protocol should
// be treated as an unrecognized next header value.
- _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6UnknownHeader,
pointer: it.ParseOffset(),
}, pkt)
@@ -947,11 +1033,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
}
default:
- _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ _ = e.protocol.returnError(&icmpReasonParameterProblem{
code: header.ICMPv6UnknownHeader,
pointer: it.ParseOffset(),
}, pkt)
- r.Stats().UnknownProtocolRcvdPackets.Increment()
+ stats.UnknownProtocolRcvdPackets.Increment()
return
}
}
@@ -1427,14 +1513,31 @@ func (p *protocol) SetForwarding(v bool) {
}
}
-// calculateMTU calculates the network-layer payload MTU based on the link-layer
-// payload mtu.
-func calculateMTU(mtu uint32) uint32 {
- mtu -= header.IPv6MinimumSize
- if mtu <= maxPayloadSize {
- return mtu
+// calculateNetworkMTU calculates the network-layer payload MTU based on the
+// link-layer payload MTU and the length of every IPv6 header.
+// Note that this is different than the Payload Length field of the IPv6 header,
+// which includes the length of the extension headers.
+func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, *tcpip.Error) {
+ if linkMTU < header.IPv6MinimumMTU {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ // As per RFC 7112 section 5, we should discard packets if their IPv6 header
+ // is bigger than 1280 bytes (ie, the minimum link MTU) since we do not
+ // support PMTU discovery:
+ // Hosts that do not discover the Path MTU MUST limit the IPv6 Header Chain
+ // length to 1280 bytes. Limiting the IPv6 Header Chain length to 1280
+ // bytes ensures that the header chain length does not exceed the IPv6
+ // minimum MTU.
+ if networkHeadersLen > header.IPv6MinimumMTU {
+ return 0, tcpip.ErrMalformedHeader
+ }
+
+ networkMTU := linkMTU - uint32(networkHeadersLen)
+ if networkMTU > maxPayloadSize {
+ networkMTU = maxPayloadSize
}
- return maxPayloadSize
+ return networkMTU, nil
}
// Options holds options to configure a new protocol.
@@ -1488,7 +1591,7 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
return func(s *stack.Stack) stack.NetworkProtocol {
p := &protocol{
stack: s,
- fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()),
+ fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()),
ids: ids,
hashIV: hashIV,
@@ -1509,23 +1612,6 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
return NewProtocolWithOptions(Options{})(s)
}
-// calculateFragmentInnerMTU calculates the maximum number of bytes of
-// fragmentable data a fragment can have, based on the link layer mtu and pkt's
-// network header size.
-func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 {
- // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are
- // supported for outbound packets, their length should not affect the fragment
- // MTU because they should only be transmitted once.
- mtu -= uint32(pkt.NetworkHeader().View().Size())
- mtu -= header.IPv6FragmentHeaderSize
- // Round the MTU down to align to 8 bytes.
- mtu &^= 7
- if mtu <= maxPayloadSize {
- return mtu
- }
- return maxPayloadSize
-}
-
func calculateFragmentReserve(pkt *stack.PacketBuffer) int {
return pkt.AvailableHeaderBytes() + pkt.NetworkHeader().View().Size() + header.IPv6FragmentHeaderSize
}
@@ -1560,6 +1646,7 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders hea
originalIPHeadersLength := len(originalIPHeaders)
fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize
fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength))
+ fragPkt.NetworkProtocolNumber = ProtocolNumber
// Copy the IPv6 header and any extension headers already populated.
if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength {
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 297868f24..1bfcdde25 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/testutil"
@@ -238,7 +239,7 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory},
})
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(10, header.IPv6MinimumMTU, linkAddr1)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
}
@@ -271,7 +272,7 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory},
})
- e := channel.New(1, 1280, linkAddr1)
+ e := channel.New(1, header.IPv6MinimumMTU, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -825,7 +826,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- e := channel.New(1, 1280, linkAddr1)
+ e := channel.New(1, header.IPv6MinimumMTU, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -1844,7 +1845,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- e := channel.New(0, 1280, linkAddr1)
+ e := channel.New(0, header.IPv6MinimumMTU, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -1912,16 +1913,19 @@ func TestReceiveIPv6Fragments(t *testing.T) {
func TestInvalidIPv6Fragments(t *testing.T) {
const (
- nicID = 1
- fragmentExtHdrLen = 8
+ addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ nicID = 1
+ hoplimit = 255
+ ident = 1
+ data = "TEST_INVALID_IPV6_FRAGMENTS"
)
- payloadGen := func(payloadLen int) []byte {
- payload := make([]byte, payloadLen)
- for i := 0; i < len(payload); i++ {
- payload[i] = 0x30
- }
- return payload
+ type fragmentData struct {
+ ipv6Fields header.IPv6Fields
+ ipv6FragmentFields header.IPv6FragmentFields
+ payload []byte
}
tests := []struct {
@@ -1929,31 +1933,64 @@ func TestInvalidIPv6Fragments(t *testing.T) {
fragments []fragmentData
wantMalformedIPPackets uint64
wantMalformedFragments uint64
+ expectICMP bool
+ expectICMPType header.ICMPv6Type
+ expectICMPCode header.ICMPv6Code
+ expectICMPTypeSpecific uint32
}{
{
+ name: "fragment size is not a multiple of 8 and the M flag is true",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 9,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0 >> 3,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:9],
+ },
+ },
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
+ expectICMP: true,
+ expectICMPType: header.ICMPv6ParamProblem,
+ expectICMPCode: header.ICMPv6ErroneousHeader,
+ expectICMPTypeSpecific: header.IPv6PayloadLenOffset,
+ },
+ {
name: "fragments reassembled into a payload exceeding the max IPv6 payload size",
fragments: []fragmentData{
{
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+(header.IPv6MaximumPayloadSize+1)-16,
- []buffer.View{
- // Fragment extension header.
- // Fragment offset = 8190, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0,
- ((header.IPv6MaximumPayloadSize + 1) - 16) >> 8,
- ((header.IPv6MaximumPayloadSize + 1) - 16) & math.MaxUint8,
- 0, 0, 0, 1}),
- // Payload length = 16
- payloadGen(16),
- },
- ),
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3,
+ M: false,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
},
},
wantMalformedIPPackets: 1,
wantMalformedFragments: 1,
+ expectICMP: true,
+ expectICMPType: header.ICMPv6ParamProblem,
+ expectICMPCode: header.ICMPv6ErroneousHeader,
+ expectICMPTypeSpecific: header.IPv6MinimumSize + 2, /* offset for 'Fragment Offset' in the fragment header */
},
}
@@ -1964,33 +2001,40 @@ func TestInvalidIPv6Fragments(t *testing.T) {
NewProtocol,
},
})
- e := channel.New(0, 1500, linkAddr1)
+ e := channel.New(1, 1500, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
}
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ }})
+ var expectICMPPayload buffer.View
for _, f := range test.fragments {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)
- // Serialize IPv6 fixed header.
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(f.data.Size()),
- NextHeader: f.nextHdr,
- HopLimit: 255,
- SrcAddr: f.srcAddr,
- DstAddr: f.dstAddr,
- })
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize))
+ ip.Encode(&f.ipv6Fields)
+
+ fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:])
+ fragHDR.Encode(&f.ipv6FragmentFields)
vv := hdr.View().ToVectorisedView()
- vv.Append(f.data)
+ vv.AppendView(f.payload)
- e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: vv,
- }))
+ })
+
+ if test.expectICMP {
+ expectICMPPayload = stack.PayloadSince(pkt.NetworkHeader())
+ }
+
+ e.InjectInbound(ProtocolNumber, pkt)
}
if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want {
@@ -1999,6 +2043,287 @@ func TestInvalidIPv6Fragments(t *testing.T) {
if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want {
t.Errorf("got Stats.IP.MalformedFragmentsReceived = %d, want = %d", got, want)
}
+
+ reply, ok := e.Read()
+ if !test.expectICMP {
+ if ok {
+ t.Fatalf("unexpected ICMP error message received: %#v", reply)
+ }
+ return
+ }
+ if !ok {
+ t.Fatal("expected ICMP error message missing")
+ }
+
+ checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
+ checker.SrcAddr(addr2),
+ checker.DstAddr(addr1),
+ checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectICMPPayload.Size())),
+ checker.ICMPv6(
+ checker.ICMPv6Type(test.expectICMPType),
+ checker.ICMPv6Code(test.expectICMPCode),
+ checker.ICMPv6TypeSpecific(test.expectICMPTypeSpecific),
+ checker.ICMPv6Payload([]byte(expectICMPPayload)),
+ ),
+ )
+ })
+ }
+}
+
+func TestFragmentReassemblyTimeout(t *testing.T) {
+ const (
+ addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ nicID = 1
+ hoplimit = 255
+ ident = 1
+ data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT"
+ )
+
+ type fragmentData struct {
+ ipv6Fields header.IPv6Fields
+ ipv6FragmentFields header.IPv6FragmentFields
+ payload []byte
+ }
+
+ tests := []struct {
+ name string
+ fragments []fragmentData
+ expectICMP bool
+ }{
+ {
+ name: "first fragment only",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
+ },
+ },
+ expectICMP: true,
+ },
+ {
+ name: "two first fragments",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
+ },
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
+ },
+ },
+ expectICMP: true,
+ },
+ {
+ name: "second fragment only",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 8,
+ M: false,
+ Identification: ident,
+ },
+ payload: []byte(data)[16:],
+ },
+ },
+ expectICMP: false,
+ },
+ {
+ name: "two fragments with a gap",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
+ },
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 8,
+ M: false,
+ Identification: ident,
+ },
+ payload: []byte(data)[16:],
+ },
+ },
+ expectICMP: true,
+ },
+ {
+ name: "two fragments with a gap in reverse order",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 8,
+ M: false,
+ Identification: ident,
+ },
+ payload: []byte(data)[16:],
+ },
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
+ },
+ },
+ expectICMP: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ NewProtocol,
+ },
+ Clock: clock,
+ })
+
+ e := channel.New(1, 1500, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ }})
+
+ var firstFragmentSent buffer.View
+ for _, f := range test.fragments {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)
+
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize))
+ ip.Encode(&f.ipv6Fields)
+
+ fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:])
+ fragHDR.Encode(&f.ipv6FragmentFields)
+
+ vv := hdr.View().ToVectorisedView()
+ vv.AppendView(f.payload)
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ })
+
+ if firstFragmentSent == nil && fragHDR.FragmentOffset() == 0 {
+ firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader())
+ }
+
+ e.InjectInbound(ProtocolNumber, pkt)
+ }
+
+ clock.Advance(ReassembleTimeout)
+
+ reply, ok := e.Read()
+ if !test.expectICMP {
+ if ok {
+ t.Fatalf("unexpected ICMP error message received: %#v", reply)
+ }
+ return
+ }
+ if !ok {
+ t.Fatal("expected ICMP error message missing")
+ }
+ if firstFragmentSent == nil {
+ t.Fatalf("unexpected ICMP error message received: %#v", reply)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
+ checker.SrcAddr(addr2),
+ checker.DstAddr(addr1),
+ checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+firstFragmentSent.Size())),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6TimeExceeded),
+ checker.ICMPv6Code(header.ICMPv6ReassemblyTimeout),
+ checker.ICMPv6Payload([]byte(firstFragmentSent)),
+ ),
+ )
})
}
}
@@ -2035,13 +2360,10 @@ func TestWriteStats(t *testing.T) {
// Install Output DROP rule.
t.Helper()
ipt := stk.IPTables()
- filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */)
- if !ok {
- t.Fatalf("failed to find filter table")
- }
+ filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
ruleIdx := filter.BuiltinChains[stack.Output]
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
- if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil {
+ if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %v", err)
}
},
@@ -2056,17 +2378,14 @@ func TestWriteStats(t *testing.T) {
// of the 3 packets.
t.Helper()
ipt := stk.IPTables()
- filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */)
- if !ok {
- t.Fatalf("failed to find filter table")
- }
+ filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
// We'll match and DROP the last packet.
ruleIdx := filter.BuiltinChains[stack.Output]
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
// Make sure the next rule is ACCEPT.
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
- if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil {
+ if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %v", err)
}
},
@@ -2230,8 +2549,8 @@ var fragmentationTests = []struct {
wantFragments []fragmentInfo
}{
{
- description: "No Fragmentation",
- mtu: 1280,
+ description: "No fragmentation",
+ mtu: header.IPv6MinimumMTU,
gso: nil,
transHdrLen: 0,
payloadSize: 1000,
@@ -2241,7 +2560,18 @@ var fragmentationTests = []struct {
},
{
description: "Fragmented",
- mtu: 1280,
+ mtu: header.IPv6MinimumMTU,
+ gso: nil,
+ transHdrLen: 0,
+ payloadSize: 2000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 776, more: false},
+ },
+ },
+ {
+ description: "Fragmented with mtu not a multiple of 8",
+ mtu: header.IPv6MinimumMTU + 1,
gso: nil,
transHdrLen: 0,
payloadSize: 2000,
@@ -2262,7 +2592,7 @@ var fragmentationTests = []struct {
},
{
description: "Fragmented with gso none",
- mtu: 1280,
+ mtu: header.IPv6MinimumMTU,
gso: &stack.GSO{Type: stack.GSONone},
transHdrLen: 0,
payloadSize: 1400,
@@ -2273,7 +2603,7 @@ var fragmentationTests = []struct {
},
{
description: "Fragmented with big header",
- mtu: 1280,
+ mtu: header.IPv6MinimumMTU,
gso: nil,
transHdrLen: 100,
payloadSize: 1200,
@@ -2448,8 +2778,8 @@ func TestFragmentationErrors(t *testing.T) {
wantError: tcpip.ErrAborted,
},
{
- description: "Error on packet with MTU smaller than transport header",
- mtu: 1280,
+ description: "Error when MTU is smaller than transport header",
+ mtu: header.IPv6MinimumMTU,
transHdrLen: 1500,
payloadSize: 500,
allowPackets: 0,
@@ -2457,6 +2787,16 @@ func TestFragmentationErrors(t *testing.T) {
mockError: nil,
wantError: tcpip.ErrMessageTooLong,
},
+ {
+ description: "Error when MTU is smaller than IPv6 minimum MTU",
+ mtu: header.IPv6MinimumMTU - 1,
+ transHdrLen: 0,
+ payloadSize: 500,
+ allowPackets: 0,
+ outgoingErrors: 1,
+ mockError: nil,
+ wantError: tcpip.ErrInvalidEndpointState,
+ },
}
for _, ft := range tests {
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index ac20f217e..981d1371a 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -341,7 +341,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi
if diff := cmp.Diff(existing, n); diff != "" {
t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff)
}
- t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing)
+ t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %#v", nicID, existing)
}
neighborByAddr[n.Addr] = n
}
@@ -368,7 +368,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi
}
if ok {
- t.Fatalf("unexpectedly got neighbor entry: %s", neigh)
+ t.Fatalf("unexpectedly got neighbor entry: %#v", neigh)
}
}
})
@@ -573,6 +573,13 @@ func TestNeighorSolicitationResponse(t *testing.T) {
t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
}
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv6EmptySubnet,
+ NIC: 1,
+ },
+ })
+
ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length()
hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
@@ -913,13 +920,13 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test
if diff := cmp.Diff(existing, n); diff != "" {
t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff)
}
- t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing)
+ t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %#v", nicID, existing)
}
neighborByAddr[n.Addr] = n
}
if neigh, ok := neighborByAddr[lladdr1]; ok {
- t.Fatalf("unexpectedly got neighbor entry: %s", neigh)
+ t.Fatalf("unexpectedly got neighbor entry: %#v", neigh)
}
if test.isValid {
@@ -993,7 +1000,8 @@ func TestNDPValidation(t *testing.T) {
if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
}
- ep.HandlePacket(r, pkt)
+ r.PopulatePacketInfo(pkt)
+ ep.HandlePacket(pkt)
}
var tllData [header.NDPLinkLayerAddressSize]byte
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index 4d3acab96..9478f3fb7 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -272,6 +272,9 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address
addrState = &addressState{
addressableEndpointState: a,
addr: addr,
+ // Cache the subnet in addrState to avoid calls to addr.Subnet() as that
+ // results in allocations on every call.
+ subnet: addr.Subnet(),
}
a.mu.endpoints[addr.Address] = addrState
addrState.mu.Lock()
@@ -361,6 +364,8 @@ func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) *
return tcpip.ErrInvalidEndpointState
}
+ a.mu.Lock()
+ defer a.mu.Unlock()
return a.removePermanentEndpointLocked(addrState)
}
@@ -664,7 +669,7 @@ var _ AddressEndpoint = (*addressState)(nil)
type addressState struct {
addressableEndpointState *AddressableEndpointState
addr tcpip.AddressWithPrefix
-
+ subnet tcpip.Subnet
// Lock ordering (from outer to inner lock ordering):
//
// AddressableEndpointState.mu
@@ -684,6 +689,11 @@ func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix {
return a.addr
}
+// Subnet implements AddressEndpoint.
+func (a *addressState) Subnet() tcpip.Subnet {
+ return a.subnet
+}
+
// GetKind implements AddressEndpoint.
func (a *addressState) GetKind() AddressKind {
a.mu.RLock()
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 0cd1da11f..9a17efcba 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -269,7 +269,7 @@ func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
return nil, dirOriginal
}
-func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *RedirectTarget) *conn {
+func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
tid, err := packetToTupleID(pkt)
if err != nil {
return nil
@@ -282,8 +282,8 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *Redire
// rule. This tuple will be used to manipulate the packet in
// handlePacket.
replyTID := tid.reply()
- replyTID.srcAddr = rt.Addr
- replyTID.srcPort = rt.Port
+ replyTID.srcAddr = address
+ replyTID.srcPort = port
var manip manipType
switch hook {
case Prerouting:
@@ -401,12 +401,12 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
// Calculate the TCP checksum and set it.
tcpHeader.SetChecksum(0)
- length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
- xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length)
+ length := uint16(len(tcpHeader) + pkt.Data.Size())
+ xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
if gso != nil && gso.NeedsCsum {
tcpHeader.SetChecksum(xsum)
- } else if r.Capabilities()&CapabilityTXChecksumOffload == 0 {
- xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, int(tcpHeader.DataOffset()), pkt.Data.Size())
+ } else if r.RequiresTXTransportChecksum() {
+ xsum = header.ChecksumVV(pkt.Data, xsum)
tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
}
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index cf042309e..7a501acdc 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -73,9 +73,9 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
return 123
}
-func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) {
+func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) {
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
+ f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
}
func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -178,7 +178,7 @@ func (*fwdTestNetworkProtocol) Close() {}
func (*fwdTestNetworkProtocol) Wait() {}
-func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
+func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
if f.onLinkAddressResolved != nil {
time.AfterFunc(f.addrResolveDelay, func() {
f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr)
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 8d6d9a7f1..2d8c883cd 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -22,30 +22,17 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-// tableID is an index into IPTables.tables.
-type tableID int
+// TableID identifies a specific table.
+type TableID int
+// Each value identifies a specific table.
const (
- natID tableID = iota
- mangleID
- filterID
- numTables
+ NATID TableID = iota
+ MangleID
+ FilterID
+ NumTables
)
-// Table names.
-const (
- NATTable = "nat"
- MangleTable = "mangle"
- FilterTable = "filter"
-)
-
-// nameToID is immutable.
-var nameToID = map[string]tableID{
- NATTable: natID,
- MangleTable: mangleID,
- FilterTable: filterID,
-}
-
// HookUnset indicates that there is no hook set for an entrypoint or
// underflow.
const HookUnset = -1
@@ -57,8 +44,8 @@ const reaperDelay = 5 * time.Second
// all packets.
func DefaultTables() *IPTables {
return &IPTables{
- v4Tables: [numTables]Table{
- natID: Table{
+ v4Tables: [NumTables]Table{
+ NATID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
@@ -81,7 +68,7 @@ func DefaultTables() *IPTables {
Postrouting: 3,
},
},
- mangleID: Table{
+ MangleID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
@@ -99,7 +86,7 @@ func DefaultTables() *IPTables {
Postrouting: HookUnset,
},
},
- filterID: Table{
+ FilterID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
@@ -122,8 +109,8 @@ func DefaultTables() *IPTables {
},
},
},
- v6Tables: [numTables]Table{
- natID: Table{
+ v6Tables: [NumTables]Table{
+ NATID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
@@ -146,7 +133,7 @@ func DefaultTables() *IPTables {
Postrouting: 3,
},
},
- mangleID: Table{
+ MangleID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
@@ -164,7 +151,7 @@ func DefaultTables() *IPTables {
Postrouting: HookUnset,
},
},
- filterID: Table{
+ FilterID: Table{
Rules: []Rule{
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}},
@@ -187,10 +174,10 @@ func DefaultTables() *IPTables {
},
},
},
- priorities: [NumHooks][]tableID{
- Prerouting: []tableID{mangleID, natID},
- Input: []tableID{natID, filterID},
- Output: []tableID{mangleID, natID, filterID},
+ priorities: [NumHooks][]TableID{
+ Prerouting: []TableID{MangleID, NATID},
+ Input: []TableID{NATID, FilterID},
+ Output: []TableID{MangleID, NATID, FilterID},
},
connections: ConnTrack{
seed: generateRandUint32(),
@@ -229,26 +216,20 @@ func EmptyNATTable() Table {
}
}
-// GetTable returns a table by name.
-func (it *IPTables) GetTable(name string, ipv6 bool) (Table, bool) {
- id, ok := nameToID[name]
- if !ok {
- return Table{}, false
- }
+// GetTable returns a table with the given id and IP version. It panics when an
+// invalid id is provided.
+func (it *IPTables) GetTable(id TableID, ipv6 bool) Table {
it.mu.RLock()
defer it.mu.RUnlock()
if ipv6 {
- return it.v6Tables[id], true
+ return it.v6Tables[id]
}
- return it.v4Tables[id], true
+ return it.v4Tables[id]
}
-// ReplaceTable replaces or inserts table by name.
-func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Error {
- id, ok := nameToID[name]
- if !ok {
- return tcpip.ErrInvalidOptionValue
- }
+// ReplaceTable replaces or inserts table by name. It panics when an invalid id
+// is provided.
+func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) *tcpip.Error {
it.mu.Lock()
defer it.mu.Unlock()
// If iptables is being enabled, initialize the conntrack table and
@@ -311,7 +292,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer
for _, tableID := range priorities {
// If handlePacket already NATed the packet, we don't need to
// check the NAT table.
- if tableID == natID && pkt.NatDone {
+ if tableID == NATID && pkt.NatDone {
continue
}
var table Table
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 538c4625d..d63e9757c 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -15,6 +15,8 @@
package stack
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -26,13 +28,6 @@ type AcceptTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (at *AcceptTarget) ID() TargetID {
- return TargetID{
- NetworkProtocol: at.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleAccept, 0
@@ -44,22 +39,11 @@ type DropTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (dt *DropTarget) ID() TargetID {
- return TargetID{
- NetworkProtocol: dt.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleDrop, 0
}
-// ErrorTargetName is used to mark targets as error targets. Error targets
-// shouldn't be reached - an error has occurred if we fall through to one.
-const ErrorTargetName = "ERROR"
-
// ErrorTarget logs an error and drops the packet. It represents a target that
// should be unreachable.
type ErrorTarget struct {
@@ -67,14 +51,6 @@ type ErrorTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (et *ErrorTarget) ID() TargetID {
- return TargetID{
- Name: ErrorTargetName,
- NetworkProtocol: et.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
@@ -90,14 +66,6 @@ type UserChainTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (uc *UserChainTarget) ID() TargetID {
- return TargetID{
- Name: ErrorTargetName,
- NetworkProtocol: uc.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
@@ -110,50 +78,39 @@ type ReturnTarget struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (rt *ReturnTarget) ID() TargetID {
- return TargetID{
- NetworkProtocol: rt.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleReturn, 0
}
-// RedirectTargetName is used to mark targets as redirect targets. Redirect
-// targets should be reached for only NAT and Mangle tables. These targets will
-// change the destination port/destination IP for packets.
-const RedirectTargetName = "REDIRECT"
-
-// RedirectTarget redirects the packet by modifying the destination port/IP.
+// RedirectTarget redirects the packet to this machine by modifying the
+// destination port/IP. Outgoing packets are redirected to the loopback device,
+// and incoming packets are redirected to the incoming interface (rather than
+// forwarded).
+//
// TODO(gvisor.dev/issue/170): Other flags need to be added after we support
// them.
type RedirectTarget struct {
- // Addr indicates address used to redirect.
- Addr tcpip.Address
-
- // Port indicates port used to redirect.
+ // Port indicates port used to redirect. It is immutable.
Port uint16
- // NetworkProtocol is the network protocol the target is used with.
+ // NetworkProtocol is the network protocol the target is used with. It
+ // is immutable.
NetworkProtocol tcpip.NetworkProtocolNumber
}
-// ID implements Target.ID.
-func (rt *RedirectTarget) ID() TargetID {
- return TargetID{
- Name: RedirectTargetName,
- NetworkProtocol: rt.NetworkProtocol,
- }
-}
-
// Action implements Target.Action.
// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
-// implementation only works for PREROUTING and calls pkt.Clone(), neither
+// implementation only works for Prerouting and calls pkt.Clone(), neither
// of which should be the case.
func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
+ // Sanity check.
+ if rt.NetworkProtocol != pkt.NetworkProtocolNumber {
+ panic(fmt.Sprintf(
+ "RedirectTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
+ rt.NetworkProtocol, pkt.NetworkProtocolNumber))
+ }
+
// Packet is already manipulated.
if pkt.NatDone {
return RuleAccept, 0
@@ -164,17 +121,17 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
return RuleDrop, 0
}
- // Change the address to localhost (127.0.0.1 or ::1) in Output and to
+ // Change the address to loopback (127.0.0.1 or ::1) in Output and to
// the primary address of the incoming interface in Prerouting.
switch hook {
case Output:
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
- rt.Addr = tcpip.Address([]byte{127, 0, 0, 1})
+ address = tcpip.Address([]byte{127, 0, 0, 1})
} else {
- rt.Addr = header.IPv6Loopback
+ address = header.IPv6Loopback
}
case Prerouting:
- rt.Addr = address
+ // No-op, as address is already set correctly.
default:
panic("redirect target is supported only on output and prerouting hooks")
}
@@ -189,21 +146,18 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
// Calculate UDP checksum and set it.
if hook == Output {
udpHeader.SetChecksum(0)
+ netHeader := pkt.Network()
+ netHeader.SetDestinationAddress(address)
// Only calculate the checksum if offloading isn't supported.
- if r.Capabilities()&CapabilityTXChecksumOffload == 0 {
+ if r.RequiresTXTransportChecksum() {
length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
- xsum := r.PseudoHeaderChecksum(protocol, length)
- for _, v := range pkt.Data.Views() {
- xsum = header.Checksum(v, xsum)
- }
- udpHeader.SetChecksum(0)
+ xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
+ xsum = header.ChecksumVV(pkt.Data, xsum)
udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
}
}
- pkt.Network().SetDestinationAddress(rt.Addr)
-
// After modification, IPv4 packets need a valid checksum.
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
netHeader := header.IPv4(pkt.NetworkHeader().View())
@@ -219,7 +173,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
// Set up conection for matching NAT rule. Only the first
// packet of the connection comes here. Other packets will be
// manipulated in connection tracking.
- if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil {
+ if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil {
ct.handlePacket(pkt, hook, gso, r)
}
default:
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 7b3f3e88b..4b86c1be9 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -37,7 +37,6 @@ import (
// ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]----->
type Hook uint
-// These values correspond to values in include/uapi/linux/netfilter.h.
const (
// Prerouting happens before a packet is routed to applications or to
// be forwarded.
@@ -86,8 +85,8 @@ type IPTables struct {
mu sync.RWMutex
// v4Tables and v6tables map tableIDs to tables. They hold builtin
// tables only, not user tables. mu must be locked for accessing.
- v4Tables [numTables]Table
- v6Tables [numTables]Table
+ v4Tables [NumTables]Table
+ v6Tables [NumTables]Table
// modified is whether tables have been modified at least once. It is
// used to elide the iptables performance overhead for workloads that
// don't utilize iptables.
@@ -96,7 +95,7 @@ type IPTables struct {
// priorities maps each hook to a list of table names. The order of the
// list is the order in which each table should be visited for that
// hook. It is immutable.
- priorities [NumHooks][]tableID
+ priorities [NumHooks][]TableID
connections ConnTrack
@@ -104,6 +103,24 @@ type IPTables struct {
reaperDone chan struct{}
}
+// VisitTargets traverses all the targets of all tables and replaces each with
+// transform(target).
+func (it *IPTables) VisitTargets(transform func(Target) Target) {
+ it.mu.Lock()
+ defer it.mu.Unlock()
+
+ for tid := range it.v4Tables {
+ for i, rule := range it.v4Tables[tid].Rules {
+ it.v4Tables[tid].Rules[i].Target = transform(rule.Target)
+ }
+ }
+ for tid := range it.v6Tables {
+ for i, rule := range it.v6Tables[tid].Rules {
+ it.v6Tables[tid].Rules[i].Target = transform(rule.Target)
+ }
+ }
+}
+
// A Table defines a set of chains and hooks into the network stack.
//
// It is a list of Rules, entry points (BuiltinChains), and error handlers
@@ -169,7 +186,6 @@ type IPHeaderFilter struct {
// CheckProtocol determines whether the Protocol field should be
// checked during matching.
- // TODO(gvisor.dev/issue/3549): Check this field during matching.
CheckProtocol bool
// Dst matches the destination IP address.
@@ -309,23 +325,8 @@ type Matcher interface {
Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
}
-// A TargetID uniquely identifies a target.
-type TargetID struct {
- // Name is the target name as stored in the xt_entry_target struct.
- Name string
-
- // NetworkProtocol is the protocol to which the target applies.
- NetworkProtocol tcpip.NetworkProtocolNumber
-
- // Revision is the version of the target.
- Revision uint8
-}
-
// A Target is the interface for taking an action for a packet.
type Target interface {
- // ID uniquely identifies the Target.
- ID() TargetID
-
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the index of the rule to jump to.
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index 6f73a0ce4..c9b13cd0e 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -180,7 +180,7 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt
}
// get reports any known link address for k.
-func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
if linkRes != nil {
if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok {
return addr, nil, nil
@@ -221,7 +221,7 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo
}
entry.done = make(chan struct{})
- go c.startAddressResolution(k, linkRes, localAddr, linkEP, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
+ go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
}
return entry.linkAddr, entry.done, tcpip.ErrWouldBlock
@@ -240,11 +240,11 @@ func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
}
}
-func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, done <-chan struct{}) {
+func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) {
for i := 0; ; i++ {
// Send link request, then wait for the timeout limit and check
// whether the request succeeded.
- linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, linkEP)
+ linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, nic)
select {
case now := <-time.After(c.resolutionTimeout):
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 33806340e..d2e37f38d 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -49,8 +49,8 @@ type testLinkAddressResolver struct {
onLinkAddressRequest func()
}
-func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error {
- time.AfterFunc(r.delay, func() { r.fakeRequest(addr) })
+func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
+ time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) })
if f := r.onLinkAddressRequest; f != nil {
f()
}
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index 4df288798..177bf5516 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -16,7 +16,6 @@ package stack
import (
"fmt"
- "time"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
@@ -68,7 +67,7 @@ var _ NUDHandler = (*neighborCache)(nil)
// reset to state incomplete, and returned. If no matching entry exists and the
// cache is not full, a new entry with state incomplete is allocated and
// returned.
-func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry {
+func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry {
n.mu.Lock()
defer n.mu.Unlock()
@@ -84,7 +83,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li
// The entry that needs to be created must be dynamic since all static
// entries are directly added to the cache via addStaticEntry.
- entry := newNeighborEntry(n.nic, remoteAddr, localAddr, n.state, linkRes)
+ entry := newNeighborEntry(n.nic, remoteAddr, n.state, linkRes)
if n.dynamic.count == neighborCacheSize {
e := n.dynamic.lru.Back()
e.mu.Lock()
@@ -111,28 +110,31 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li
// provided, it will be notified when address resolution is complete (success
// or not).
//
+// If specified, the local address must be an address local to the interface the
+// neighbor cache belongs to. The local address is the source address of a
+// packet prompting NUD/link address resolution.
+//
// If address resolution is required, ErrNoLinkAddress and a notification
// channel is returned for the top level caller to block. Channel is closed
// once address resolution is complete (success or not).
func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok {
e := NeighborEntry{
- Addr: remoteAddr,
- LocalAddr: localAddr,
- LinkAddr: linkAddr,
- State: Static,
- UpdatedAt: time.Now(),
+ Addr: remoteAddr,
+ LinkAddr: linkAddr,
+ State: Static,
+ UpdatedAtNanos: 0,
}
return e, nil, nil
}
- entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes)
+ entry := n.getOrCreateEntry(remoteAddr, linkRes)
entry.mu.Lock()
defer entry.mu.Unlock()
switch s := entry.neigh.State; s {
case Stale:
- entry.handlePacketQueuedLocked()
+ entry.handlePacketQueuedLocked(localAddr)
fallthrough
case Reachable, Static, Delay, Probe:
// As per RFC 4861 section 7.3.3:
@@ -152,7 +154,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
entry.done = make(chan struct{})
}
- entry.handlePacketQueuedLocked()
+ entry.handlePacketQueuedLocked(localAddr)
return entry.neigh, entry.done, tcpip.ErrWouldBlock
case Failed:
return entry.neigh, nil, tcpip.ErrNoLinkAddress
@@ -207,7 +209,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd
} else {
// Static entry found with the same address but different link address.
entry.neigh.LinkAddr = linkAddr
- entry.dispatchChangeEventLocked(entry.neigh.State)
+ entry.dispatchChangeEventLocked()
entry.mu.Unlock()
return
}
@@ -220,8 +222,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd
entry.mu.Unlock()
}
- entry := newStaticNeighborEntry(n.nic, addr, linkAddr, n.state)
- n.cache[addr] = entry
+ n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state)
}
// removeEntryLocked removes the specified entry from the neighbor cache.
@@ -292,8 +293,8 @@ func (n *neighborCache) setConfig(config NUDConfigurations) {
// HandleProbe implements NUDHandler.HandleProbe by following the logic defined
// in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled
// by the caller.
-func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) {
- entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes)
+func (n *neighborCache) HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) {
+ entry := n.getOrCreateEntry(remoteAddr, linkRes)
entry.mu.Lock()
entry.handleProbeLocked(remoteLinkAddr)
entry.mu.Unlock()
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index fcd54ed83..ed33418f3 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -61,23 +61,20 @@ const (
)
// entryDiffOpts returns the options passed to cmp.Diff to compare neighbor
-// entries. The UpdatedAt field is ignored due to a lack of a deterministic
-// method to predict the time that an event will be dispatched.
+// entries. The UpdatedAtNanos field is ignored due to a lack of a
+// deterministic method to predict the time that an event will be dispatched.
func entryDiffOpts() []cmp.Option {
return []cmp.Option{
- cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"),
+ cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAtNanos"),
}
}
// entryDiffOptsWithSort is like entryDiffOpts but also includes an option to
// sort slices of entries for cases where ordering must be ignored.
func entryDiffOptsWithSort() []cmp.Option {
- return []cmp.Option{
- cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"),
- cmpopts.SortSlices(func(a, b NeighborEntry) bool {
- return strings.Compare(string(a.Addr), string(b.Addr)) < 0
- }),
- }
+ return append(entryDiffOpts(), cmpopts.SortSlices(func(a, b NeighborEntry) bool {
+ return strings.Compare(string(a.Addr), string(b.Addr)) < 0
+ }))
}
func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache {
@@ -128,9 +125,8 @@ func newTestEntryStore() *testEntryStore {
linkAddr := toLinkAddress(i)
store.entriesMap[addr] = NeighborEntry{
- Addr: addr,
- LocalAddr: testEntryLocalAddr,
- LinkAddr: linkAddr,
+ Addr: addr,
+ LinkAddr: linkAddr,
}
}
return store
@@ -195,10 +191,10 @@ type testNeighborResolver struct {
var _ LinkAddressResolver = (*testNeighborResolver)(nil)
-func (r *testNeighborResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
+func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
// Delay handling the request to emulate network latency.
r.clock.AfterFunc(r.delay, func() {
- r.fakeRequest(addr)
+ r.fakeRequest(targetAddr)
})
// Execute post address resolution action, if available.
@@ -294,9 +290,8 @@ func TestNeighborCacheEntry(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -305,15 +300,19 @@ func TestNeighborCacheEntry(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -324,8 +323,8 @@ func TestNeighborCacheEntry(t *testing.T) {
t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil {
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
}
// No more events should have been dispatched.
@@ -354,9 +353,9 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -365,15 +364,19 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -391,9 +394,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -404,8 +409,8 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
}
}
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
}
@@ -452,8 +457,8 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
if !ok {
return fmt.Errorf("c.store.entry(%d) not found", i)
}
- if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock {
- return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
+ return fmt.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(c.neigh.config().RetransmitTimer)
@@ -470,23 +475,29 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
wantEvents = append(wantEvents, testEntryEventInfo{
EventType: entryTestRemoved,
NICID: 1,
- Addr: removedEntry.Addr,
- LinkAddr: removedEntry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: removedEntry.Addr,
+ LinkAddr: removedEntry.LinkAddr,
+ State: Reachable,
+ },
})
}
wantEvents = append(wantEvents, testEntryEventInfo{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
}, testEntryEventInfo{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
})
c.nudDisp.mu.Lock()
@@ -508,10 +519,9 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
return fmt.Errorf("c.store.entry(%d) not found", i)
}
wantEntry := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
}
wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
}
@@ -564,24 +574,27 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
if !ok {
t.Fatalf("c.store.entry(0) not found")
}
- _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(c.neigh.config().RetransmitTimer)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -600,9 +613,11 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -640,9 +655,11 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -682,9 +699,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T)
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -703,9 +722,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T)
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -740,9 +761,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -760,9 +783,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -800,24 +825,27 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
if !ok {
t.Fatalf("c.store.entry(0) not found")
}
- _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -836,16 +864,20 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -861,10 +893,9 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
startAtEntryIndex: 1,
wantStaticEntries: []NeighborEntry{
{
- Addr: entry.Addr,
- LocalAddr: "", // static entries don't need a local address
- LinkAddr: staticLinkAddr,
- State: Static,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
},
},
}
@@ -896,12 +927,12 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w)
if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, _ = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ t.Fatalf("got neigh.entry(%s, '', _, _ = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
if doneCh == nil {
- t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr)
}
clock.Advance(typicalLatency)
@@ -913,7 +944,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) {
id, ok := s.Fetch(false /* block */)
if !ok {
- t.Errorf("expected waker to be notified after neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ t.Errorf("expected waker to be notified after neigh.entry(%s, '', _, _)", entry.Addr)
}
if id != wakerID {
t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID)
@@ -923,15 +954,19 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -964,12 +999,12 @@ func TestNeighborCacheRemoveWaker(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w)
if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, _) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ t.Fatalf("got neigh.entry(%s, '', _, _) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
if doneCh == nil {
- t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr)
}
// Remove the waker before the neighbor cache has the opportunity to send a
@@ -991,15 +1026,19 @@ func TestNeighborCacheRemoveWaker(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1028,10 +1067,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: "", // static entries don't need a local address
- LinkAddr: entry.LinkAddr,
- State: Static,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
@@ -1041,9 +1079,11 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -1058,10 +1098,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
startAtEntryIndex: 1,
wantStaticEntries: []NeighborEntry{
{
- Addr: entry.Addr,
- LocalAddr: "", // static entries don't need a local address
- LinkAddr: entry.LinkAddr,
- State: Static,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
},
},
}
@@ -1089,9 +1128,8 @@ func TestNeighborCacheClear(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -1099,15 +1137,19 @@ func TestNeighborCacheClear(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1126,9 +1168,11 @@ func TestNeighborCacheClear(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Static,
+ },
},
}
nudDisp.mu.Lock()
@@ -1149,16 +1193,20 @@ func TestNeighborCacheClear(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Static,
+ },
},
}
nudDisp.mu.Lock()
@@ -1185,24 +1233,27 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
if !ok {
t.Fatalf("c.store.entry(0) not found")
}
- _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -1220,9 +1271,11 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -1274,29 +1327,33 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
case <-doneCh:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
}
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1312,9 +1369,8 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
for i := neighborCacheSize; i < store.size(); i++ {
// Periodically refresh the frequently used entry
if i%(neighborCacheSize/2) == 0 {
- _, _, err := neigh.entry(frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, linkRes, nil)
- if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, err)
+ if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil {
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", frequentlyUsedEntry.Addr, err)
}
}
@@ -1322,15 +1378,15 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
case <-doneCh:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
}
// An entry should have been removed, as per the LRU eviction strategy
@@ -1342,22 +1398,28 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: removedEntry.Addr,
- LinkAddr: removedEntry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: removedEntry.Addr,
+ LinkAddr: removedEntry.LinkAddr,
+ State: Reachable,
+ },
},
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1374,10 +1436,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
// have to be sorted before comparison.
wantUnsortedEntries := []NeighborEntry{
{
- Addr: frequentlyUsedEntry.Addr,
- LocalAddr: frequentlyUsedEntry.LocalAddr,
- LinkAddr: frequentlyUsedEntry.LinkAddr,
- State: Reachable,
+ Addr: frequentlyUsedEntry.Addr,
+ LinkAddr: frequentlyUsedEntry.LinkAddr,
+ State: Reachable,
},
}
@@ -1387,10 +1448,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
t.Fatalf("store.entry(%d) not found", i)
}
wantEntry := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
}
wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
}
@@ -1430,9 +1490,8 @@ func TestNeighborCacheConcurrent(t *testing.T) {
wg.Add(1)
go func(entry NeighborEntry) {
defer wg.Done()
- e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != nil && err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, entry.LocalAddr, e, err, tcpip.ErrWouldBlock)
+ if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, '', _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock)
}
}(entry)
}
@@ -1456,10 +1515,9 @@ func TestNeighborCacheConcurrent(t *testing.T) {
t.Errorf("store.entry(%d) not found", i)
}
wantEntry := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
}
wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
}
@@ -1488,37 +1546,36 @@ func TestNeighborCacheReplace(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
case <-doneCh:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
}
// Verify the entry exists
{
- e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ e, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
}
if doneCh != nil {
- t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh)
+ t.Errorf("unexpected done channel from neigh.entry(%s, '', _, nil): %v", entry.Addr, doneCh)
}
if t.Failed() {
t.FailNow()
}
want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
}
@@ -1542,37 +1599,35 @@ func TestNeighborCacheReplace(t *testing.T) {
//
// Verify the entry's new link address and the new state.
{
- e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ e, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Fatalf("neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ t.Fatalf("neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: updatedLinkAddr,
- State: Delay,
+ Addr: entry.Addr,
+ LinkAddr: updatedLinkAddr,
+ State: Delay,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
clock.Advance(config.DelayFirstProbeTime + typicalLatency)
}
// Verify that the neighbor is now reachable.
{
- e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ e, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
clock.Advance(typicalLatency)
if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: updatedLinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: updatedLinkAddr,
+ State: Reachable,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
}
}
@@ -1601,35 +1656,34 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
- got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ got, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
}
if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
// Verify that address resolution for an unknown address returns ErrNoLinkAddress
before := atomic.LoadUint32(&requestCount)
entry.Addr += "2"
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress)
}
maxAttempts := neigh.config().MaxUnicastProbes
@@ -1659,13 +1713,13 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress)
}
}
@@ -1683,18 +1737,17 @@ func TestNeighborCacheStaticResolution(t *testing.T) {
delay: typicalLatency,
}
- got, _, err := neigh.entry(testEntryBroadcastAddr, testEntryLocalAddr, linkRes, nil)
+ got, _, err := neigh.entry(testEntryBroadcastAddr, "", linkRes, nil)
if err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", testEntryBroadcastAddr, testEntryLocalAddr, err)
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", testEntryBroadcastAddr, err)
}
want := NeighborEntry{
- Addr: testEntryBroadcastAddr,
- LocalAddr: testEntryLocalAddr,
- LinkAddr: testEntryBroadcastLinkAddr,
- State: Static,
+ Addr: testEntryBroadcastAddr,
+ LinkAddr: testEntryBroadcastLinkAddr,
+ State: Static,
}
if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, testEntryLocalAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff)
}
}
@@ -1719,9 +1772,9 @@ func BenchmarkCacheClear(b *testing.B) {
if !ok {
b.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != tcpip.ErrWouldBlock {
- b.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ b.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
if doneCh != nil {
<-doneCh
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index be61a21af..493e48031 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -24,13 +24,18 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
)
+const (
+ // immediateDuration is a duration of zero for scheduling work that needs to
+ // be done immediately but asynchronously to avoid deadlock.
+ immediateDuration time.Duration = 0
+)
+
// NeighborEntry describes a neighboring device in the local network.
type NeighborEntry struct {
- Addr tcpip.Address
- LocalAddr tcpip.Address
- LinkAddr tcpip.LinkAddress
- State NeighborState
- UpdatedAt time.Time
+ Addr tcpip.Address
+ LinkAddr tcpip.LinkAddress
+ State NeighborState
+ UpdatedAtNanos int64
}
// NeighborState defines the state of a NeighborEntry within the Neighbor
@@ -106,35 +111,35 @@ type neighborEntry struct {
// state, Unknown. Transition out of Unknown by calling either
// `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created
// neighborEntry.
-func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, localAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry {
+func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry {
return &neighborEntry{
nic: nic,
linkRes: linkRes,
nudState: nudState,
neigh: NeighborEntry{
- Addr: remoteAddr,
- LocalAddr: localAddr,
- State: Unknown,
+ Addr: remoteAddr,
+ State: Unknown,
},
}
}
-// newStaticNeighborEntry creates a neighbor cache entry starting at the Static
-// state. The entry can only transition out of Static by directly calling
-// `setStateLocked`.
+// newStaticNeighborEntry creates a neighbor cache entry starting at the
+// Static state. The entry can only transition out of Static by directly
+// calling `setStateLocked`.
func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry {
+ entry := NeighborEntry{
+ Addr: addr,
+ LinkAddr: linkAddr,
+ State: Static,
+ UpdatedAtNanos: nic.stack.clock.NowNanoseconds(),
+ }
if nic.stack.nudDisp != nil {
- nic.stack.nudDisp.OnNeighborAdded(nic.id, addr, linkAddr, Static, time.Now())
+ nic.stack.nudDisp.OnNeighborAdded(nic.id, entry)
}
return &neighborEntry{
nic: nic,
nudState: state,
- neigh: NeighborEntry{
- Addr: addr,
- LinkAddr: linkAddr,
- State: Static,
- UpdatedAt: time.Now(),
- },
+ neigh: entry,
}
}
@@ -165,17 +170,17 @@ func (e *neighborEntry) notifyWakersLocked() {
// dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has
// been added.
-func (e *neighborEntry) dispatchAddEventLocked(nextState NeighborState) {
+func (e *neighborEntry) dispatchAddEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
- nudDisp.OnNeighborAdded(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now())
+ nudDisp.OnNeighborAdded(e.nic.id, e.neigh)
}
}
// dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry
// has changed state or link-layer address.
-func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) {
+func (e *neighborEntry) dispatchChangeEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
- nudDisp.OnNeighborChanged(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now())
+ nudDisp.OnNeighborChanged(e.nic.id, e.neigh)
}
}
@@ -183,7 +188,7 @@ func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) {
// has been removed.
func (e *neighborEntry) dispatchRemoveEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
- nudDisp.OnNeighborRemoved(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, e.neigh.State, time.Now())
+ nudDisp.OnNeighborRemoved(e.nic.id, e.neigh)
}
}
@@ -201,68 +206,24 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
prev := e.neigh.State
e.neigh.State = next
- e.neigh.UpdatedAt = time.Now()
+ e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds()
config := e.nudState.Config()
switch next {
case Incomplete:
- var retryCounter uint32
- var sendMulticastProbe func()
-
- sendMulticastProbe = func() {
- if retryCounter == config.MaxMulticastProbes {
- // "If no Neighbor Advertisement is received after
- // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed.
- // The sender MUST return ICMP destination unreachable indications with
- // code 3 (Address Unreachable) for each packet queued awaiting address
- // resolution." - RFC 4861 section 7.2.2
- //
- // There is no need to send an ICMP destination unreachable indication
- // since the failure to resolve the address is expected to only occur
- // on this node. Thus, redirecting traffic is currently not supported.
- //
- // "If the error occurs on a node other than the node originating the
- // packet, an ICMP error message is generated. If the error occurs on
- // the originating node, an implementation is not required to actually
- // create and send an ICMP error packet to the source, as long as the
- // upper-layer sender is notified through an appropriate mechanism
- // (e.g. return value from a procedure call). Note, however, that an
- // implementation may find it convenient in some cases to return errors
- // to the sender by taking the offending packet, generating an ICMP
- // error message, and then delivering it (locally) through the generic
- // error-handling routines.' - RFC 4861 section 2.1
- e.dispatchRemoveEventLocked()
- e.setStateLocked(Failed)
- return
- }
-
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.LinkEndpoint); err != nil {
- // There is no need to log the error here; the NUD implementation may
- // assume a working link. A valid link should be the responsibility of
- // the NIC/stack.LinkEndpoint.
- e.dispatchRemoveEventLocked()
- e.setStateLocked(Failed)
- return
- }
-
- retryCounter++
- e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe)
- e.job.Schedule(config.RetransmitTimer)
- }
-
- sendMulticastProbe()
+ panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.neigh, prev))
case Reachable:
e.job = e.nic.stack.newJob(&e.mu, func() {
- e.dispatchChangeEventLocked(Stale)
e.setStateLocked(Stale)
+ e.dispatchChangeEventLocked()
})
e.job.Schedule(e.nudState.ReachableTime())
case Delay:
e.job = e.nic.stack.newJob(&e.mu, func() {
- e.dispatchChangeEventLocked(Probe)
e.setStateLocked(Probe)
+ e.dispatchChangeEventLocked()
})
e.job.Schedule(config.DelayFirstProbeTime)
@@ -277,24 +238,23 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
return
}
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.LinkEndpoint); err != nil {
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr, e.nic); err != nil {
e.dispatchRemoveEventLocked()
e.setStateLocked(Failed)
return
}
retryCounter++
- if retryCounter == config.MaxUnicastProbes {
- e.dispatchRemoveEventLocked()
- e.setStateLocked(Failed)
- return
- }
-
e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe)
e.job.Schedule(config.RetransmitTimer)
}
- sendUnicastProbe()
+ // Send a probe in another gorountine to free this thread of execution
+ // for finishing the state transition. This is necessary to avoid
+ // deadlock where sending and processing probes are done synchronously,
+ // such as loopback and integration tests.
+ e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe)
+ e.job.Schedule(immediateDuration)
case Failed:
e.notifyWakersLocked()
@@ -315,15 +275,77 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
// being queued for outgoing transmission.
//
// Follows the logic defined in RFC 4861 section 7.3.3.
-func (e *neighborEntry) handlePacketQueuedLocked() {
+func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
switch e.neigh.State {
case Unknown:
- e.dispatchAddEventLocked(Incomplete)
- e.setStateLocked(Incomplete)
+ e.neigh.State = Incomplete
+ e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds()
+
+ e.dispatchAddEventLocked()
+
+ config := e.nudState.Config()
+
+ var retryCounter uint32
+ var sendMulticastProbe func()
+
+ sendMulticastProbe = func() {
+ if retryCounter == config.MaxMulticastProbes {
+ // "If no Neighbor Advertisement is received after
+ // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed.
+ // The sender MUST return ICMP destination unreachable indications with
+ // code 3 (Address Unreachable) for each packet queued awaiting address
+ // resolution." - RFC 4861 section 7.2.2
+ //
+ // There is no need to send an ICMP destination unreachable indication
+ // since the failure to resolve the address is expected to only occur
+ // on this node. Thus, redirecting traffic is currently not supported.
+ //
+ // "If the error occurs on a node other than the node originating the
+ // packet, an ICMP error message is generated. If the error occurs on
+ // the originating node, an implementation is not required to actually
+ // create and send an ICMP error packet to the source, as long as the
+ // upper-layer sender is notified through an appropriate mechanism
+ // (e.g. return value from a procedure call). Note, however, that an
+ // implementation may find it convenient in some cases to return errors
+ // to the sender by taking the offending packet, generating an ICMP
+ // error message, and then delivering it (locally) through the generic
+ // error-handling routines.' - RFC 4861 section 2.1
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ // As per RFC 4861 section 7.2.2:
+ //
+ // If the source address of the packet prompting the solicitation is the
+ // same as one of the addresses assigned to the outgoing interface, that
+ // address SHOULD be placed in the IP Source Address of the outgoing
+ // solicitation.
+ //
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, "", e.nic); err != nil {
+ // There is no need to log the error here; the NUD implementation may
+ // assume a working link. A valid link should be the responsibility of
+ // the NIC/stack.LinkEndpoint.
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ retryCounter++
+ e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe)
+ e.job.Schedule(config.RetransmitTimer)
+ }
+
+ // Send a probe in another gorountine to free this thread of execution
+ // for finishing the state transition. This is necessary to avoid
+ // deadlock where sending and processing probes are done synchronously,
+ // such as loopback and integration tests.
+ e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe)
+ e.job.Schedule(immediateDuration)
case Stale:
- e.dispatchChangeEventLocked(Delay)
e.setStateLocked(Delay)
+ e.dispatchChangeEventLocked()
case Incomplete, Reachable, Delay, Probe, Static, Failed:
// Do nothing
@@ -345,21 +367,21 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) {
switch e.neigh.State {
case Unknown, Incomplete, Failed:
e.neigh.LinkAddr = remoteLinkAddr
- e.dispatchAddEventLocked(Stale)
e.setStateLocked(Stale)
e.notifyWakersLocked()
+ e.dispatchAddEventLocked()
case Reachable, Delay, Probe:
if e.neigh.LinkAddr != remoteLinkAddr {
e.neigh.LinkAddr = remoteLinkAddr
- e.dispatchChangeEventLocked(Stale)
e.setStateLocked(Stale)
+ e.dispatchChangeEventLocked()
}
case Stale:
if e.neigh.LinkAddr != remoteLinkAddr {
e.neigh.LinkAddr = remoteLinkAddr
- e.dispatchChangeEventLocked(Stale)
+ e.dispatchChangeEventLocked()
}
case Static:
@@ -393,12 +415,11 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
e.neigh.LinkAddr = linkAddr
if flags.Solicited {
- e.dispatchChangeEventLocked(Reachable)
e.setStateLocked(Reachable)
} else {
- e.dispatchChangeEventLocked(Stale)
e.setStateLocked(Stale)
}
+ e.dispatchChangeEventLocked()
e.isRouter = flags.IsRouter
e.notifyWakersLocked()
@@ -411,8 +432,8 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
if isLinkAddrDifferent {
if !flags.Override {
if e.neigh.State == Reachable {
- e.dispatchChangeEventLocked(Stale)
e.setStateLocked(Stale)
+ e.dispatchChangeEventLocked()
}
break
}
@@ -421,23 +442,24 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
if !flags.Solicited {
if e.neigh.State != Stale {
- e.dispatchChangeEventLocked(Stale)
e.setStateLocked(Stale)
+ e.dispatchChangeEventLocked()
} else {
// Notify the LinkAddr change, even though NUD state hasn't changed.
- e.dispatchChangeEventLocked(e.neigh.State)
+ e.dispatchChangeEventLocked()
}
break
}
}
if flags.Solicited && (flags.Override || !isLinkAddrDifferent) {
- if e.neigh.State != Reachable {
- e.dispatchChangeEventLocked(Reachable)
- }
+ wasReachable := e.neigh.State == Reachable
// Set state to Reachable again to refresh timers.
e.setStateLocked(Reachable)
e.notifyWakersLocked()
+ if !wasReachable {
+ e.dispatchChangeEventLocked()
+ }
}
if e.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.neigh.Addr) {
@@ -475,11 +497,12 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
func (e *neighborEntry) handleUpperLevelConfirmationLocked() {
switch e.neigh.State {
case Reachable, Stale, Delay, Probe:
- if e.neigh.State != Reachable {
- e.dispatchChangeEventLocked(Reachable)
- // Set state to Reachable again to refresh timers.
- }
+ wasReachable := e.neigh.State == Reachable
+ // Set state to Reachable again to refresh timers.
e.setStateLocked(Reachable)
+ if !wasReachable {
+ e.dispatchChangeEventLocked()
+ }
case Unknown, Incomplete, Failed, Static:
// Do nothing
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index 3ee2a3b31..c2b763325 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -47,24 +47,27 @@ const (
entryTestNetDefaultMTU = 65536
)
+// runImmediatelyScheduledJobs runs all jobs scheduled to run at the current
+// time.
+func runImmediatelyScheduledJobs(clock *faketime.ManualClock) {
+ clock.Advance(immediateDuration)
+}
+
// eventDiffOpts are the options passed to cmp.Diff to compare entry events.
-// The UpdatedAt field is ignored due to a lack of a deterministic method to
-// predict the time that an event will be dispatched.
+// The UpdatedAtNanos field is ignored due to a lack of a deterministic method
+// to predict the time that an event will be dispatched.
func eventDiffOpts() []cmp.Option {
return []cmp.Option{
- cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"),
+ cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAtNanos"),
}
}
// eventDiffOptsWithSort is like eventDiffOpts but also includes an option to
// sort slices of events for cases where ordering must be ignored.
func eventDiffOptsWithSort() []cmp.Option {
- return []cmp.Option{
- cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"),
- cmpopts.SortSlices(func(a, b testEntryEventInfo) bool {
- return strings.Compare(string(a.Addr), string(b.Addr)) < 0
- }),
- }
+ return append(eventDiffOpts(), cmpopts.SortSlices(func(a, b testEntryEventInfo) bool {
+ return strings.Compare(string(a.Entry.Addr), string(b.Entry.Addr)) < 0
+ }))
}
// The following unit tests exercise every state transition and verify its
@@ -125,14 +128,11 @@ func (t testEntryEventType) String() string {
type testEntryEventInfo struct {
EventType testEntryEventType
NICID tcpip.NICID
- Addr tcpip.Address
- LinkAddr tcpip.LinkAddress
- State NeighborState
- UpdatedAt time.Time
+ Entry NeighborEntry
}
func (e testEntryEventInfo) String() string {
- return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.EventType, e.NICID, e.Addr, e.LinkAddr, e.State)
+ return fmt.Sprintf("%s event for NIC #%d, %#v", e.EventType, e.NICID, e.Entry)
}
// testNUDDispatcher implements NUDDispatcher to validate the dispatching of
@@ -150,36 +150,27 @@ func (d *testNUDDispatcher) queueEvent(e testEntryEventInfo) {
d.events = append(d.events, e)
}
-func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry NeighborEntry) {
d.queueEvent(testEntryEventInfo{
EventType: entryTestAdded,
NICID: nicID,
- Addr: addr,
- LinkAddr: linkAddr,
- State: state,
- UpdatedAt: updatedAt,
+ Entry: entry,
})
}
-func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry NeighborEntry) {
d.queueEvent(testEntryEventInfo{
EventType: entryTestChanged,
NICID: nicID,
- Addr: addr,
- LinkAddr: linkAddr,
- State: state,
- UpdatedAt: updatedAt,
+ Entry: entry,
})
}
-func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry NeighborEntry) {
d.queueEvent(testEntryEventInfo{
EventType: entryTestRemoved,
NICID: nicID,
- Addr: addr,
- LinkAddr: linkAddr,
- State: state,
- UpdatedAt: updatedAt,
+ Entry: entry,
})
}
@@ -202,9 +193,9 @@ func (p entryTestProbeInfo) String() string {
// LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts
// to the local network if linkAddr is the zero value.
-func (r *entryTestLinkResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
+func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
p := entryTestProbeInfo{
- RemoteAddress: addr,
+ RemoteAddress: targetAddr,
RemoteLinkAddress: linkAddr,
LocalAddress: localAddr,
}
@@ -245,7 +236,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
nudState := NewNUDState(c, rng)
linkRes := entryTestLinkResolver{}
- entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, entryTestAddr2 /* localAddr */, nudState, &linkRes)
+ entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, nudState, &linkRes)
// Stub out the neighbor cache to verify deletion from the cache.
nic.neigh = &neighborCache{
@@ -323,15 +314,16 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) {
func TestEntryUnknownToIncomplete(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Incomplete; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -350,9 +342,11 @@ func TestEntryUnknownToIncomplete(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
}
{
@@ -367,7 +361,7 @@ func TestEntryUnknownToIncomplete(t *testing.T) {
func TestEntryUnknownToStale(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
e.handleProbeLocked(entryTestLinkAddr1)
@@ -377,6 +371,7 @@ func TestEntryUnknownToStale(t *testing.T) {
e.mu.Unlock()
// No probes should have been sent.
+ runImmediatelyScheduledJobs(clock)
linkRes.mu.Lock()
diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil))
linkRes.mu.Unlock()
@@ -388,9 +383,11 @@ func TestEntryUnknownToStale(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -406,11 +403,11 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Incomplete; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
- updatedAt := e.neigh.UpdatedAt
+ updatedAtNanos := e.neigh.UpdatedAtNanos
e.mu.Unlock()
clock.Advance(c.RetransmitTimer)
@@ -437,7 +434,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
}
e.mu.Lock()
- if got, want := e.neigh.UpdatedAt, updatedAt; got != want {
+ if got, want := e.neigh.UpdatedAtNanos, updatedAtNanos; got != want {
t.Errorf("got e.neigh.UpdatedAt = %q, want = %q", got, want)
}
e.mu.Unlock()
@@ -468,16 +465,20 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestRemoved,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
}
nudDisp.mu.Lock()
@@ -487,7 +488,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
nudDisp.mu.Unlock()
e.mu.Lock()
- if got, notWant := e.neigh.UpdatedAt, updatedAt; got == notWant {
+ if got, notWant := e.neigh.UpdatedAtNanos, updatedAtNanos; got == notWant {
t.Errorf("expected e.neigh.UpdatedAt to change, got = %q", got)
}
e.mu.Unlock()
@@ -495,23 +496,16 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
func TestEntryIncompleteToReachable(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -526,20 +520,35 @@ func TestEntryIncompleteToReachable(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -555,7 +564,7 @@ func TestEntryIncompleteToReachable(t *testing.T) {
// to Reachable.
func TestEntryAddsAndClearsWakers(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
w := sleep.Waker{}
s := sleep.Sleeper{}
@@ -563,7 +572,25 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
defer s.Done()
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
if got := e.wakers; got != nil {
t.Errorf("got e.wakers = %v, want = nil", got)
}
@@ -587,34 +614,24 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
}
e.mu.Unlock()
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -626,26 +643,16 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: true,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.isRouter, true; got != want {
- t.Errorf("got e.isRouter = %t, want = %t", got, want)
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -659,20 +666,38 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
}
linkRes.mu.Unlock()
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: true,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if !e.isRouter {
+ t.Errorf("got e.isRouter = %t, want = true", e.isRouter)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -684,23 +709,16 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
func TestEntryIncompleteToStale(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Incomplete; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Incomplete {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete)
}
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -715,20 +733,35 @@ func TestEntryIncompleteToStale(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -744,7 +777,7 @@ func TestEntryIncompleteToFailed(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Incomplete; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
@@ -783,16 +816,20 @@ func TestEntryIncompleteToFailed(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestRemoved,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
}
nudDisp.mu.Lock()
@@ -817,12 +854,30 @@ func (*testLocker) Unlock() {}
func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: true,
Override: false,
@@ -848,34 +903,24 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
}
e.mu.Unlock()
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -893,27 +938,13 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr1)
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -928,20 +959,42 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.handleProbeLocked(entryTestLinkAddr1)
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -961,17 +1014,10 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -986,29 +1032,46 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.mu.Unlock()
+
clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1026,24 +1089,13 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr2)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1058,27 +1110,48 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1086,38 +1159,17 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1132,27 +1184,52 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T)
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1160,38 +1237,17 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T)
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1206,27 +1262,52 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1234,37 +1315,17 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr1)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1279,20 +1340,42 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handleProbeLocked(entryTestLinkAddr1)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1304,31 +1387,13 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1343,27 +1408,55 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1375,10 +1468,28 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
e.mu.Lock()
- e.handlePacketQueuedLocked()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -1400,41 +1511,33 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
}
e.mu.Unlock()
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1446,31 +1549,13 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1485,27 +1570,55 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1517,27 +1630,13 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr2)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1552,27 +1651,51 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1584,24 +1707,13 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
func TestEntryStaleToDelay(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1616,27 +1728,48 @@ func TestEntryStaleToDelay(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
}
nudDisp.mu.Lock()
@@ -1656,22 +1789,10 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleUpperLevelConfirmationLocked()
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1686,43 +1807,68 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
- clock.Advance(c.BaseReachableTime)
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleUpperLevelConfirmationLocked()
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ e.mu.Unlock()
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1743,29 +1889,10 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: true,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Reachable; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1780,43 +1907,75 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
- clock.Advance(c.BaseReachableTime)
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Reachable {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr2 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2)
+ }
+ e.mu.Unlock()
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1837,13 +1996,31 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if e.neigh.State != Delay {
t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
}
@@ -1860,57 +2037,52 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
}
e.mu.Unlock()
- wantProbes := []entryTestProbeInfo{
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
- }
-
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1922,32 +2094,13 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
- t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -1962,27 +2115,56 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ if e.neigh.LinkAddr != entryTestLinkAddr1 {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
}
nudDisp.mu.Lock()
@@ -1994,25 +2176,13 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleProbeLocked(entryTestLinkAddr2)
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -2027,34 +2197,58 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2066,29 +2260,13 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
c := DefaultNUDConfigurations()
- e, nudDisp, linkRes, _ := entryTestSetup(c)
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
- e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: false,
- IsRouter: false,
- })
- e.handlePacketQueuedLocked()
- if got, want := e.neigh.State, Delay; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
- Solicited: false,
- Override: true,
- IsRouter: false,
- })
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
+ runImmediatelyScheduledJobs(clock)
wantProbes := []entryTestProbeInfo{
{
RemoteAddress: entryTestAddr1,
@@ -2103,34 +2281,62 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ if e.neigh.State != Delay {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if e.neigh.State != Stale {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale)
+ }
+ e.mu.Unlock()
+
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2145,69 +2351,91 @@ func TestEntryDelayToProbe(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Delay; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
}
nudDisp.mu.Lock()
@@ -2228,36 +2456,50 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2274,37 +2516,47 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2312,12 +2564,6 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
@@ -2325,36 +2571,50 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2375,37 +2635,47 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2413,12 +2683,6 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Stale; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
@@ -2426,36 +2690,51 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2479,30 +2758,38 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
}
nudDisp.mu.Lock()
@@ -2529,17 +2816,14 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
e.mu.Lock()
e.handleProbeLocked(entryTestLinkAddr1)
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
wantProbes := []entryTestProbeInfo{
- // Probe caused by the Delay-to-Probe transition
{
RemoteAddress: entryTestAddr1,
RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
},
}
linkRes.mu.Lock()
@@ -2567,42 +2851,51 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
e.mu.Unlock()
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2622,36 +2915,50 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2672,49 +2979,60 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
e.mu.Unlock()
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2734,36 +3052,50 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2781,49 +3113,60 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
e.mu.Unlock()
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2843,36 +3186,50 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The second probe is caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
e.mu.Lock()
@@ -2890,49 +3247,60 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
e.mu.Unlock()
clock.Advance(c.BaseReachableTime)
-
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2946,87 +3314,116 @@ func TestEntryProbeToFailed(t *testing.T) {
c := DefaultNUDConfigurations()
c.MaxMulticastProbes = 3
c.MaxUnicastProbes = 3
+ c.DelayFirstProbeTime = c.RetransmitTimer
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
- waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes)
- clock.Advance(waitFor)
+ // Observe each probe sent while in the Probe state.
+ for i := uint32(0); i < c.MaxUnicastProbes; i++ {
+ clock.Advance(c.RetransmitTimer)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probe #%d mismatch (-got, +want):\n%s", i+1, diff)
+ }
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The next three probe are caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
+ e.mu.Lock()
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
+ }
+ e.mu.Unlock()
}
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+
+ // Wait for the last probe to expire, causing a transition to Failed.
+ clock.Advance(c.RetransmitTimer)
+ e.mu.Lock()
+ if e.neigh.State != Failed {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed)
}
+ e.mu.Unlock()
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestRemoved,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
}
nudDisp.mu.Lock()
@@ -3034,12 +3431,6 @@ func TestEntryProbeToFailed(t *testing.T) {
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Failed; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryFailedGetsDeleted(t *testing.T) {
@@ -3054,84 +3445,106 @@ func TestEntryFailedGetsDeleted(t *testing.T) {
}
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ runImmediatelyScheduledJobs(clock)
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime
clock.Advance(waitFor)
-
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The next three probe are caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- }
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ {
+ wantProbes := []entryTestProbeInfo{
+ // The next three probe are sent in Probe.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
}
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestRemoved,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
}
nudDisp.mu.Lock()
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index dcd4319bf..60c81a3aa 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -273,6 +273,15 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
return n.writePacket(r, gso, protocol, pkt)
}
+// WritePacketToRemote implements NetworkInterface.
+func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+ r := Route{
+ NetProto: protocol,
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+ return n.writePacket(&r, gso, protocol, pkt)
+}
+
func (n *NIC) writePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
// WritePacket takes ownership of pkt, calculate numBytes first.
numBytes := pkt.Size()
@@ -339,6 +348,16 @@ func (n *NIC) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address
return n.getAddressOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous)
}
+func (n *NIC) hasAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
+ ep := n.getAddressOrCreateTempInner(protocol, addr, false, NeverPrimaryEndpoint)
+ if ep != nil {
+ ep.DecRef()
+ return true
+ }
+
+ return false
+}
+
// findEndpoint finds the endpoint, if any, with the given address.
func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint {
return n.getAddressOrCreateTemp(protocol, address, peb, spoofing)
@@ -546,10 +565,10 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool {
}
func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) {
- r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */)
+ r := makeRoute(protocol, dst, src, n, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */)
defer r.Release()
- r.RemoteLinkAddress = remotelinkAddr
- n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt)
+ r.PopulatePacketInfo(pkt)
+ n.getNetworkEndpoint(protocol).HandlePacket(pkt)
}
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
@@ -585,6 +604,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
if local == "" {
local = n.LinkEndpoint.LinkAddress()
}
+ pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0
// Are any packet type sockets listening for this network protocol?
packetEPs := n.mu.packetEPs[protocol]
@@ -660,14 +680,13 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
}
// Found a NIC.
- n := r.nic
+ n := r.localAddressNIC
if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil {
if n.isValidForOutgoing(addressEndpoint) {
- r.LocalLinkAddress = n.LinkEndpoint.LinkAddress()
- r.RemoteLinkAddress = remote
+ pkt.NICID = n.ID()
r.RemoteAddress = src
- // TODO(b/123449044): Update the source NIC as well.
- n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt)
+ pkt.NetworkPacketInfo = r.networkPacketInfo()
+ n.getNetworkEndpoint(protocol).HandlePacket(pkt)
addressEndpoint.DecRef()
r.Release()
return
@@ -678,7 +697,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// n doesn't have a destination endpoint.
// Send the packet out of n.
- // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6.
+ // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease
+ // the TTL field for ipv4/ipv6.
// pkt may have set its header and may not have enough headroom for
// link-layer header for the other link to prepend. Here we create a new
@@ -725,7 +745,7 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
-func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition {
+func (n *NIC) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
@@ -737,7 +757,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// Raw socket packets are delivered based solely on the transport
// protocol number. We do not inspect the payload to ensure it's
// validly formed.
- n.stack.demux.deliverRawPacket(r, protocol, pkt)
+ n.stack.demux.deliverRawPacket(protocol, pkt)
// TransportHeader is empty only when pkt is an ICMP packet or was reassembled
// from fragments.
@@ -766,14 +786,25 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
return TransportPacketHandled
}
- id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
- if n.stack.demux.deliverPacket(r, protocol, pkt, id) {
+ netProto, ok := n.stack.networkProtocols[pkt.NetworkProtocolNumber]
+ if !ok {
+ panic(fmt.Sprintf("expected network protocol = %d, have = %#v", pkt.NetworkProtocolNumber, n.stack.networkProtocolNumbers()))
+ }
+
+ src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View())
+ id := TransportEndpointID{
+ LocalPort: dstPort,
+ LocalAddress: dst,
+ RemotePort: srcPort,
+ RemoteAddress: src,
+ }
+ if n.stack.demux.deliverPacket(protocol, pkt, id) {
return TransportPacketHandled
}
// Try to deliver to per-stack default handler.
if state.defaultHandler != nil {
- if state.defaultHandler(r, id, pkt) {
+ if state.defaultHandler(id, pkt) {
return TransportPacketHandled
}
}
@@ -781,7 +812,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// We could not find an appropriate destination for this packet so
// give the protocol specific error handler a chance to handle it.
// If it doesn't handle it then we should do so.
- switch res := transProto.HandleUnknownDestinationPacket(r, id, pkt); res {
+ switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res {
case UnknownDestinationPacketMalformed:
n.stack.stats.MalformedRcvdPackets.Increment()
return TransportPacketHandled
@@ -885,7 +916,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep
}
// isValidForOutgoing returns true if the endpoint can be used to send out a
-// packet. It requires the endpoint to not be marked expired (i.e., its address)
+// packet. It requires the endpoint to not be marked expired (i.e., its address
// has been removed) unless the NIC is in spoofing mode, or temporary.
func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool {
n.mu.RLock()
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 97a96af62..5b5c58afb 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -83,8 +83,7 @@ func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip
}
// HandlePacket implements NetworkEndpoint.HandlePacket.
-func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) {
-}
+func (*testIPv6Endpoint) HandlePacket(*PacketBuffer) {}
// Close implements NetworkEndpoint.Close.
func (e *testIPv6Endpoint) Close() {
@@ -169,7 +168,7 @@ func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements LinkAddressResolver.
-func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error {
+func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
return nil
}
diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go
index e1ec15487..ab629b3a4 100644
--- a/pkg/tcpip/stack/nud.go
+++ b/pkg/tcpip/stack/nud.go
@@ -129,7 +129,7 @@ type NUDDispatcher interface {
// the stack's operation.
//
// May be called concurrently.
- OnNeighborAdded(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+ OnNeighborAdded(tcpip.NICID, NeighborEntry)
// OnNeighborChanged will be called when an entry in a NIC's (with ID nicID)
// neighbor table changes state and/or link address.
@@ -138,7 +138,7 @@ type NUDDispatcher interface {
// the stack's operation.
//
// May be called concurrently.
- OnNeighborChanged(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+ OnNeighborChanged(tcpip.NICID, NeighborEntry)
// OnNeighborRemoved will be called when an entry is removed from a NIC's
// (with ID nicID) neighbor table.
@@ -147,7 +147,7 @@ type NUDDispatcher interface {
// the stack's operation.
//
// May be called concurrently.
- OnNeighborRemoved(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+ OnNeighborRemoved(tcpip.NICID, NeighborEntry)
}
// ReachabilityConfirmationFlags describes the flags used within a reachability
@@ -177,7 +177,7 @@ type NUDHandler interface {
// Neighbor Solicitation for ARP or NDP, respectively). Validation of the
// probe needs to be performed before calling this function since the
// Neighbor Cache doesn't have access to view the NIC's assigned addresses.
- HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver)
+ HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver)
// HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP
// reply or Neighbor Advertisement for ARP or NDP, respectively).
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 7f54a6de8..664cc6fa0 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -112,6 +112,16 @@ type PacketBuffer struct {
// PktType indicates the SockAddrLink.PacketType of the packet as defined in
// https://www.man7.org/linux/man-pages/man7/packet.7.html.
PktType tcpip.PacketType
+
+ // NICID is the ID of the interface the network packet was received at.
+ NICID tcpip.NICID
+
+ // RXTransportChecksumValidated indicates that transport checksum verification
+ // may be safely skipped.
+ RXTransportChecksumValidated bool
+
+ // NetworkPacketInfo holds an incoming packet's network-layer information.
+ NetworkPacketInfo NetworkPacketInfo
}
// NewPacketBuffer creates a new PacketBuffer with opts.
@@ -240,20 +250,33 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum
// Clone should be called in such cases so that no modifications is done to
// underlying packet payload.
func (pk *PacketBuffer) Clone() *PacketBuffer {
- newPk := &PacketBuffer{
- PacketBufferEntry: pk.PacketBufferEntry,
- Data: pk.Data.Clone(nil),
- headers: pk.headers,
- header: pk.header,
- Hash: pk.Hash,
- Owner: pk.Owner,
- EgressRoute: pk.EgressRoute,
- GSOOptions: pk.GSOOptions,
- NetworkProtocolNumber: pk.NetworkProtocolNumber,
- NatDone: pk.NatDone,
- TransportProtocolNumber: pk.TransportProtocolNumber,
+ return &PacketBuffer{
+ PacketBufferEntry: pk.PacketBufferEntry,
+ Data: pk.Data.Clone(nil),
+ headers: pk.headers,
+ header: pk.header,
+ Hash: pk.Hash,
+ Owner: pk.Owner,
+ GSOOptions: pk.GSOOptions,
+ NetworkProtocolNumber: pk.NetworkProtocolNumber,
+ NatDone: pk.NatDone,
+ TransportProtocolNumber: pk.TransportProtocolNumber,
+ PktType: pk.PktType,
+ NICID: pk.NICID,
+ RXTransportChecksumValidated: pk.RXTransportChecksumValidated,
+ NetworkPacketInfo: pk.NetworkPacketInfo,
}
- return newPk
+}
+
+// SourceLinkAddress returns the source link address of the packet.
+func (pk *PacketBuffer) SourceLinkAddress() tcpip.LinkAddress {
+ link := pk.LinkHeader().View()
+
+ if link.IsEmpty() {
+ return ""
+ }
+
+ return header.Ethernet(link).SourceAddress()
}
// Network returns the network header as a header.Network.
@@ -270,6 +293,17 @@ func (pk *PacketBuffer) Network() header.Network {
}
}
+// CloneToInbound makes a shallow copy of the packet buffer to be used as an
+// inbound packet.
+//
+// See PacketBuffer.Data for details about how a packet buffer holds an inbound
+// packet.
+func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
+ return NewPacketBuffer(PacketBufferOptions{
+ Data: buffer.NewVectorisedView(pk.Size(), pk.Views()),
+ })
+}
+
// headerInfo stores metadata about a header in a packet.
type headerInfo struct {
// buf is the memorized slice for both prepended and consumed header.
diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go
index f838eda8d..5d364a2b0 100644
--- a/pkg/tcpip/stack/pending_packets.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -106,7 +106,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro
} else if _, err := p.route.Resolve(nil); err != nil {
p.route.Stats().IP.OutgoingPacketErrors.Increment()
} else {
- p.route.nic.writePacket(p.route, nil /* gso */, p.proto, p.pkt)
+ p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt)
}
p.route.Release()
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index defb9129b..b8f333057 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -63,17 +63,28 @@ const (
ControlUnknown
)
+// NetworkPacketInfo holds information about a network layer packet.
+type NetworkPacketInfo struct {
+ // RemoteAddressBroadcast is true if the packet's remote address is a
+ // broadcast address.
+ RemoteAddressBroadcast bool
+
+ // LocalAddressBroadcast is true if the packet's local address is a broadcast
+ // address.
+ LocalAddressBroadcast bool
+}
+
// TransportEndpoint is the interface that needs to be implemented by transport
// protocol (e.g., tcp, udp) endpoints that can handle packets.
type TransportEndpoint interface {
// UniqueID returns an unique ID for this transport endpoint.
UniqueID() uint64
- // HandlePacket is called by the stack when new packets arrive to
- // this transport endpoint. It sets pkt.TransportHeader.
+ // HandlePacket is called by the stack when new packets arrive to this
+ // transport endpoint. It sets the packet buffer's transport header.
//
- // HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer)
+ // HandlePacket takes ownership of the packet.
+ HandlePacket(TransportEndpointID, *PacketBuffer)
// HandleControlPacket is called by the stack when new control (e.g.
// ICMP) packets arrive to this transport endpoint.
@@ -105,8 +116,8 @@ type RawTransportEndpoint interface {
// this transport endpoint. The packet contains all data from the link
// layer up.
//
- // HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, pkt *PacketBuffer)
+ // HandlePacket takes ownership of the packet.
+ HandlePacket(*PacketBuffer)
}
// PacketEndpoint is the interface that needs to be implemented by packet
@@ -172,9 +183,9 @@ type TransportProtocol interface {
// protocol that don't match any existing endpoint. For example,
// it is targeted at a port that has no listeners.
//
- // HandleUnknownDestinationPacket takes ownership of pkt if it handles
+ // HandleUnknownDestinationPacket takes ownership of the packet if it handles
// the issue.
- HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) UnknownDestinationPacketDisposition
+ HandleUnknownDestinationPacket(TransportEndpointID, *PacketBuffer) UnknownDestinationPacketDisposition
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -227,8 +238,8 @@ type TransportDispatcher interface {
//
// pkt.NetworkHeader must be set before calling DeliverTransportPacket.
//
- // DeliverTransportPacket takes ownership of pkt.
- DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition
+ // DeliverTransportPacket takes ownership of the packet.
+ DeliverTransportPacket(tcpip.TransportProtocolNumber, *PacketBuffer) TransportPacketDisposition
// DeliverTransportControlPacket delivers control packets to the
// appropriate transport protocol endpoint.
@@ -329,6 +340,9 @@ type AssignableAddressEndpoint interface {
// AddressWithPrefix returns the endpoint's address.
AddressWithPrefix() tcpip.AddressWithPrefix
+ // Subnet returns the subnet of the endpoint's address.
+ Subnet() tcpip.Subnet
+
// IsAssigned returns whether or not the endpoint is considered bound
// to its NetworkEndpoint.
IsAssigned(allowExpired bool) bool
@@ -490,6 +504,9 @@ type NetworkInterface interface {
// Enabled returns true if the interface is enabled.
Enabled() bool
+
+ // WritePacketToRemote writes the packet to the given remote link address.
+ WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error
}
// NetworkEndpoint is the interface that needs to be implemented by endpoints
@@ -544,7 +561,7 @@ type NetworkEndpoint interface {
// this network endpoint. It sets pkt.NetworkHeader.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, pkt *PacketBuffer)
+ HandlePacket(pkt *PacketBuffer)
// Close is called when the endpoint is reomved from a stack.
Close()
@@ -764,13 +781,13 @@ type InjectableLinkEndpoint interface {
// A LinkAddressResolver is an extension to a NetworkProtocol that
// can resolve link addresses.
type LinkAddressResolver interface {
- // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts
- // the request on the local network if remoteLinkAddr is the zero value. The
- // request is sent on linkEP with localAddr as the source.
+ // LinkAddressRequest sends a request for the link address of the target
+ // address. The request is broadcasted on the local network if a remote link
+ // address is not provided.
//
- // A valid response will cause the discovery protocol's network
- // endpoint to call AddLinkAddress.
- LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error
+ // The request is sent from the passed network interface. If the interface
+ // local address is unspecified, any interface local address may be used.
+ LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) *tcpip.Error
// ResolveStaticAddress attempts to resolve address without sending
// requests. It either resolves the name immediately or returns the
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index b76e2d37b..15ff437c7 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -15,6 +15,8 @@
package stack
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -45,11 +47,16 @@ type Route struct {
// Loop controls where WritePacket should send packets.
Loop PacketLooping
- // nic is the NIC the route goes through.
- nic *NIC
+ // localAddressNIC is the interface the address is associated with.
+ // TODO(gvisor.dev/issue/4548): Remove this field once we can query the
+ // address's assigned status without the NIC.
+ localAddressNIC *NIC
+
+ // localAddressEndpoint is the local address this route is associated with.
+ localAddressEndpoint AssignableAddressEndpoint
- // addressEndpoint is the local address this route is associated with.
- addressEndpoint AssignableAddressEndpoint
+ // outgoingNIC is the interface this route uses to write packets.
+ outgoingNIC *NIC
// linkCache is set if link address resolution is enabled for this protocol on
// the route's NIC.
@@ -60,51 +67,144 @@ type Route struct {
linkRes LinkAddressResolver
}
+// constructAndValidateRoute validates and initializes a route. It takes
+// ownership of the provided local address.
+//
+// Returns an empty route if validation fails.
+func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route {
+ addrWithPrefix := addressEndpoint.AddressWithPrefix()
+
+ if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(addrWithPrefix.Address) {
+ addressEndpoint.DecRef()
+ return Route{}
+ }
+
+ // If no remote address is provided, use the local address.
+ if len(remoteAddr) == 0 {
+ remoteAddr = addrWithPrefix.Address
+ }
+
+ r := makeRoute(
+ netProto,
+ addrWithPrefix.Address,
+ remoteAddr,
+ outgoingNIC,
+ localAddressNIC,
+ addressEndpoint,
+ handleLocal,
+ multicastLoop,
+ )
+
+ // If the route requires us to send a packet through some gateway, do not
+ // broadcast it.
+ if len(gateway) > 0 {
+ r.NextHop = gateway
+ } else if subnet := addrWithPrefix.Subnet(); subnet.IsBroadcast(remoteAddr) {
+ r.RemoteLinkAddress = header.EthernetBroadcastAddress
+ }
+
+ return r
+}
+
// makeRoute initializes a new route. It takes ownership of the provided
// AssignableAddressEndpoint.
-func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, nic *NIC, addressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route {
+func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route {
+ if localAddressNIC.stack != outgoingNIC.stack {
+ panic(fmt.Sprintf("cannot create a route with NICs from different stacks"))
+ }
+
loop := PacketOut
- if handleLocal && localAddr != "" && remoteAddr == localAddr {
- loop = PacketLoop
- } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) {
- loop |= PacketLoop
- } else if remoteAddr == header.IPv4Broadcast {
- loop |= PacketLoop
+
+ // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the
+ // link endpoint level. We can remove this check once loopback interfaces
+ // loop back packets at the network layer.
+ if !outgoingNIC.IsLoopback() {
+ if handleLocal && localAddr != "" && remoteAddr == localAddr {
+ loop = PacketLoop
+ } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) {
+ loop |= PacketLoop
+ } else if remoteAddr == header.IPv4Broadcast {
+ loop |= PacketLoop
+ } else if subnet := localAddressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) {
+ loop |= PacketLoop
+ }
}
+ return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
+}
+
+func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) Route {
r := Route{
- NetProto: netProto,
- LocalAddress: localAddr,
- LocalLinkAddress: nic.LinkEndpoint.LinkAddress(),
- RemoteAddress: remoteAddr,
- addressEndpoint: addressEndpoint,
- nic: nic,
- Loop: loop,
+ NetProto: netProto,
+ LocalAddress: localAddr,
+ LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(),
+ RemoteAddress: remoteAddr,
+ localAddressNIC: localAddressNIC,
+ localAddressEndpoint: localAddressEndpoint,
+ outgoingNIC: outgoingNIC,
+ Loop: loop,
}
- if r.nic.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
- if linkRes, ok := r.nic.stack.linkAddrResolvers[r.NetProto]; ok {
+ if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
+ if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok {
r.linkRes = linkRes
- r.linkCache = r.nic.stack
+ r.linkCache = r.outgoingNIC.stack
}
}
return r
}
+// makeLocalRoute initializes a new local route. It takes ownership of the
+// provided AssignableAddressEndpoint.
+//
+// A local route is a route to a destination that is local to the stack.
+func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) Route {
+ loop := PacketLoop
+ // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the
+ // link endpoint level. We can remove this check once loopback interfaces
+ // loop back packets at the network layer.
+ if outgoingNIC.IsLoopback() {
+ loop = PacketOut
+ }
+ return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop)
+}
+
+// PopulatePacketInfo populates a packet buffer's packet information fields.
+//
+// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by
+// the network layer.
+func (r *Route) PopulatePacketInfo(pkt *PacketBuffer) {
+ if r.local() {
+ pkt.RXTransportChecksumValidated = true
+ }
+ pkt.NetworkPacketInfo = r.networkPacketInfo()
+}
+
+// networkPacketInfo returns the network packet information of the route.
+//
+// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by
+// the network layer.
+func (r *Route) networkPacketInfo() NetworkPacketInfo {
+ return NetworkPacketInfo{
+ RemoteAddressBroadcast: r.IsOutboundBroadcast(),
+ LocalAddressBroadcast: r.isInboundBroadcast(),
+ }
+}
+
// NICID returns the id of the NIC from which this route originates.
func (r *Route) NICID() tcpip.NICID {
- return r.nic.ID()
+ return r.outgoingNIC.ID()
}
// MaxHeaderLength forwards the call to the network endpoint's implementation.
func (r *Route) MaxHeaderLength() uint16 {
- return r.nic.getNetworkEndpoint(r.NetProto).MaxHeaderLength()
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MaxHeaderLength()
}
// Stats returns a mutable copy of current stats.
func (r *Route) Stats() tcpip.Stats {
- return r.nic.stack.Stats()
+ return r.outgoingNIC.stack.Stats()
}
// PseudoHeaderChecksum forwards the call to the network endpoint's
@@ -113,14 +213,38 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot
return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress, totalLen)
}
-// Capabilities returns the link-layer capabilities of the route.
-func (r *Route) Capabilities() LinkEndpointCapabilities {
- return r.nic.LinkEndpoint.Capabilities()
+// RequiresTXTransportChecksum returns false if the route does not require
+// transport checksums to be populated.
+func (r *Route) RequiresTXTransportChecksum() bool {
+ if r.local() {
+ return false
+ }
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityTXChecksumOffload == 0
+}
+
+// HasSoftwareGSOCapability returns true if the route supports software GSO.
+func (r *Route) HasSoftwareGSOCapability() bool {
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySoftwareGSO != 0
+}
+
+// HasHardwareGSOCapability returns true if the route supports hardware GSO.
+func (r *Route) HasHardwareGSOCapability() bool {
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityHardwareGSO != 0
+}
+
+// HasSaveRestoreCapability returns true if the route supports save/restore.
+func (r *Route) HasSaveRestoreCapability() bool {
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySaveRestore != 0
+}
+
+// HasDisconncetOkCapability returns true if the route supports disconnecting.
+func (r *Route) HasDisconncetOkCapability() bool {
+ return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityDisconnectOk != 0
}
// GSOMaxSize returns the maximum GSO packet size.
func (r *Route) GSOMaxSize() uint32 {
- if gso, ok := r.nic.LinkEndpoint.(GSOEndpoint); ok {
+ if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok {
return gso.GSOMaxSize()
}
return 0
@@ -158,8 +282,15 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
nextAddr = r.RemoteAddress
}
- if neigh := r.nic.neigh; neigh != nil {
- entry, ch, err := neigh.entry(nextAddr, r.LocalAddress, r.linkRes, waker)
+ // If specified, the local address used for link address resolution must be an
+ // address on the outgoing interface.
+ var linkAddressResolutionRequestLocalAddr tcpip.Address
+ if r.localAddressNIC == r.outgoingNIC {
+ linkAddressResolutionRequestLocalAddr = r.LocalAddress
+ }
+
+ if neigh := r.outgoingNIC.neigh; neigh != nil {
+ entry, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, waker)
if err != nil {
return ch, err
}
@@ -167,7 +298,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
return nil, nil
}
- linkAddr, ch, err := r.linkCache.GetLinkAddress(r.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker)
+ linkAddr, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, waker)
if err != nil {
return ch, err
}
@@ -182,76 +313,102 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
nextAddr = r.RemoteAddress
}
- if neigh := r.nic.neigh; neigh != nil {
+ if neigh := r.outgoingNIC.neigh; neigh != nil {
neigh.removeWaker(nextAddr, waker)
return
}
- r.linkCache.RemoveWaker(r.nic.ID(), nextAddr, waker)
+ r.linkCache.RemoveWaker(r.outgoingNIC.ID(), nextAddr, waker)
+}
+
+// local returns true if the route is a local route.
+func (r *Route) local() bool {
+ return r.Loop == PacketLoop || r.outgoingNIC.IsLoopback()
}
// IsResolutionRequired returns true if Resolve() must be called to resolve
-// the link address before the this route can be written to.
+// the link address before the route can be written to.
//
-// The NIC r uses must not be locked.
+// The NICs the route is associated with must not be locked.
func (r *Route) IsResolutionRequired() bool {
- if r.nic.neigh != nil {
- return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkRes != nil && r.RemoteLinkAddress == ""
+ if !r.isValidForOutgoing() || r.RemoteLinkAddress != "" || r.local() {
+ return false
}
- return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkCache != nil && r.RemoteLinkAddress == ""
+
+ return (r.outgoingNIC.neigh != nil && r.linkRes != nil) || r.linkCache != nil
+}
+
+func (r *Route) isValidForOutgoing() bool {
+ if !r.outgoingNIC.Enabled() {
+ return false
+ }
+
+ if !r.localAddressNIC.isValidForOutgoing(r.localAddressEndpoint) {
+ return false
+ }
+
+ // If the source NIC and outgoing NIC are different, make sure the stack has
+ // forwarding enabled, or the packet will be handled locally.
+ if r.outgoingNIC != r.localAddressNIC && !r.outgoingNIC.stack.Forwarding(r.NetProto) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto, r.RemoteAddress)) {
+ return false
+ }
+
+ return true
}
// WritePacket writes the packet through the given route.
func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
- if !r.nic.isValidForOutgoing(r.addressEndpoint) {
+ if !r.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
- return r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt)
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt)
}
// WritePackets writes a list of n packets through the given route and returns
// the number of packets written.
func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) {
- if !r.nic.isValidForOutgoing(r.addressEndpoint) {
+ if !r.isValidForOutgoing() {
return 0, tcpip.ErrInvalidEndpointState
}
- return r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params)
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params)
}
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error {
- if !r.nic.isValidForOutgoing(r.addressEndpoint) {
+ if !r.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
- return r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt)
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt)
}
// DefaultTTL returns the default TTL of the underlying network endpoint.
func (r *Route) DefaultTTL() uint8 {
- return r.nic.getNetworkEndpoint(r.NetProto).DefaultTTL()
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).DefaultTTL()
}
// MTU returns the MTU of the underlying network endpoint.
func (r *Route) MTU() uint32 {
- return r.nic.getNetworkEndpoint(r.NetProto).MTU()
+ return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MTU()
}
// Release frees all resources associated with the route.
func (r *Route) Release() {
- if r.addressEndpoint != nil {
- r.addressEndpoint.DecRef()
- r.addressEndpoint = nil
+ if r.localAddressEndpoint != nil {
+ r.localAddressEndpoint.DecRef()
+ r.localAddressEndpoint = nil
}
}
// Clone clones the route.
func (r *Route) Clone() Route {
- if r.addressEndpoint != nil {
- _ = r.addressEndpoint.IncRef()
+ if r.localAddressEndpoint != nil {
+ if !r.localAddressEndpoint.IncRef() {
+ panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress))
+ }
}
return *r
}
@@ -275,7 +432,7 @@ func (r *Route) MakeLoopedRoute() Route {
// Stack returns the instance of the Stack that owns this route.
func (r *Route) Stack() *Stack {
- return r.nic.stack
+ return r.outgoingNIC.stack
}
func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
@@ -283,7 +440,7 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
return true
}
- subnet := r.addressEndpoint.AddressWithPrefix().Subnet()
+ subnet := r.localAddressEndpoint.Subnet()
return subnet.IsBroadcast(addr)
}
@@ -294,9 +451,9 @@ func (r *Route) IsOutboundBroadcast() bool {
return r.isV4Broadcast(r.RemoteAddress)
}
-// IsInboundBroadcast returns true if the route is for an inbound broadcast
+// isInboundBroadcast returns true if the route is for an inbound broadcast
// packet.
-func (r *Route) IsInboundBroadcast() bool {
+func (r *Route) isInboundBroadcast() bool {
// Only IPv4 has a notion of broadcast.
return r.isV4Broadcast(r.LocalAddress)
}
@@ -304,15 +461,16 @@ func (r *Route) IsInboundBroadcast() bool {
// ReverseRoute returns new route with given source and destination address.
func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route {
return Route{
- NetProto: r.NetProto,
- LocalAddress: dst,
- LocalLinkAddress: r.RemoteLinkAddress,
- RemoteAddress: src,
- RemoteLinkAddress: r.LocalLinkAddress,
- Loop: r.Loop,
- addressEndpoint: r.addressEndpoint,
- nic: r.nic,
- linkCache: r.linkCache,
- linkRes: r.linkRes,
+ NetProto: r.NetProto,
+ LocalAddress: dst,
+ LocalLinkAddress: r.RemoteLinkAddress,
+ RemoteAddress: src,
+ RemoteLinkAddress: r.LocalLinkAddress,
+ Loop: r.Loop,
+ localAddressNIC: r.localAddressNIC,
+ localAddressEndpoint: r.localAddressEndpoint,
+ outgoingNIC: r.outgoingNIC,
+ linkCache: r.linkCache,
+ linkRes: r.linkRes,
}
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 3a07577c8..a23fb97ff 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -22,6 +22,7 @@ package stack
import (
"bytes"
"encoding/binary"
+ "fmt"
mathrand "math/rand"
"sync/atomic"
"time"
@@ -52,7 +53,7 @@ const (
type transportProtocolState struct {
proto TransportProtocol
- defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool
+ defaultHandler func(id TransportEndpointID, pkt *PacketBuffer) bool
}
// TCPProbeFunc is the expected function type for a TCP probe function to be
@@ -518,6 +519,10 @@ type Options struct {
//
// RandSource must be thread-safe.
RandSource mathrand.Source
+
+ // IPTables are the initial iptables rules. If nil, iptables will allow
+ // all traffic.
+ IPTables *IPTables
}
// TransportEndpointInfo holds useful information about a transport endpoint
@@ -620,6 +625,10 @@ func New(opts Options) *Stack {
randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())}
}
+ if opts.IPTables == nil {
+ opts.IPTables = DefaultTables()
+ }
+
opts.NUDConfigs.resetInvalidFields()
s := &Stack{
@@ -633,7 +642,7 @@ func New(opts Options) *Stack {
clock: clock,
stats: opts.Stats.FillIn(),
handleLocal: opts.HandleLocal,
- tables: DefaultTables(),
+ tables: opts.IPTables,
icmpRateLimiter: NewICMPRateLimiter(),
seed: generateRandUint32(),
nudConfigs: opts.NUDConfigs,
@@ -751,7 +760,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber,
//
// It must be called only during initialization of the stack. Changing it as the
// stack is operating is not supported.
-func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) {
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(TransportEndpointID, *PacketBuffer) bool) {
state := s.transportProtocols[p]
if state != nil {
state.defaultHandler = h
@@ -830,6 +839,20 @@ func (s *Stack) AddRoute(route tcpip.Route) {
s.routeTable = append(s.routeTable, route)
}
+// RemoveRoutes removes matching routes from the route table.
+func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ var filteredRoutes []tcpip.Route
+ for _, route := range s.routeTable {
+ if !match(route) {
+ filteredRoutes = append(filteredRoutes, route)
+ }
+ }
+ s.routeTable = filteredRoutes
+}
+
// NewEndpoint creates a new transport layer endpoint of the given protocol.
func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
t, ok := s.transportProtocols[transport]
@@ -1180,54 +1203,225 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP
return nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint)
}
+// findLocalRouteFromNICRLocked is like findLocalRouteRLocked but finds a route
+// from the specified NIC.
+//
+// Precondition: s.mu must be read locked.
+func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+ localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint)
+ if localAddressEndpoint == nil {
+ return Route{}, false
+ }
+
+ var outgoingNIC *NIC
+ // Prefer a local route to the same interface as the local address.
+ if localAddressNIC.hasAddress(netProto, remoteAddr) {
+ outgoingNIC = localAddressNIC
+ }
+
+ // If the remote address isn't owned by the local address's NIC, check all
+ // NICs.
+ if outgoingNIC == nil {
+ for _, nic := range s.nics {
+ if nic.hasAddress(netProto, remoteAddr) {
+ outgoingNIC = nic
+ break
+ }
+ }
+ }
+
+ // If the remote address is not owned by the stack, we can't return a local
+ // route.
+ if outgoingNIC == nil {
+ localAddressEndpoint.DecRef()
+ return Route{}, false
+ }
+
+ r := makeLocalRoute(
+ netProto,
+ localAddressEndpoint.AddressWithPrefix().Address,
+ remoteAddr,
+ outgoingNIC,
+ localAddressNIC,
+ localAddressEndpoint,
+ )
+
+ if r.IsOutboundBroadcast() {
+ r.Release()
+ return Route{}, false
+ }
+
+ return r, true
+}
+
+// findLocalRouteRLocked returns a local route.
+//
+// A local route is a route to some remote address which the stack owns. That
+// is, a local route is a route where packets never have to leave the stack.
+//
+// Precondition: s.mu must be read locked.
+func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) {
+ if len(localAddr) == 0 {
+ localAddr = remoteAddr
+ }
+
+ if localAddressNICID == 0 {
+ for _, localAddressNIC := range s.nics {
+ if r, ok := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); ok {
+ return r, true
+ }
+ }
+
+ return Route{}, false
+ }
+
+ if localAddressNIC, ok := s.nics[localAddressNICID]; ok {
+ return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto)
+ }
+
+ return Route{}, false
+}
+
// FindRoute creates a route to the given destination address, leaving through
-// the given nic and local address (if provided).
+// the given NIC and local address (if provided).
+//
+// If a NIC is not specified, the returned route will leave through the same
+// NIC as the NIC that has the local address assigned when forwarding is
+// disabled. If forwarding is enabled and the NIC is unspecified, the route may
+// leave through any interface unless the route is link-local.
+//
+// If no local address is provided, the stack will select a local address. If no
+// remote address is provided, the stack wil use a remote address equal to the
+// local address.
func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
+ isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr)
isLocalBroadcast := remoteAddr == header.IPv4Broadcast
isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
- needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
+ isLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr)
+ needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback)
+
+ if s.handleLocal && !isMulticast && !isLocalBroadcast {
+ if r, ok := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); ok {
+ return r, nil
+ }
+ }
+
+ // If the interface is specified and we do not need a route, return a route
+ // through the interface if the interface is valid and enabled.
if id != 0 && !needRoute {
if nic, ok := s.nics[id]; ok && nic.Enabled() {
if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
- return makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()), nil
+ return makeRoute(
+ netProto,
+ addressEndpoint.AddressWithPrefix().Address,
+ remoteAddr,
+ nic, /* outboundNIC */
+ nic, /* localAddressNIC*/
+ addressEndpoint,
+ s.handleLocal,
+ multicastLoop,
+ ), nil
}
}
- } else {
- for _, route := range s.routeTable {
- if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) {
- continue
+
+ if isLoopback {
+ return Route{}, tcpip.ErrBadLocalAddress
+ }
+ return Route{}, tcpip.ErrNetworkUnreachable
+ }
+
+ canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal
+
+ // Find a route to the remote with the route table.
+ var chosenRoute tcpip.Route
+ for _, route := range s.routeTable {
+ if len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr) {
+ continue
+ }
+
+ nic, ok := s.nics[route.NIC]
+ if !ok || !nic.Enabled() {
+ continue
+ }
+
+ if id == 0 || id == route.NIC {
+ if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
+ var gateway tcpip.Address
+ if needRoute {
+ gateway = route.Gateway
+ }
+ r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop)
+ if r == (Route{}) {
+ panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr))
+ }
+ return r, nil
}
- if nic, ok := s.nics[route.NIC]; ok && nic.Enabled() {
- if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil {
- if len(remoteAddr) == 0 {
- // If no remote address was provided, then the route
- // provided will refer to the link local address.
- remoteAddr = addressEndpoint.AddressWithPrefix().Address
- }
+ }
- r := makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback())
- if len(route.Gateway) > 0 {
- if needRoute {
- r.NextHop = route.Gateway
- }
- } else if subnet := addressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) {
- r.RemoteLinkAddress = header.EthernetBroadcastAddress
+ // If the stack has forwarding enabled and we haven't found a valid route to
+ // the remote address yet, keep track of the first valid route. We keep
+ // iterating because we prefer routes that let us use a local address that
+ // is assigned to the outgoing interface. There is no requirement to do this
+ // from any RFC but simply a choice made to better follow a strong host
+ // model which the netstack follows at the time of writing.
+ if canForward && chosenRoute == (tcpip.Route{}) {
+ chosenRoute = route
+ }
+ }
+
+ if chosenRoute != (tcpip.Route{}) {
+ // At this point we know the stack has forwarding enabled since chosenRoute is
+ // only set when forwarding is enabled.
+ nic, ok := s.nics[chosenRoute.NIC]
+ if !ok {
+ // If the route's NIC was invalid, we should not have chosen the route.
+ panic(fmt.Sprintf("chosen route must have a valid NIC with ID = %d", chosenRoute.NIC))
+ }
+
+ var gateway tcpip.Address
+ if needRoute {
+ gateway = chosenRoute.Gateway
+ }
+
+ // Use the specified NIC to get the local address endpoint.
+ if id != 0 {
+ if aNIC, ok := s.nics[id]; ok {
+ if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil {
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
+ return r, nil
}
+ }
+ }
+
+ return Route{}, tcpip.ErrNoRoute
+ }
+ if id == 0 {
+ // If an interface is not specified, try to find a NIC that holds the local
+ // address endpoint to construct a route.
+ for _, aNIC := range s.nics {
+ addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto)
+ if addressEndpoint == nil {
+ continue
+ }
+
+ if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) {
return r, nil
}
}
}
}
- if !needRoute {
- return Route{}, tcpip.ErrNetworkUnreachable
+ if needRoute {
+ return Route{}, tcpip.ErrNoRoute
}
-
- return Route{}, tcpip.ErrNoRoute
+ if isLoopback {
+ return Route{}, tcpip.ErrBadLocalAddress
+ }
+ return Route{}, tcpip.ErrNetworkUnreachable
}
// CheckNetworkProtocol checks if a given network protocol is enabled in the
@@ -1323,7 +1517,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address,
fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
linkRes := s.linkAddrResolvers[protocol]
- return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.LinkEndpoint, waker)
+ return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, waker)
}
// Neighbors returns all IP to MAC address associations.
@@ -1443,8 +1637,8 @@ func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) {
// FindTransportEndpoint finds an endpoint that most closely matches the provided
// id. If no endpoint is found it returns nil.
-func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
- return s.demux.findTransportEndpoint(netProto, transProto, id, r)
+func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint {
+ return s.demux.findTransportEndpoint(netProto, transProto, id, nicID)
}
// RegisterRawTransportEndpoint registers the given endpoint with the stack
@@ -1896,3 +2090,71 @@ func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
func (s *Stack) NewJob(l sync.Locker, f func()) *tcpip.Job {
return tcpip.NewJob(s.clock, l, f)
}
+
+// ParseResult indicates the result of a parsing attempt.
+type ParseResult int
+
+const (
+ // ParsedOK indicates that a packet was successfully parsed.
+ ParsedOK ParseResult = iota
+
+ // UnknownNetworkProtocol indicates that the network protocol is unknown.
+ UnknownNetworkProtocol
+
+ // NetworkLayerParseError indicates that the network packet was not
+ // successfully parsed.
+ NetworkLayerParseError
+
+ // UnknownTransportProtocol indicates that the transport protocol is unknown.
+ UnknownTransportProtocol
+
+ // TransportLayerParseError indicates that the transport packet was not
+ // successfully parsed.
+ TransportLayerParseError
+)
+
+// ParsePacketBuffer parses the provided packet buffer.
+func (s *Stack) ParsePacketBuffer(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) ParseResult {
+ netProto, ok := s.networkProtocols[protocol]
+ if !ok {
+ return UnknownNetworkProtocol
+ }
+
+ transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt)
+ if !ok {
+ return NetworkLayerParseError
+ }
+ if !hasTransportHdr {
+ return ParsedOK
+ }
+
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
+ // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
+ // full explanation.
+ if transProtoNum == header.ICMPv4ProtocolNumber || transProtoNum == header.ICMPv6ProtocolNumber {
+ return ParsedOK
+ }
+
+ pkt.TransportProtocolNumber = transProtoNum
+ // Parse the transport header if present.
+ state, ok := s.transportProtocols[transProtoNum]
+ if !ok {
+ return UnknownTransportProtocol
+ }
+
+ if !state.proto.Parse(pkt) {
+ return TransportLayerParseError
+ }
+
+ return ParsedOK
+}
+
+// networkProtocolNumbers returns the network protocol numbers the stack is
+// configured with.
+func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber {
+ protos := make([]tcpip.NetworkProtocolNumber, 0, len(s.networkProtocols))
+ for p := range s.networkProtocols {
+ protos = append(protos, p)
+ }
+ return protos
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index e75f58c64..dedfdd435 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -21,6 +21,7 @@ import (
"bytes"
"fmt"
"math"
+ "net"
"sort"
"testing"
"time"
@@ -108,12 +109,13 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
return 123
}
-func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) {
// Increment the received packet count in the protocol descriptor.
- f.proto.packetCount[int(r.LocalAddress[0])%len(f.proto.packetCount)]++
+ netHdr := pkt.NetworkHeader().View()
+ f.proto.packetCount[int(netHdr[dstAddrOffset])%len(f.proto.packetCount)]++
// Handle control packets.
- if pkt.NetworkHeader().View()[protocolNumberOffset] == uint8(fakeControlProtocol) {
+ if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) {
nb, ok := pkt.Data.PullUp(fakeNetHeaderLen)
if !ok {
return
@@ -129,7 +131,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff
}
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
+ f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -151,12 +153,15 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
// Add the protocol's header to the packet and send it to the link
// endpoint.
hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen)
+ pkt.NetworkProtocolNumber = fakeNetNumber
hdr[dstAddrOffset] = r.RemoteAddress[0]
hdr[srcAddrOffset] = r.LocalAddress[0]
hdr[protocolNumberOffset] = byte(params.Protocol)
if r.Loop&stack.PacketLoop != 0 {
- f.HandlePacket(r, pkt)
+ pkt := pkt.Clone()
+ r.PopulatePacketInfo(pkt)
+ f.HandlePacket(pkt)
}
if r.Loop&stack.PacketOut == 0 {
return nil
@@ -254,6 +259,7 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto
if !ok {
return 0, false, false
}
+ pkt.NetworkProtocolNumber = fakeNetNumber
return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true
}
@@ -1334,6 +1340,106 @@ func TestPromiscuousMode(t *testing.T) {
testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
}
+// TestExternalSendWithHandleLocal tests that the stack creates a non-local
+// route when spoofing or promiscuous mode are enabled.
+//
+// This test makes sure that packets are transmitted from the stack.
+func TestExternalSendWithHandleLocal(t *testing.T) {
+ const (
+ unspecifiedNICID = 0
+ nicID = 1
+
+ localAddr = tcpip.Address("\x01")
+ dstAddr = tcpip.Address("\x03")
+ )
+
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ tests := []struct {
+ name string
+ configureStack func(*testing.T, *stack.Stack)
+ }{
+ {
+ name: "Default",
+ configureStack: func(*testing.T, *stack.Stack) {},
+ },
+ {
+ name: "Spoofing",
+ configureStack: func(t *testing.T, s *stack.Stack) {
+ if err := s.SetSpoofing(nicID, true); err != nil {
+ t.Fatalf("s.SetSpoofing(%d, true): %s", nicID, err)
+ }
+ },
+ },
+ {
+ name: "Promiscuous",
+ configureStack: func(t *testing.T, s *stack.Stack) {
+ if err := s.SetPromiscuousMode(nicID, true); err != nil {
+ t.Fatalf("s.SetPromiscuousMode(%d, true): %s", nicID, err)
+ }
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, handleLocal := range []bool{true, false} {
+ t.Run(fmt.Sprintf("HandleLocal=%t", handleLocal), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ HandleLocal: handleLocal,
+ })
+
+ ep := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID, ep); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}})
+
+ test.configureStack(t, s)
+
+ r, err := s.FindRoute(unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, err)
+ }
+ defer r.Release()
+
+ if r.LocalAddress != localAddr {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, localAddr)
+ }
+ if r.RemoteAddress != dstAddr {
+ t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr)
+ }
+
+ if n := ep.Drain(); n != 0 {
+ t.Fatalf("got ep.Drain() = %d, want = 0", n)
+ }
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: fakeTransNumber,
+ TTL: 123,
+ TOS: stack.DefaultTOS,
+ }, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: buffer.NewView(10).ToVectorisedView(),
+ })); err != nil {
+ t.Fatalf("r.WritePacket(nil, _, _): %s", err)
+ }
+ if n := ep.Drain(); n != 1 {
+ t.Fatalf("got ep.Drain() = %d, want = 1", n)
+ }
+ })
+ }
+ })
+ }
+}
+
func TestSpoofingWithAddress(t *testing.T) {
localAddr := tcpip.Address("\x01")
nonExistentLocalAddr := tcpip.Address("\x02")
@@ -3346,7 +3452,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
RemoteAddress: ipv4SubnetBcast,
RemoteLinkAddress: header.EthernetBroadcastAddress,
NetProto: header.IPv4ProtocolNumber,
- Loop: stack.PacketOut,
+ Loop: stack.PacketOut | stack.PacketLoop,
},
},
// Broadcast to a locally attached /31 subnet does not populate the
@@ -3672,3 +3778,453 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) {
t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix)
}
}
+
+// TestAddRoute tests Stack.AddRoute
+func TestAddRoute(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{})
+
+ subnet1, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ subnet2, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ expected := []tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: 1},
+ {Destination: subnet2, Gateway: "\x00", NIC: 1},
+ }
+
+ // Initialize the route table with one route.
+ s.SetRouteTable([]tcpip.Route{expected[0]})
+
+ // Add another route.
+ s.AddRoute(expected[1])
+
+ rt := s.GetRouteTable()
+ if got, want := len(rt), len(expected); got != want {
+ t.Fatalf("Unexpected route table length got = %d, want = %d", got, want)
+ }
+ for i, route := range rt {
+ if got, want := route, expected[i]; got != want {
+ t.Fatalf("Unexpected route got = %#v, want = %#v", got, want)
+ }
+ }
+}
+
+// TestRemoveRoutes tests Stack.RemoveRoutes
+func TestRemoveRoutes(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{})
+
+ addressToRemove := tcpip.Address("\x01")
+ subnet1, err := tcpip.NewSubnet(addressToRemove, "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ subnet2, err := tcpip.NewSubnet(addressToRemove, "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ subnet3, err := tcpip.NewSubnet("\x02", "\x02")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Initialize the route table with three routes.
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: 1},
+ {Destination: subnet2, Gateway: "\x00", NIC: 1},
+ {Destination: subnet3, Gateway: "\x00", NIC: 1},
+ })
+
+ // Remove routes with the specific address.
+ s.RemoveRoutes(func(r tcpip.Route) bool {
+ return r.Destination.ID() == addressToRemove
+ })
+
+ expected := []tcpip.Route{{Destination: subnet3, Gateway: "\x00", NIC: 1}}
+ rt := s.GetRouteTable()
+ if got, want := len(rt), len(expected); got != want {
+ t.Fatalf("Unexpected route table length got = %d, want = %d", got, want)
+ }
+ for i, route := range rt {
+ if got, want := route, expected[i]; got != want {
+ t.Fatalf("Unexpected route got = %#v, want = %#v", got, want)
+ }
+ }
+}
+
+func TestFindRouteWithForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+
+ nic1Addr = tcpip.Address("\x01")
+ nic2Addr = tcpip.Address("\x02")
+ remoteAddr = tcpip.Address("\x03")
+ )
+
+ type netCfg struct {
+ proto tcpip.NetworkProtocolNumber
+ factory stack.NetworkProtocolFactory
+ nic1Addr tcpip.Address
+ nic2Addr tcpip.Address
+ remoteAddr tcpip.Address
+ }
+
+ fakeNetCfg := netCfg{
+ proto: fakeNetNumber,
+ factory: fakeNetFactory,
+ nic1Addr: nic1Addr,
+ nic2Addr: nic2Addr,
+ remoteAddr: remoteAddr,
+ }
+
+ globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16())
+ globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16())
+
+ ipv6LinkLocalNIC1WithGlobalRemote := netCfg{
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1Addr: llAddr1,
+ nic2Addr: globalIPv6Addr2,
+ remoteAddr: globalIPv6Addr1,
+ }
+ ipv6GlobalNIC1WithLinkLocalRemote := netCfg{
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1Addr: globalIPv6Addr1,
+ nic2Addr: llAddr1,
+ remoteAddr: llAddr2,
+ }
+ ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{
+ proto: ipv6.ProtocolNumber,
+ factory: ipv6.NewProtocol,
+ nic1Addr: globalIPv6Addr1,
+ nic2Addr: globalIPv6Addr2,
+ remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ }
+
+ tests := []struct {
+ name string
+
+ netCfg netCfg
+ forwardingEnabled bool
+
+ addrNIC tcpip.NICID
+ localAddr tcpip.Address
+
+ findRouteErr *tcpip.Error
+ dependentOnForwarding bool
+ }{
+ {
+ name: "forwarding disabled and localAddr not on specified NIC but route from different NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ addrNIC: nicID1,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr not on specified NIC but route from different NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ addrNIC: nicID1,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and localAddr on specified NIC but route from different NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ addrNIC: nicID1,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr on specified NIC but route from different NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ addrNIC: nicID1,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: true,
+ },
+ {
+ name: "forwarding disabled and localAddr on specified NIC and route from same NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ addrNIC: nicID2,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr on specified NIC and route from same NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ addrNIC: nicID2,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and localAddr not on specified NIC but route from same NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ addrNIC: nicID2,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr not on specified NIC but route from same NIC",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ addrNIC: nicID2,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and localAddr on same NIC as route",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr on same NIC as route",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ localAddr: fakeNetCfg.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and localAddr on different NIC as route",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: false,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and localAddr on different NIC as route",
+ netCfg: fakeNetCfg,
+ forwardingEnabled: true,
+ localAddr: fakeNetCfg.nic1Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: true,
+ },
+ {
+ name: "forwarding disabled and specified NIC only has link-local addr with route on different NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: false,
+ addrNIC: nicID1,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and specified NIC only has link-local addr with route on different NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: true,
+ addrNIC: nicID1,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and link-local local addr with route on different NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and link-local local addr with route on same NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNoRoute,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and global local addr with route on same NIC",
+ netCfg: ipv6LinkLocalNIC1WithGlobalRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and link-local local addr with route on same NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and link-local local addr with route on same NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and global local addr with link-local remote on different NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNetworkUnreachable,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and global local addr with link-local remote on different NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNetworkUnreachable,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and global local addr with link-local multicast remote on different NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNetworkUnreachable,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and global local addr with link-local multicast remote on different NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr,
+ findRouteErr: tcpip.ErrNetworkUnreachable,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding disabled and global local addr with link-local multicast remote on same NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
+ forwardingEnabled: false,
+ localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ {
+ name: "forwarding enabled and global local addr with link-local multicast remote on same NIC",
+ netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote,
+ forwardingEnabled: true,
+ localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr,
+ findRouteErr: nil,
+ dependentOnForwarding: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{test.netCfg.factory},
+ })
+
+ ep1 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s:", nicID1, err)
+ }
+
+ ep2 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID2, ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err)
+ }
+
+ if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err)
+ }
+
+ if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err)
+ }
+
+ if err := s.SetForwarding(test.netCfg.proto, test.forwardingEnabled); err != nil {
+ t.Fatalf("SetForwarding(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}})
+
+ r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */)
+ if err != test.findRouteErr {
+ t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr)
+ }
+ defer r.Release()
+
+ if test.findRouteErr != nil {
+ return
+ }
+
+ if r.LocalAddress != test.localAddr {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.localAddr)
+ }
+ if r.RemoteAddress != test.netCfg.remoteAddr {
+ t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.netCfg.remoteAddr)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ // Sending a packet should always go through NIC2 since we only install a
+ // route to test.netCfg.remoteAddr through NIC2.
+ data := buffer.View([]byte{1, 2, 3, 4})
+ if err := send(r, data); err != nil {
+ t.Fatalf("send(_, _): %s", err)
+ }
+ if n := ep1.Drain(); n != 0 {
+ t.Errorf("got %d unexpected packets from ep1", n)
+ }
+ pkt, ok := ep2.Read()
+ if !ok {
+ t.Fatal("packet not sent through ep2")
+ }
+ if pkt.Route.LocalAddress != test.localAddr {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr)
+ }
+ if pkt.Route.RemoteAddress != test.netCfg.remoteAddr {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr)
+ }
+
+ if !test.forwardingEnabled || !test.dependentOnForwarding {
+ return
+ }
+
+ // Disabling forwarding when the route is dependent on forwarding being
+ // enabled should make the route invalid.
+ if err := s.SetForwarding(test.netCfg.proto, false); err != nil {
+ t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err)
+ }
+ if err := send(r, data); err != tcpip.ErrInvalidEndpointState {
+ t.Fatalf("got send(_, _) = %s, want = %s", err, tcpip.ErrInvalidEndpointState)
+ }
+ if n := ep1.Drain(); n != 0 {
+ t.Errorf("got %d unexpected packets from ep1", n)
+ }
+ if n := ep2.Drain(); n != 0 {
+ t.Errorf("got %d unexpected packets from ep2", n)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 35e5b1a2e..f183ec6e4 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -152,10 +152,10 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
+func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) {
epsByNIC.mu.RLock()
- mpep, ok := epsByNIC.endpoints[r.nic.ID()]
+ mpep, ok := epsByNIC.endpoints[pkt.NICID]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
@@ -165,20 +165,20 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p
// If this is a broadcast or multicast datagram, deliver the datagram to all
// endpoints bound to the right device.
- if isInboundMulticastOrBroadcast(r) {
- mpep.handlePacketAll(r, id, pkt)
+ if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) {
+ mpep.handlePacketAll(id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return
}
// multiPortEndpoints are guaranteed to have at least one element.
transEP := selectEndpoint(id, mpep, epsByNIC.seed)
if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
- queuedProtocol.QueuePacket(r, transEP, id, pkt)
+ queuedProtocol.QueuePacket(transEP, id, pkt)
epsByNIC.mu.RUnlock()
return
}
- transEP.HandlePacket(r, id, pkt)
+ transEP.HandlePacket(id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
}
@@ -253,6 +253,8 @@ func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t T
// based on endpoints IDs. It should only be instantiated via
// newTransportDemuxer.
type transportDemuxer struct {
+ stack *Stack
+
// protocol is immutable.
protocol map[protocolIDs]*transportEndpoints
queuedProtocols map[protocolIDs]queuedTransportProtocol
@@ -262,11 +264,12 @@ type transportDemuxer struct {
// the dispatcher to delivery packets to the QueuePacket method instead of
// calling HandlePacket directly on the endpoint.
type queuedTransportProtocol interface {
- QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer)
+ QueuePacket(ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer)
}
func newTransportDemuxer(stack *Stack) *transportDemuxer {
d := &transportDemuxer{
+ stack: stack,
protocol: make(map[protocolIDs]*transportEndpoints),
queuedProtocols: make(map[protocolIDs]queuedTransportProtocol),
}
@@ -377,22 +380,22 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
return mpep.endpoints[idx]
}
-func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
+func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) {
ep.mu.RLock()
queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}]
// HandlePacket takes ownership of pkt, so each endpoint needs
// its own copy except for the final one.
for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] {
if mustQueue {
- queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone())
+ queuedProtocol.QueuePacket(endpoint, id, pkt.Clone())
} else {
- endpoint.HandlePacket(r, id, pkt.Clone())
+ endpoint.HandlePacket(id, pkt.Clone())
}
}
if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue {
- queuedProtocol.QueuePacket(r, endpoint, id, pkt)
+ queuedProtocol.QueuePacket(endpoint, id, pkt)
} else {
- endpoint.HandlePacket(r, id, pkt)
+ endpoint.HandlePacket(id, pkt)
}
ep.mu.RUnlock() // Don't use defer for performance reasons.
}
@@ -518,29 +521,29 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN
// deliverPacket attempts to find one or more matching transport endpoints, and
// then, if matches are found, delivers the packet to them. Returns true if
// the packet no longer needs to be handled.
-func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool {
- eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
+func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool {
+ eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}]
if !ok {
return false
}
// If the packet is a UDP broadcast or multicast, then find all matching
// transport endpoints.
- if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) {
+ if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(pkt, id.LocalAddress) {
eps.mu.RLock()
destEPs := eps.findAllEndpointsLocked(id)
eps.mu.RUnlock()
// Fail if we didn't find at least one matching transport endpoint.
if len(destEPs) == 0 {
- r.Stats().UDP.UnknownPortErrors.Increment()
+ d.stack.stats.UDP.UnknownPortErrors.Increment()
return false
}
// handlePacket takes ownership of pkt, so each endpoint needs its own
// copy except for the final one.
for _, ep := range destEPs[:len(destEPs)-1] {
- ep.handlePacket(r, id, pkt.Clone())
+ ep.handlePacket(id, pkt.Clone())
}
- destEPs[len(destEPs)-1].handlePacket(r, id, pkt)
+ destEPs[len(destEPs)-1].handlePacket(id, pkt)
return true
}
@@ -548,10 +551,10 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// destination address, then do nothing further and instruct the caller to do
// the same. The network layer handles address validation for specified source
// addresses.
- if protocol == header.TCPProtocolNumber && (!isSpecified(r.LocalAddress) || !isSpecified(r.RemoteAddress) || isInboundMulticastOrBroadcast(r)) {
+ if protocol == header.TCPProtocolNumber && (!isSpecified(id.LocalAddress) || !isSpecified(id.RemoteAddress) || isInboundMulticastOrBroadcast(pkt, id.LocalAddress)) {
// TCP can only be used to communicate between a single source and a
- // single destination; the addresses must be unicast.
- r.Stats().TCP.InvalidSegmentsReceived.Increment()
+ // single destination; the addresses must be unicast.e
+ d.stack.stats.TCP.InvalidSegmentsReceived.Increment()
return true
}
@@ -560,18 +563,18 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
eps.mu.RUnlock()
if ep == nil {
if protocol == header.UDPProtocolNumber {
- r.Stats().UDP.UnknownPortErrors.Increment()
+ d.stack.stats.UDP.UnknownPortErrors.Increment()
}
return false
}
- ep.handlePacket(r, id, pkt)
+ ep.handlePacket(id, pkt)
return true
}
// deliverRawPacket attempts to deliver the given packet and returns whether it
// was delivered successfully.
-func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool {
- eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
+func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool {
+ eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}]
if !ok {
return false
}
@@ -584,7 +587,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
for _, rawEP := range eps.rawEndpoints {
// Each endpoint gets its own copy of the packet for the sake
// of save/restore.
- rawEP.HandlePacket(r, pkt)
+ rawEP.HandlePacket(pkt.Clone())
foundRaw = true
}
eps.mu.RUnlock()
@@ -612,7 +615,7 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco
}
// findTransportEndpoint find a single endpoint that most closely matches the provided id.
-func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
+func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint {
eps, ok := d.protocol[protocolIDs{netProto, transProto}]
if !ok {
return nil
@@ -628,7 +631,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN
epsByNIC.mu.RLock()
eps.mu.RUnlock()
- mpep, ok := epsByNIC.endpoints[r.nic.ID()]
+ mpep, ok := epsByNIC.endpoints[nicID]
if !ok {
if mpep, ok = epsByNIC.endpoints[0]; !ok {
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
@@ -679,8 +682,8 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
eps.mu.Unlock()
}
-func isInboundMulticastOrBroadcast(r *Route) bool {
- return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress)
+func isInboundMulticastOrBroadcast(pkt *PacketBuffer, localAddr tcpip.Address) bool {
+ return pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(localAddr) || header.IsV6MulticastAddress(localAddr)
}
func isSpecified(addr tcpip.Address) bool {
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 62ab6d92f..c457b67a2 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -28,7 +28,7 @@ import (
const (
fakeTransNumber tcpip.TransportProtocolNumber = 1
- fakeTransHeaderLen = 3
+ fakeTransHeaderLen int = 3
)
// fakeTransportEndpoint is a transport-layer protocol endpoint. It counts
@@ -213,20 +213,29 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro
return tcpip.FullAddress{}, nil
}
-func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ *stack.PacketBuffer) {
+func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Increment the number of received packets.
f.proto.packetCount++
- if f.acceptQueue != nil {
- f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
- TransportEndpointInfo: stack.TransportEndpointInfo{
- ID: f.ID,
- NetProto: f.NetProto,
- },
- proto: f.proto,
- peerAddr: r.RemoteAddress,
- route: r.Clone(),
- })
+ if f.acceptQueue == nil {
+ return
}
+
+ netHdr := pkt.NetworkHeader().View()
+ route, err := f.proto.stack.FindRoute(pkt.NICID, tcpip.Address(netHdr[dstAddrOffset]), tcpip.Address(netHdr[srcAddrOffset]), pkt.NetworkProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return
+ }
+ route.ResolveWith(pkt.SourceLinkAddress())
+
+ f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
+ TransportEndpointInfo: stack.TransportEndpointInfo{
+ ID: f.ID,
+ NetProto: f.NetProto,
+ },
+ proto: f.proto,
+ peerAddr: route.RemoteAddress,
+ route: route,
+ })
}
func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) {
@@ -288,7 +297,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp
return 0, 0, nil
}
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
return stack.UnknownDestinationPacketHandled
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index d77848d61..3ab2b7654 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -356,10 +356,9 @@ func (s *Subnet) IsBroadcast(address Address) bool {
return s.Prefix() <= 30 && s.Broadcast() == address
}
-// Equal returns true if s equals o.
-//
-// Needed to use cmp.Equal on Subnet as its fields are unexported.
+// Equal returns true if this Subnet is equal to the given Subnet.
func (s Subnet) Equal(o Subnet) bool {
+ // If this changes, update Route.Equal accordingly.
return s == o
}
@@ -763,6 +762,10 @@ const (
// endpoint that all packets being written have an IP header and the
// endpoint should not attach an IP header.
IPHdrIncludedOption
+
+ // AcceptConnOption is used by GetSockOptBool to indicate if the
+ // socket is a listening socket.
+ AcceptConnOption
)
// SockOptInt represents socket options which values have the int type.
@@ -1256,6 +1259,12 @@ func (r Route) String() string {
return out.String()
}
+// Equal returns true if the given Route is equal to this Route.
+func (r Route) Equal(to Route) bool {
+ // NOTE: This relies on the fact that r.Destination == to.Destination
+ return r == to
+}
+
// TransportProtocolNumber is the number of a transport protocol.
type TransportProtocolNumber uint32
@@ -1496,6 +1505,15 @@ type IPStats struct {
// IPTablesOutputDropped is the total number of IP packets dropped in
// the Output chain.
IPTablesOutputDropped *StatCounter
+
+ // OptionTSReceived is the number of Timestamp options seen.
+ OptionTSReceived *StatCounter
+
+ // OptionRRReceived is the number of Record Route options seen.
+ OptionRRReceived *StatCounter
+
+ // OptionUnknownReceived is the number of unknown IP options seen.
+ OptionUnknownReceived *StatCounter
}
// TCPStats collects TCP-specific stats.
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 34aab32d0..9b0f3b675 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -10,6 +10,7 @@ go_test(
"link_resolution_test.go",
"loopback_test.go",
"multicast_broadcast_test.go",
+ "route_test.go",
],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index 0dcef7b04..bf7594268 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -33,11 +33,6 @@ import (
func TestForwarding(t *testing.T) {
const (
- host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
- routerNIC1LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07")
- routerNIC2LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08")
- host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
-
host1NICID = 1
routerNICID1 = 2
routerNICID2 = 3
@@ -166,6 +161,38 @@ func TestForwarding(t *testing.T) {
}
},
},
+ {
+ name: "IPv4 host2 server with routerNIC1 client",
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses {
+ ep1, ep1WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv4.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, routerStack, udp.ProtocolNumber, ipv4.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ }
+ },
+ },
+ {
+ name: "IPv6 routerNIC2 server with host1 client",
+ epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses {
+ ep1, ep1WECH := newEP(t, routerStack, udp.ProtocolNumber, ipv6.ProtocolNumber)
+ ep2, ep2WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv6.ProtocolNumber)
+ return endpointAndAddresses{
+ serverEP: ep1,
+ serverAddr: routerNIC2IPv6Addr.AddressWithPrefix.Address,
+ serverReadableCH: ep1WECH,
+
+ clientEP: ep2,
+ clientAddr: host1IPv6Addr.AddressWithPrefix.Address,
+ clientReadableCH: ep2WECH,
+ }
+ },
+ },
}
for _, test := range tests {
@@ -179,8 +206,8 @@ func TestForwarding(t *testing.T) {
routerStack := stack.New(stackOpts)
host2Stack := stack.New(stackOpts)
- host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr)
- routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr)
+ host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2)
+ routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4)
if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil {
t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
@@ -321,12 +348,8 @@ func TestForwarding(t *testing.T) {
if err == tcpip.ErrNoLinkAddress {
// Wait for link resolution to complete.
<-ch
-
n, _, err = ep.Write(dataPayload, wOpts)
- } else if err != nil {
- t.Fatalf("ep.Write(_, _): %s", err)
}
-
if err != nil {
t.Fatalf("ep.Write(_, _): %s", err)
}
@@ -343,7 +366,6 @@ func TestForwarding(t *testing.T) {
// Wait for the endpoint to be readable.
<-ch
-
var addr tcpip.FullAddress
v, _, err := ep.Read(&addr)
if err != nil {
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index 6ddcda70c..fe7c1bb3d 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -32,32 +32,36 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-var (
- host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
- host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+const (
+ linkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06")
+ linkAddr2 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07")
+ linkAddr3 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08")
+ linkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
+)
- host1IPv4Addr = tcpip.ProtocolAddress{
+var (
+ ipv4Addr1 = tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()),
PrefixLen: 24,
},
}
- host2IPv4Addr = tcpip.ProtocolAddress{
+ ipv4Addr2 = tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()),
PrefixLen: 8,
},
}
- host1IPv6Addr = tcpip.ProtocolAddress{
+ ipv6Addr1 = tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("a::1").To16()),
PrefixLen: 64,
},
}
- host2IPv6Addr = tcpip.ProtocolAddress{
+ ipv6Addr2 = tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.Address(net.ParseIP("a::2").To16()),
@@ -89,7 +93,7 @@ func TestPing(t *testing.T) {
name: "IPv4 Ping",
transProto: icmp.ProtocolNumber4,
netProto: ipv4.ProtocolNumber,
- remoteAddr: host2IPv4Addr.AddressWithPrefix.Address,
+ remoteAddr: ipv4Addr2.AddressWithPrefix.Address,
icmpBuf: func(t *testing.T) buffer.View {
data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data)))
@@ -104,7 +108,7 @@ func TestPing(t *testing.T) {
name: "IPv6 Ping",
transProto: icmp.ProtocolNumber6,
netProto: ipv6.ProtocolNumber,
- remoteAddr: host2IPv6Addr.AddressWithPrefix.Address,
+ remoteAddr: ipv6Addr2.AddressWithPrefix.Address,
icmpBuf: func(t *testing.T) buffer.View {
data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data)))
@@ -127,7 +131,7 @@ func TestPing(t *testing.T) {
host1Stack := stack.New(stackOpts)
host2Stack := stack.New(stackOpts)
- host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr)
+ host1NIC, host2NIC := pipe.New(linkAddr1, linkAddr2)
if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil {
t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
@@ -143,36 +147,36 @@ func TestPing(t *testing.T) {
t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err)
}
- if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err)
+ if err := host1Stack.AddProtocolAddress(host1NICID, ipv4Addr1); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv4Addr1, err)
}
- if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err)
+ if err := host2Stack.AddProtocolAddress(host2NICID, ipv4Addr2); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv4Addr2, err)
}
- if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil {
- t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err)
+ if err := host1Stack.AddProtocolAddress(host1NICID, ipv6Addr1); err != nil {
+ t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv6Addr1, err)
}
- if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil {
- t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err)
+ if err := host2Stack.AddProtocolAddress(host2NICID, ipv6Addr2); err != nil {
+ t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv6Addr2, err)
}
host1Stack.SetRouteTable([]tcpip.Route{
tcpip.Route{
- Destination: host1IPv4Addr.AddressWithPrefix.Subnet(),
+ Destination: ipv4Addr1.AddressWithPrefix.Subnet(),
NIC: host1NICID,
},
tcpip.Route{
- Destination: host1IPv6Addr.AddressWithPrefix.Subnet(),
+ Destination: ipv6Addr1.AddressWithPrefix.Subnet(),
NIC: host1NICID,
},
})
host2Stack.SetRouteTable([]tcpip.Route{
tcpip.Route{
- Destination: host2IPv4Addr.AddressWithPrefix.Subnet(),
+ Destination: ipv4Addr2.AddressWithPrefix.Subnet(),
NIC: host2NICID,
},
tcpip.Route{
- Destination: host2IPv6Addr.AddressWithPrefix.Subnet(),
+ Destination: ipv6Addr2.AddressWithPrefix.Subnet(),
NIC: host2NICID,
},
})
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index e8caf09ba..421da1add 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -204,7 +204,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) {
},
})
- wq := waiter.Queue{}
+ var wq waiter.Queue
rep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq)
if err != nil {
t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err)
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index f1028823b..cdf0459e3 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -409,7 +409,7 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
t.Fatalf("got unexpected address length = %d bytes", l)
}
- wq := waiter.Queue{}
+ var wq waiter.Queue
ep, err := s.NewEndpoint(udp.ProtocolNumber, netproto, &wq)
if err != nil {
t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netproto, err)
@@ -447,8 +447,6 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff")
)
- data := tcpip.SlicePayload([]byte{1, 2, 3, 4})
-
tests := []struct {
name string
broadcastAddr tcpip.Address
@@ -492,16 +490,22 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
},
})
+ type endpointAndWaiter struct {
+ ep tcpip.Endpoint
+ ch chan struct{}
+ }
+ var eps []endpointAndWaiter
// We create endpoints that bind to both the wildcard address and the
// broadcast address to make sure both of these types of "broadcast
// interested" endpoints receive broadcast packets.
- wq := waiter.Queue{}
- var eps []tcpip.Endpoint
for _, bindWildcard := range []bool{false, true} {
// Create multiple endpoints for each type of "broadcast interested"
// endpoint so we can test that all endpoints receive the broadcast
// packet.
for i := 0; i < 2; i++ {
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err)
@@ -528,7 +532,7 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
}
}
- eps = append(eps, ep)
+ eps = append(eps, endpointAndWaiter{ep: ep, ch: ch})
}
}
@@ -539,14 +543,18 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
Port: localPort,
},
}
- if n, _, err := wep.Write(data, writeOpts); err != nil {
+ data := tcpip.SlicePayload([]byte{byte(i), 2, 3, 4})
+ if n, _, err := wep.ep.Write(data, writeOpts); err != nil {
t.Fatalf("eps[%d].Write(_, _): %s", i, err)
} else if want := int64(len(data)); n != want {
t.Fatalf("got eps[%d].Write(_, _) = (%d, nil, nil), want = (%d, nil, nil)", i, n, want)
}
for j, rep := range eps {
- if gotPayload, _, err := rep.Read(nil); err != nil {
+ // Wait for the endpoint to become readable.
+ <-rep.ch
+
+ if gotPayload, _, err := rep.ep.Read(nil); err != nil {
t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err)
} else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff)
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
new file mode 100644
index 000000000..02fc47015
--- /dev/null
+++ b/pkg/tcpip/tests/integration/route_test.go
@@ -0,0 +1,388 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package integration_test
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// TestLocalPing tests pinging a remote that is local the stack.
+//
+// This tests that a local route is created and packets do not leave the stack.
+func TestLocalPing(t *testing.T) {
+ const (
+ nicID = 1
+ ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01")
+
+ // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo
+ // request/reply packets.
+ icmpDataOffset = 8
+ )
+
+ channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") }
+ channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) {
+ channelEP := e.(*channel.Endpoint)
+ if n := channelEP.Drain(); n != 0 {
+ t.Fatalf("got channelEP.Drain() = %d, want = 0", n)
+ }
+ }
+
+ ipv4ICMPBuf := func(t *testing.T) buffer.View {
+ data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
+ hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data)))
+ hdr.SetType(header.ICMPv4Echo)
+ if n := copy(hdr.Payload(), data[:]); n != len(data) {
+ t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
+ }
+ return buffer.View(hdr)
+ }
+
+ ipv6ICMPBuf := func(t *testing.T) buffer.View {
+ data := [8]byte{1, 2, 3, 4, 5, 6, 7, 9}
+ hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data)))
+ hdr.SetType(header.ICMPv6EchoRequest)
+ if n := copy(hdr.Payload(), data[:]); n != len(data) {
+ t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data))
+ }
+ return buffer.View(hdr)
+ }
+
+ tests := []struct {
+ name string
+ transProto tcpip.TransportProtocolNumber
+ netProto tcpip.NetworkProtocolNumber
+ linkEndpoint func() stack.LinkEndpoint
+ localAddr tcpip.Address
+ icmpBuf func(*testing.T) buffer.View
+ expectedConnectErr *tcpip.Error
+ checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint)
+ }{
+ {
+ name: "IPv4 loopback",
+ transProto: icmp.ProtocolNumber4,
+ netProto: ipv4.ProtocolNumber,
+ linkEndpoint: loopback.New,
+ localAddr: ipv4Loopback,
+ icmpBuf: ipv4ICMPBuf,
+ checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
+ },
+ {
+ name: "IPv6 loopback",
+ transProto: icmp.ProtocolNumber6,
+ netProto: ipv6.ProtocolNumber,
+ linkEndpoint: loopback.New,
+ localAddr: header.IPv6Loopback,
+ icmpBuf: ipv6ICMPBuf,
+ checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
+ },
+ {
+ name: "IPv4 non-loopback",
+ transProto: icmp.ProtocolNumber4,
+ netProto: ipv4.ProtocolNumber,
+ linkEndpoint: channelEP,
+ localAddr: ipv4Addr.Address,
+ icmpBuf: ipv4ICMPBuf,
+ checkLinkEndpoint: channelEPCheck,
+ },
+ {
+ name: "IPv6 non-loopback",
+ transProto: icmp.ProtocolNumber6,
+ netProto: ipv6.ProtocolNumber,
+ linkEndpoint: channelEP,
+ localAddr: ipv6Addr.Address,
+ icmpBuf: ipv6ICMPBuf,
+ checkLinkEndpoint: channelEPCheck,
+ },
+ {
+ name: "IPv4 loopback without local address",
+ transProto: icmp.ProtocolNumber4,
+ netProto: ipv4.ProtocolNumber,
+ linkEndpoint: loopback.New,
+ icmpBuf: ipv4ICMPBuf,
+ expectedConnectErr: tcpip.ErrNoRoute,
+ checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
+ },
+ {
+ name: "IPv6 loopback without local address",
+ transProto: icmp.ProtocolNumber6,
+ netProto: ipv6.ProtocolNumber,
+ linkEndpoint: loopback.New,
+ icmpBuf: ipv6ICMPBuf,
+ expectedConnectErr: tcpip.ErrNoRoute,
+ checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {},
+ },
+ {
+ name: "IPv4 non-loopback without local address",
+ transProto: icmp.ProtocolNumber4,
+ netProto: ipv4.ProtocolNumber,
+ linkEndpoint: channelEP,
+ icmpBuf: ipv4ICMPBuf,
+ expectedConnectErr: tcpip.ErrNoRoute,
+ checkLinkEndpoint: channelEPCheck,
+ },
+ {
+ name: "IPv6 non-loopback without local address",
+ transProto: icmp.ProtocolNumber6,
+ netProto: ipv6.ProtocolNumber,
+ linkEndpoint: channelEP,
+ icmpBuf: ipv6ICMPBuf,
+ expectedConnectErr: tcpip.ErrNoRoute,
+ checkLinkEndpoint: channelEPCheck,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6},
+ HandleLocal: true,
+ })
+ e := test.linkEndpoint()
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ if len(test.localAddr) != 0 {
+ if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err)
+ }
+ }
+
+ var wq waiter.Queue
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err)
+ }
+ defer ep.Close()
+
+ connAddr := tcpip.FullAddress{Addr: test.localAddr}
+ if err := ep.Connect(connAddr); err != test.expectedConnectErr {
+ t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr)
+ }
+
+ if test.expectedConnectErr != nil {
+ return
+ }
+
+ payload := tcpip.SlicePayload(test.icmpBuf(t))
+ var wOpts tcpip.WriteOptions
+ if n, _, err := ep.Write(payload, wOpts); err != nil {
+ t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err)
+ } else if n != int64(len(payload)) {
+ t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload))
+ }
+
+ // Wait for the endpoint to become readable.
+ <-ch
+
+ var addr tcpip.FullAddress
+ v, _, err := ep.Read(&addr)
+ if err != nil {
+ t.Fatalf("ep.Read(_): %s", err)
+ }
+ if diff := cmp.Diff(v[icmpDataOffset:], buffer.View(payload[icmpDataOffset:])); diff != "" {
+ t.Errorf("received data mismatch (-want +got):\n%s", diff)
+ }
+ if addr.Addr != test.localAddr {
+ t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.localAddr)
+ }
+
+ test.checkLinkEndpoint(t, e)
+ })
+ }
+}
+
+// TestLocalUDP tests sending UDP packets between two endpoints that are local
+// to the stack.
+//
+// This tests that that packets never leave the stack and the addresses
+// used when sending a packet.
+func TestLocalUDP(t *testing.T) {
+ const (
+ nicID = 1
+ )
+
+ tests := []struct {
+ name string
+ canBePrimaryAddr tcpip.ProtocolAddress
+ firstPrimaryAddr tcpip.ProtocolAddress
+ }{
+ {
+ name: "IPv4",
+ canBePrimaryAddr: ipv4Addr1,
+ firstPrimaryAddr: ipv4Addr2,
+ },
+ {
+ name: "IPv6",
+ canBePrimaryAddr: ipv6Addr1,
+ firstPrimaryAddr: ipv6Addr2,
+ },
+ }
+
+ subTests := []struct {
+ name string
+ addAddress bool
+ expectedWriteErr *tcpip.Error
+ }{
+ {
+ name: "Unassigned local address",
+ addAddress: false,
+ expectedWriteErr: tcpip.ErrNoRoute,
+ },
+ {
+ name: "Assigned local address",
+ addAddress: true,
+ expectedWriteErr: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ stackOpts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ HandleLocal: true,
+ }
+
+ s := stack.New(stackOpts)
+ ep := channel.New(1, header.IPv6MinimumMTU, "")
+
+ if err := s.CreateNIC(nicID, ep); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ if subTest.addAddress {
+ if err := s.AddProtocolAddressWithOptions(nicID, test.canBePrimaryAddr, stack.CanBePrimaryEndpoint); err != nil {
+ t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.canBePrimaryAddr, stack.FirstPrimaryEndpoint, err)
+ }
+ if err := s.AddProtocolAddressWithOptions(nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint, err)
+ }
+ }
+
+ var serverWQ waiter.Queue
+ serverWE, serverCH := waiter.NewChannelEntry(nil)
+ serverWQ.EventRegister(&serverWE, waiter.EventIn)
+ server, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &serverWQ)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err)
+ }
+ defer server.Close()
+
+ bindAddr := tcpip.FullAddress{Port: 80}
+ if err := server.Bind(bindAddr); err != nil {
+ t.Fatalf("server.Bind(%#v): %s", bindAddr, err)
+ }
+
+ var clientWQ waiter.Queue
+ clientWE, clientCH := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientWE, waiter.EventIn)
+ client, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &clientWQ)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err)
+ }
+ defer client.Close()
+
+ serverAddr := tcpip.FullAddress{
+ Addr: test.canBePrimaryAddr.AddressWithPrefix.Address,
+ Port: 80,
+ }
+
+ clientPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+ {
+ wOpts := tcpip.WriteOptions{
+ To: &serverAddr,
+ }
+ if n, _, err := client.Write(clientPayload, wOpts); err != subTest.expectedWriteErr {
+ t.Fatalf("got client.Write(%#v, %#v) = (%d, _, %s_), want = (_, _, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr)
+ } else if subTest.expectedWriteErr != nil {
+ // Nothing else to test if we expected not to be able to send the
+ // UDP packet.
+ return
+ } else if n != int64(len(clientPayload)) {
+ t.Fatalf("got client.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", clientPayload, wOpts, n, len(clientPayload))
+ }
+ }
+
+ // Wait for the server endpoint to become readable.
+ <-serverCH
+
+ var clientAddr tcpip.FullAddress
+ if v, _, err := server.Read(&clientAddr); err != nil {
+ t.Fatalf("server.Read(_): %s", err)
+ } else {
+ if diff := cmp.Diff(buffer.View(clientPayload), v); diff != "" {
+ t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff)
+ }
+ if clientAddr.Addr != test.canBePrimaryAddr.AddressWithPrefix.Address {
+ t.Errorf("got clientAddr.Addr = %s, want = %s", clientAddr.Addr, test.canBePrimaryAddr.AddressWithPrefix.Address)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ serverPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+ {
+ wOpts := tcpip.WriteOptions{
+ To: &clientAddr,
+ }
+ if n, _, err := server.Write(serverPayload, wOpts); err != nil {
+ t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err)
+ } else if n != int64(len(serverPayload)) {
+ t.Fatalf("got server.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", serverPayload, wOpts, n, len(serverPayload))
+ }
+ }
+
+ // Wait for the client endpoint to become readable.
+ <-clientCH
+
+ var gotServerAddr tcpip.FullAddress
+ if v, _, err := client.Read(&gotServerAddr); err != nil {
+ t.Fatalf("client.Read(_): %s", err)
+ } else {
+ if diff := cmp.Diff(buffer.View(serverPayload), v); diff != "" {
+ t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff)
+ }
+ if gotServerAddr.Addr != serverAddr.Addr {
+ t.Errorf("got gotServerAddr.Addr = %s, want = %s", gotServerAddr.Addr, serverAddr.Addr)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 41eb0ca44..763cd8f84 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -378,7 +378,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
switch opt {
- case tcpip.KeepaliveEnabledOption:
+ case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption:
return false, nil
default:
@@ -755,7 +755,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Only accept echo replies.
switch e.NetProto {
case header.IPv4ProtocolNumber:
@@ -800,7 +800,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
// Push new packet into receive list and increment the buffer size.
packet := &icmpPacket{
senderAddress: tcpip.FullAddress{
- NIC: r.NICID(),
+ NIC: pkt.NICID,
Addr: id.RemoteAddress,
},
}
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 87d510f96..3820e5dc7 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -101,7 +101,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+func (*protocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
return stack.UnknownDestinationPacketHandled
}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 072601d2d..31831a6d8 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -389,7 +389,12 @@ func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- return false, tcpip.ErrNotSupported
+ switch opt {
+ case tcpip.AcceptConnOption:
+ return false, nil
+ default:
+ return false, tcpip.ErrNotSupported
+ }
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index e37c00523..7b6a87ba9 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -601,7 +601,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
switch opt {
- case tcpip.KeepaliveEnabledOption:
+ case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption:
return false, nil
case tcpip.IPHdrIncludedOption:
@@ -646,7 +646,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
}
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
-func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
e.rcvMu.Lock()
// Drop the packet if our buffer is currently full or if this is an unassociated
@@ -671,14 +671,16 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
return
}
+ remoteAddr := pkt.Network().SourceAddress()
+
if e.bound {
// If bound to a NIC, only accept data for that NIC.
- if e.BindNICID != 0 && e.BindNICID != route.NICID() {
+ if e.BindNICID != 0 && e.BindNICID != pkt.NICID {
e.rcvMu.Unlock()
return
}
// If bound to an address, only accept data for that address.
- if e.BindAddr != "" && e.BindAddr != route.RemoteAddress {
+ if e.BindAddr != "" && e.BindAddr != remoteAddr {
e.rcvMu.Unlock()
return
}
@@ -686,7 +688,7 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
// If connected, only accept packets from the remote address we
// connected to.
- if e.connected && e.route.RemoteAddress != route.RemoteAddress {
+ if e.connected && e.route.RemoteAddress != remoteAddr {
e.rcvMu.Unlock()
return
}
@@ -696,8 +698,8 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
// Push new packet into receive list and increment the buffer size.
packet := &rawPacket{
senderAddr: tcpip.FullAddress{
- NIC: route.NICID(),
- Addr: route.RemoteAddress,
+ NIC: pkt.NICID,
+ Addr: remoteAddr,
},
}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 33bfb56cd..7d97cbdc7 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -37,57 +37,57 @@ func (p *rawPacket) loadData(data buffer.VectorisedView) {
}
// beforeSave is invoked by stateify.
-func (ep *endpoint) beforeSave() {
+func (e *endpoint) beforeSave() {
// Stop incoming packets from being handled (and mutate endpoint state).
// The lock will be released after saveRcvBufSizeMax(), which would have
- // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming
// packets.
- ep.rcvMu.Lock()
+ e.rcvMu.Lock()
}
// saveRcvBufSizeMax is invoked by stateify.
-func (ep *endpoint) saveRcvBufSizeMax() int {
- max := ep.rcvBufSizeMax
+func (e *endpoint) saveRcvBufSizeMax() int {
+ max := e.rcvBufSizeMax
// Make sure no new packets will be handled regardless of the lock.
- ep.rcvBufSizeMax = 0
+ e.rcvBufSizeMax = 0
// Release the lock acquired in beforeSave() so regular endpoint closing
// logic can proceed after save.
- ep.rcvMu.Unlock()
+ e.rcvMu.Unlock()
return max
}
// loadRcvBufSizeMax is invoked by stateify.
-func (ep *endpoint) loadRcvBufSizeMax(max int) {
- ep.rcvBufSizeMax = max
+func (e *endpoint) loadRcvBufSizeMax(max int) {
+ e.rcvBufSizeMax = max
}
// afterLoad is invoked by stateify.
-func (ep *endpoint) afterLoad() {
- stack.StackFromEnv.RegisterRestoredEndpoint(ep)
+func (e *endpoint) afterLoad() {
+ stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
// Resume implements tcpip.ResumableEndpoint.Resume.
-func (ep *endpoint) Resume(s *stack.Stack) {
- ep.stack = s
+func (e *endpoint) Resume(s *stack.Stack) {
+ e.stack = s
// If the endpoint is connected, re-connect.
- if ep.connected {
+ if e.connected {
var err *tcpip.Error
- ep.route, err = ep.stack.FindRoute(ep.RegisterNICID, ep.BindAddr, ep.route.RemoteAddress, ep.NetProto, false)
+ e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.route.RemoteAddress, e.NetProto, false)
if err != nil {
panic(err)
}
}
// If the endpoint is bound, re-bind.
- if ep.bound {
- if ep.stack.CheckLocalAddress(ep.RegisterNICID, ep.NetProto, ep.BindAddr) == 0 {
+ if e.bound {
+ if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.BindAddr) == 0 {
panic(tcpip.ErrBadLocalAddress)
}
}
- if ep.associated {
- if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil {
+ if e.associated {
+ if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil {
panic(err)
}
}
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index b706438bd..47982ca41 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -199,18 +199,25 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
// createConnectingEndpoint creates a new endpoint in a connecting state, with
// the connection parameters given by the arguments.
-func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) *endpoint {
+func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
- netProto = s.route.NetProto
+ netProto = s.netProto
}
+
+ route, err := l.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
+ if err != nil {
+ return nil, err
+ }
+ route.ResolveWith(s.remoteLinkAddr)
+
n := newEndpoint(l.stack, netProto, queue)
n.v6only = l.v6Only
n.ID = s.id
- n.boundNICID = s.route.NICID()
- n.route = s.route.Clone()
- n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
+ n.boundNICID = s.nicID
+ n.route = route
+ n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.netProto}
n.rcvBufSize = int(l.rcvWnd)
n.amss = calculateAdvertisedMSS(n.userMSS, n.route)
n.setEndpointState(StateConnecting)
@@ -225,7 +232,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
// window to grow to a really large value.
n.rcvAutoParams.prevCopied = n.initialReceiveWindow()
- return n
+ return n, nil
}
// createEndpointAndPerformHandshake creates a new endpoint in connected state
@@ -236,7 +243,10 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
- ep := l.createConnectingEndpoint(s, isn, irs, opts, queue)
+ ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue)
+ if err != nil {
+ return nil, err
+ }
// Lock the endpoint before registering to ensure that no out of
// band changes are possible due to incoming packets etc till
@@ -425,20 +435,17 @@ func (e *endpoint) notifyAborted() {
// cookies to accept connections.
func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
defer ctx.synRcvdCount.dec()
- defer func() {
- e.mu.Lock()
- e.decSynRcvdCount()
- e.mu.Unlock()
- }()
defer s.decRef()
n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
+ e.decSynRcvdCount()
return
}
ctx.removePendingEndpoint(n)
+ e.decSynRcvdCount()
n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
@@ -456,7 +463,9 @@ func (e *endpoint) incSynRcvdCount() bool {
}
func (e *endpoint) decSynRcvdCount() {
+ e.mu.Lock()
e.synRcvdCount--
+ e.mu.Unlock()
}
func (e *endpoint) acceptQueueIsFull() bool {
@@ -468,7 +477,7 @@ func (e *endpoint) acceptQueueIsFull() bool {
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
-func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
+func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Error {
e.rcvListMu.Lock()
rcvClosed := e.rcvClosed
e.rcvListMu.Unlock()
@@ -478,8 +487,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// RFC 793 section 3.4 page 35 (figure 12) outlines that a RST
// must be sent in response to a SYN-ACK while in the listen
// state to prevent completing a handshake from an old SYN.
- replyWithReset(s, e.sendTOS, e.ttl)
- return
+ return replyWithReset(e.stack, s, e.sendTOS, e.ttl)
}
switch {
@@ -493,13 +501,13 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
if !e.acceptQueueIsFull() && e.incSynRcvdCount() {
s.incRef()
go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier.
- return
+ return nil
}
ctx.synRcvdCount.dec()
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
- return
+ return nil
} else {
// If cookies are in use but the endpoint accept queue
// is full then drop the syn.
@@ -507,10 +515,17 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
- return
+ return nil
}
cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
+ route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+ route.ResolveWith(s.remoteLinkAddr)
+
// Send SYN without window scaling because we currently
// don't encode this information in the cookie.
//
@@ -524,9 +539,9 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
TS: opts.TS,
TSVal: tcpTimeStamp(time.Now(), timeStampOffset()),
TSEcr: opts.TSVal,
- MSS: calculateAdvertisedMSS(e.userMSS, s.route),
+ MSS: calculateAdvertisedMSS(e.userMSS, route),
}
- e.sendSynTCP(&s.route, tcpFields{
+ fields := tcpFields{
id: s.id,
ttl: e.ttl,
tos: e.sendTOS,
@@ -534,8 +549,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
seq: cookie,
ack: s.sequenceNumber + 1,
rcvWnd: ctx.rcvWnd,
- }, synOpts)
+ }
+ if err := e.sendSynTCP(&route, fields, synOpts); err != nil {
+ return err
+ }
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
+ return nil
}
case (s.flags & header.TCPFlagAck) != 0:
@@ -548,7 +567,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
e.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
- return
+ return nil
}
if !ctx.synRcvdCount.synCookiesInUse() {
@@ -567,8 +586,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// The only time we should reach here when a connection
// was opened and closed really quickly and a delayed
// ACK was received from the sender.
- replyWithReset(s, e.sendTOS, e.ttl)
- return
+ return replyWithReset(e.stack, s, e.sendTOS, e.ttl)
}
iss := s.ackNumber - 1
@@ -588,7 +606,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
if !ok || int(data) >= len(mssTable) {
e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment()
e.stack.Stats().DroppedPackets.Increment()
- return
+ return nil
}
e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment()
// Create newly accepted endpoint and deliver it.
@@ -609,7 +627,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr
}
- n := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{})
+ n, err := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{})
+ if err != nil {
+ return err
+ }
n.mu.Lock()
@@ -623,7 +644,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
- return
+ return nil
}
// Register new endpoint so that packets are routed to it.
@@ -633,7 +654,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
- return
+ return err
}
n.isRegistered = true
@@ -671,12 +692,16 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
go e.deliverAccepted(n)
+ return nil
+
+ default:
+ return nil
}
}
// protocolListenLoop is the main loop of a listening TCP endpoint. It runs in
// its own goroutine and is responsible for handling connection requests.
-func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
+func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.mu.Lock()
v6Only := e.v6only
ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto)
@@ -715,12 +740,14 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
case wakerForNotification:
n := e.fetchNotifications()
if n&notifyClose != 0 {
- return nil
+ return
}
if n&notifyDrain != 0 {
for !e.segmentQueue.empty() {
s := e.segmentQueue.dequeue()
- e.handleListenSegment(ctx, s)
+ // TODO(gvisor.dev/issue/4690): Better handle errors instead of
+ // silently dropping.
+ _ = e.handleListenSegment(ctx, s)
s.decRef()
}
close(e.drainDone)
@@ -739,7 +766,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
break
}
- e.handleListenSegment(ctx, s)
+ // TODO(gvisor.dev/issue/4690): Better handle errors instead of
+ // silently dropping.
+ _ = e.handleListenSegment(ctx, s)
s.decRef()
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 0aaef495d..2facbebec 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -293,9 +293,9 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
MSS: amss,
}
if ttl == 0 {
- ttl = s.route.DefaultTTL()
+ ttl = h.ep.route.DefaultTTL()
}
- h.ep.sendSynTCP(&s.route, tcpFields{
+ h.ep.sendSynTCP(&h.ep.route, tcpFields{
id: h.ep.ID,
ttl: ttl,
tos: h.ep.sendTOS,
@@ -356,7 +356,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
SACKPermitted: h.ep.sackPermitted,
MSS: h.ep.amss,
}
- h.ep.sendSynTCP(&s.route, tcpFields{
+ h.ep.sendSynTCP(&h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
tos: h.ep.sendTOS,
@@ -496,7 +496,9 @@ func (h *handshake) resolveRoute() *tcpip.Error {
}
// Wait for notification.
- index, _ = s.Fetch(true)
+ h.ep.mu.Unlock()
+ index, _ = s.Fetch(true /* block */)
+ h.ep.mu.Lock()
}
}
@@ -566,8 +568,10 @@ func (h *handshake) execute() *tcpip.Error {
}, synOpts)
for h.state != handshakeCompleted {
+ // Unlock before blocking, and reacquire again afterwards (h.ep.mu is held
+ // throughout handshake processing).
h.ep.mu.Unlock()
- index, _ := s.Fetch(true)
+ index, _ := s.Fetch(true /* block */)
h.ep.mu.Lock()
switch index {
@@ -767,7 +771,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta
// TCP header, then the kernel calculate a checksum of the
// header and data and get the right sum of the TCP packet.
tcp.SetChecksum(xsum)
- } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
+ } else if r.RequiresTXTransportChecksum() {
xsum = header.ChecksumVV(pkt.Data, xsum)
tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
}
@@ -1040,13 +1044,13 @@ func (e *endpoint) transitionToStateCloseLocked() {
// only when the endpoint is in StateClose and we want to deliver the segment
// to any other listening endpoint. We reply with RST if we cannot find one.
func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
- ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route)
+ ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, s.nicID)
if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" {
// Dual-stack socket, try IPv4.
- ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, &s.route)
+ ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, s.nicID)
}
if ep == nil {
- replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
+ replyWithReset(e.stack, s, stack.DefaultTOS, 0 /* ttl */)
s.decRef()
return
}
@@ -1366,7 +1370,9 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
drained := e.drainDone != nil
if drained {
close(e.drainDone)
+ e.mu.Unlock()
<-e.undrain
+ e.mu.Lock()
}
// Set up the functions that will be called when the main protocol loop
@@ -1535,7 +1541,7 @@ loop:
}
e.mu.Unlock()
- v, _ := s.Fetch(true)
+ v, _ := s.Fetch(true /* block */)
e.mu.Lock()
// We need to double check here because the notification may be
@@ -1620,7 +1626,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()
netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber}
}
for _, netProto := range netProtos {
- if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil {
+ if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, s.nicID); listenEP != nil {
tcpEP := listenEP.(*endpoint)
if EndpointState(tcpEP.State()) == StateListen {
reuseTW = func() {
@@ -1683,7 +1689,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
for {
e.mu.Unlock()
- v, _ := s.Fetch(true)
+ v, _ := s.Fetch(true /* block */)
e.mu.Lock()
switch v {
case newSegment:
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index 98aecab9e..21162f01a 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -172,10 +172,11 @@ func (d *dispatcher) wait() {
d.wg.Wait()
}
-func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
ep := stackEP.(*endpoint)
- s := newSegment(r, id, pkt)
- if !s.parse() {
+
+ s := newIncomingSegment(id, pkt)
+ if !s.parse(pkt.RXTransportChecksumValidated) {
ep.stack.Stats().MalformedRcvdPackets.Increment()
ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index 560b4904c..a6f25896b 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -236,6 +236,25 @@ func TestV6ConnectWhenBoundToWildcard(t *testing.T) {
testV6Connect(t, c)
}
+func TestStackV6OnlyConnectWhenBoundToWildcard(t *testing.T) {
+ c := context.NewWithOpts(t, context.Options{
+ EnableV6: true,
+ MTU: defaultMTU,
+ })
+ defer c.Cleanup()
+
+ // Create a v6 endpoint but don't set the v6-only TCP option.
+ c.CreateV6Endpoint(false)
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test the connection request.
+ testV6Connect(t, c)
+}
+
func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 3bcd3923a..258f9f1bb 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -721,9 +721,9 @@ func (e *endpoint) LockUser() {
for {
// Try first if the sock is locked then check if it's owned
// by another user goroutine if not then we spin, otherwise
- // we just goto sleep on the Lock() and wait.
+ // we just go to sleep on the Lock() and wait.
if !e.mu.TryLock() {
- // If socket is owned by the user then just goto sleep
+ // If socket is owned by the user then just go to sleep
// as the lock could be held for a reasonably long time.
if atomic.LoadUint32(&e.ownedByUser) == 1 {
e.mu.Lock()
@@ -1425,7 +1425,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
queueAndSend := func() (int64, <-chan struct{}, *tcpip.Error) {
// Add data to the send queue.
- s := newSegmentFromView(&e.route, e.ID, v)
+ s := newOutgoingSegment(e.ID, v)
e.sndBufUsed += len(v)
e.sndBufInQueue += seqnum.Size(len(v))
e.sndQueue.PushBack(s)
@@ -1999,6 +1999,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
case tcpip.MulticastLoopOption:
return true, nil
+ case tcpip.AcceptConnOption:
+ e.LockUser()
+ defer e.UnlockUser()
+
+ return e.EndpointState() == StateListen, nil
+
default:
return false, tcpip.ErrUnknownProtocolOption
}
@@ -2310,7 +2316,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
// done yet) or the reservation was freed between the check above and
// the FindTransportEndpoint below. But rather than retry the same port
// we just skip it and move on.
- transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, &r)
+ transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, r.NICID())
if transEP == nil {
// ReservePort failed but there is no registered endpoint with
// demuxer. Which indicates there is at least some endpoint that has
@@ -2379,7 +2385,6 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} {
for s := l.Front(); s != nil; s = s.Next() {
s.id = e.ID
- s.route = r.Clone()
e.sndWaker.Assert()
}
}
@@ -2445,7 +2450,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
}
// Queue fin segment.
- s := newSegmentFromView(&e.route, e.ID, nil)
+ s := newOutgoingSegment(e.ID, nil)
e.sndQueue.PushBack(s)
e.sndBufInQueue++
// Mark endpoint as closed.
@@ -2627,14 +2632,16 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
return err
}
- // Expand netProtos to include v4 and v6 if the caller is binding to a
- // wildcard (empty) address, and this is an IPv6 endpoint with v6only
- // set to false.
netProtos := []tcpip.NetworkProtocolNumber{netProto}
- if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
- netProtos = []tcpip.NetworkProtocolNumber{
- header.IPv6ProtocolNumber,
- header.IPv4ProtocolNumber,
+
+ // Expand netProtos to include v4 and v6 under dual-stack if the caller is
+ // binding to a wildcard (empty) address, and this is an IPv6 endpoint with
+ // v6only set to false.
+ if netProto == header.IPv6ProtocolNumber {
+ stackHasV4 := e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber)
+ alsoBindToV4 := !e.v6only && addr.Addr == "" && stackHasV4
+ if alsoBindToV4 {
+ netProtos = append(netProtos, header.IPv4ProtocolNumber)
}
}
@@ -2715,7 +2722,7 @@ func (e *endpoint) getRemoteAddress() tcpip.FullAddress {
}
}
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+func (*endpoint) HandlePacket(stack.TransportEndpointID, *stack.PacketBuffer) {
// TCP HandlePacket is not required anymore as inbound packets first
// land at the Dispatcher which then can either delivery using the
// worker go routine or directly do the invoke the tcp processing inline
@@ -3074,9 +3081,9 @@ func (e *endpoint) initHardwareGSO() {
}
func (e *endpoint) initGSO() {
- if e.route.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ if e.route.HasHardwareGSOCapability() {
e.initHardwareGSO()
- } else if e.route.Capabilities()&stack.CapabilitySoftwareGSO != 0 {
+ } else if e.route.HasSoftwareGSOCapability() {
e.gso = &stack.GSO{
MaxSize: e.route.GSOMaxSize(),
Type: stack.GSOSW,
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index b25431467..2bcc5e1c2 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -53,8 +53,8 @@ func (e *endpoint) beforeSave() {
switch {
case epState == StateInitial || epState == StateBound:
case epState.connected() || epState.handshake():
- if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 {
- if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 {
+ if !e.route.HasSaveRestoreCapability() {
+ if !e.route.HasDisconncetOkCapability() {
panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)})
}
e.resetConnectionLocked(tcpip.ErrConnectionAborted)
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 070b634b4..0664789da 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -30,6 +30,8 @@ import (
// The canonical way of using it is to pass the Forwarder.HandlePacket function
// to stack.SetTransportProtocolHandler.
type Forwarder struct {
+ stack *stack.Stack
+
maxInFlight int
handler func(*ForwarderRequest)
@@ -48,6 +50,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
rcvWnd = DefaultReceiveBufferSize
}
return &Forwarder{
+ stack: s,
maxInFlight: maxInFlight,
handler: handler,
inFlight: make(map[stack.TransportEndpointID]struct{}),
@@ -61,12 +64,12 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
- s := newSegment(r, id, pkt)
+func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+ s := newIncomingSegment(id, pkt)
defer s.decRef()
// We only care about well-formed SYN packets.
- if !s.parse() || !s.csumValid || s.flags != header.TCPFlagSyn {
+ if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid || s.flags != header.TCPFlagSyn {
return false
}
@@ -128,9 +131,8 @@ func (r *ForwarderRequest) Complete(sendReset bool) {
delete(r.forwarder.inFlight, r.segment.id)
r.forwarder.mu.Unlock()
- // If the caller requested, send a reset.
if sendReset {
- replyWithReset(r.segment, stack.DefaultTOS, r.segment.route.DefaultTTL())
+ replyWithReset(r.forwarder.stack, r.segment, stack.DefaultTOS, 0 /* ttl */)
}
// Release all resources.
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 5bce73605..2329aca4b 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -187,8 +187,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// to a specific processing queue. Each queue is serviced by its own processor
// goroutine which is responsible for dequeuing and doing full TCP dispatch of
// the packet.
-func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
- p.dispatcher.queuePacket(r, ep, id, pkt)
+func (p *protocol) QueuePacket(ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+ p.dispatcher.queuePacket(ep, id, pkt)
}
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
@@ -198,24 +198,32 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st
// a reset is sent in response to any incoming segment except another reset. In
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
-
-func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
- s := newSegment(r, id, pkt)
+func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+ s := newIncomingSegment(id, pkt)
defer s.decRef()
- if !s.parse() || !s.csumValid {
+ if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid {
return stack.UnknownDestinationPacketMalformed
}
if !s.flagIsSet(header.TCPFlagRst) {
- replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
+ replyWithReset(p.stack, s, stack.DefaultTOS, 0)
}
return stack.UnknownDestinationPacketHandled
}
// replyWithReset replies to the given segment with a reset segment.
-func replyWithReset(s *segment, tos, ttl uint8) {
+//
+// If the passed TTL is 0, then the route's default TTL will be used.
+func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error {
+ route, err := stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+ route.ResolveWith(s.remoteLinkAddr)
+
// Get the seqnum from the packet if the ack flag is set.
seq := seqnum.Value(0)
ack := seqnum.Value(0)
@@ -237,7 +245,12 @@ func replyWithReset(s *segment, tos, ttl uint8) {
flags |= header.TCPFlagAck
ack = s.sequenceNumber.Add(s.logicalLen())
}
- sendTCP(&s.route, tcpFields{
+
+ if ttl == 0 {
+ ttl = route.DefaultTTL()
+ }
+
+ return sendTCP(&route, tcpFields{
id: s.id,
ttl: ttl,
tos: tos,
diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard.go b/pkg/tcpip/transport/tcp/sack_scoreboard.go
index 7ef2df377..833a7b470 100644
--- a/pkg/tcpip/transport/tcp/sack_scoreboard.go
+++ b/pkg/tcpip/transport/tcp/sack_scoreboard.go
@@ -164,7 +164,7 @@ func (s *SACKScoreboard) IsSACKED(r header.SACKBlock) bool {
return found
}
-// Dump prints the state of the scoreboard structure.
+// String returns human-readable state of the scoreboard structure.
func (s *SACKScoreboard) String() string {
var str strings.Builder
str.WriteString("SACKScoreboard: {")
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 1f9c5cf50..2091989cc 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -19,6 +19,7 @@ import (
"sync/atomic"
"time"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
@@ -45,9 +46,18 @@ type segment struct {
ep *endpoint
qFlags queueFlags
id stack.TransportEndpointID `state:"manual"`
- route stack.Route `state:"manual"`
- data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- hdr header.TCP
+
+ // TODO(gvisor.dev/issue/4417): Hold a stack.PacketBuffer instead of
+ // individual members for link/network packet info.
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ netProto tcpip.NetworkProtocolNumber
+ nicID tcpip.NICID
+ remoteLinkAddr tcpip.LinkAddress
+
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+
+ hdr header.TCP
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
views [8]buffer.View `state:"nosave"`
@@ -76,11 +86,16 @@ type segment struct {
acked bool
}
-func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
+func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
+ netHdr := pkt.Network()
s := &segment{
- refCnt: 1,
- id: id,
- route: r.Clone(),
+ refCnt: 1,
+ id: id,
+ srcAddr: netHdr.SourceAddress(),
+ dstAddr: netHdr.DestinationAddress(),
+ netProto: pkt.NetworkProtocolNumber,
+ nicID: pkt.NICID,
+ remoteLinkAddr: pkt.SourceLinkAddress(),
}
s.data = pkt.Data.Clone(s.views[:])
s.hdr = header.TCP(pkt.TransportHeader().View())
@@ -88,11 +103,10 @@ func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketB
return s
}
-func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment {
+func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment {
s := &segment{
refCnt: 1,
id: id,
- route: r.Clone(),
}
s.rcvdTime = time.Now()
if len(v) != 0 {
@@ -110,7 +124,9 @@ func (s *segment) clone() *segment {
ackNumber: s.ackNumber,
flags: s.flags,
window: s.window,
- route: s.route.Clone(),
+ netProto: s.netProto,
+ nicID: s.nicID,
+ remoteLinkAddr: s.remoteLinkAddr,
viewToDeliver: s.viewToDeliver,
rcvdTime: s.rcvdTime,
xmitTime: s.xmitTime,
@@ -160,7 +176,6 @@ func (s *segment) decRef() {
panic(fmt.Sprintf("unexpected queue flag %b set for segment", s.qFlags))
}
}
- s.route.Release()
}
}
@@ -198,10 +213,10 @@ func (s *segment) segMemSize() int {
//
// Returns boolean indicating if the parsing was successful.
//
-// If checksum verification is not offloaded then parse also verifies the
+// If checksum verification may not be skipped, parse also verifies the
// TCP checksum and stores the checksum and result of checksum verification in
// the csum and csumValid fields of the segment.
-func (s *segment) parse() bool {
+func (s *segment) parse(skipChecksumValidation bool) bool {
// h is the header followed by the payload. We check that the offset to
// the data respects the following constraints:
// 1. That it's at least the minimum header size; if we don't do this
@@ -220,16 +235,14 @@ func (s *segment) parse() bool {
s.options = []byte(s.hdr[header.TCPMinimumSize:])
s.parsedOptions = header.ParseTCPOptions(s.options)
- // Query the link capabilities to decide if checksum validation is
- // required.
verifyChecksum := true
- if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 {
+ if skipChecksumValidation {
s.csumValid = true
verifyChecksum = false
}
if verifyChecksum {
s.csum = s.hdr.Checksum()
- xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()+len(s.hdr)))
+ xsum := header.PseudoHeaderChecksum(ProtocolNumber, s.srcAddr, s.dstAddr, uint16(s.data.Size()+len(s.hdr)))
xsum = s.hdr.CalculateChecksum(xsum)
xsum = header.ChecksumVV(s.data, xsum)
s.csumValid = xsum == 0xffff
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 6fa8d63cd..ab5fa4fb7 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -1285,6 +1285,10 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
// steps 2 and 3.
func (s *sender) walkSACK(rcvdSeg *segment) {
+ if len(rcvdSeg.parsedOptions.SACKBlocks) == 0 {
+ return
+ }
+
// Sort the SACK blocks. The first block is the most recent unacked
// block. The following blocks can be in arbitrary order.
sackBlocks := make([]header.SACKBlock, len(rcvdSeg.parsedOptions.SACKBlocks))
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index a7149efd0..5f05608e2 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -5131,6 +5131,7 @@ func TestKeepalive(t *testing.T) {
}
func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
+ t.Helper()
// Send a SYN request.
irs = seqnum.Value(789)
c.SendPacket(nil, &context.Headers{
@@ -5175,6 +5176,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki
}
func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
+ t.Helper()
// Send a SYN request.
irs = seqnum.Value(789)
c.SendV6Packet(nil, &context.Headers{
@@ -5238,13 +5240,14 @@ func TestListenBacklogFull(t *testing.T) {
// Test acceptance.
// Start listening.
- listenBacklog := 2
+ listenBacklog := 10
if err := c.EP.Listen(listenBacklog); err != nil {
t.Fatalf("Listen failed: %s", err)
}
- for i := 0; i < listenBacklog; i++ {
- executeHandshake(t, c, context.TestPort+uint16(i), false /*synCookieInUse */)
+ lastPortOffset := uint16(0)
+ for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ {
+ executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */)
}
time.Sleep(50 * time.Millisecond)
@@ -5252,7 +5255,7 @@ func TestListenBacklogFull(t *testing.T) {
// Now execute send one more SYN. The stack should not respond as the backlog
// is full at this point.
c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + 2,
+ SrcPort: context.TestPort + uint16(lastPortOffset),
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: seqnum.Value(789),
@@ -5293,7 +5296,7 @@ func TestListenBacklogFull(t *testing.T) {
}
// Now a new handshake must succeed.
- executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */)
+ executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */)
newEP, _, err := c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
@@ -6722,6 +6725,13 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second)
+ // drain any older notifications from the notification channel before attempting
+ // 2nd connection.
+ select {
+ case <-ch:
+ default:
+ }
+
// Send a SYN request w/ sequence number higher than
// the highest sequence number sent.
iss = seqnum.Value(792)
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 4d7847142..f791f8f13 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -112,6 +112,18 @@ type Headers struct {
TCPOpts []byte
}
+// Options contains options for creating a new test context.
+type Options struct {
+ // EnableV4 indicates whether IPv4 should be enabled.
+ EnableV4 bool
+
+ // EnableV6 indicates whether IPv4 should be enabled.
+ EnableV6 bool
+
+ // MTU indicates the maximum transmission unit on the link layer.
+ MTU uint32
+}
+
// Context provides an initialized Network stack and a link layer endpoint
// for use in TCP tests.
type Context struct {
@@ -154,10 +166,30 @@ type Context struct {
// New allocates and initializes a test context containing a new
// stack and a link-layer endpoint.
func New(t *testing.T, mtu uint32) *Context {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ return NewWithOpts(t, Options{
+ EnableV4: true,
+ EnableV6: true,
+ MTU: mtu,
})
+}
+
+// NewWithOpts allocates and initializes a test context containing a new
+// stack and a link-layer endpoint with specific options.
+func NewWithOpts(t *testing.T, opts Options) *Context {
+ if opts.MTU == 0 {
+ panic("MTU must be greater than 0")
+ }
+
+ stackOpts := stack.Options{
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ }
+ if opts.EnableV4 {
+ stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv4.NewProtocol)
+ }
+ if opts.EnableV6 {
+ stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv6.NewProtocol)
+ }
+ s := stack.New(stackOpts)
const sendBufferSize = 1 << 20 // 1 MiB
const recvBufferSize = 1 << 20 // 1 MiB
@@ -182,50 +214,55 @@ func New(t *testing.T, mtu uint32) *Context {
// Some of the congestion control tests send up to 640 packets, we so
// set the channel size to 1000.
- ep := channel.New(1000, mtu, "")
+ ep := channel.New(1000, opts.MTU, "")
wep := stack.LinkEndpoint(ep)
if testing.Verbose() {
wep = sniffer.New(ep)
}
- opts := stack.NICOptions{Name: "nic1"}
- if err := s.CreateNICWithOptions(1, wep, opts); err != nil {
+ nicOpts := stack.NICOptions{Name: "nic1"}
+ if err := s.CreateNICWithOptions(1, wep, nicOpts); err != nil {
t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
}
- wep2 := stack.LinkEndpoint(channel.New(1000, mtu, ""))
+ wep2 := stack.LinkEndpoint(channel.New(1000, opts.MTU, ""))
if testing.Verbose() {
- wep2 = sniffer.New(channel.New(1000, mtu, ""))
+ wep2 = sniffer.New(channel.New(1000, opts.MTU, ""))
}
opts2 := stack.NICOptions{Name: "nic2"}
if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil {
t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err)
}
- v4ProtocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv4.ProtocolNumber,
- AddressWithPrefix: StackAddrWithPrefix,
- }
- if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err)
- }
-
- v6ProtocolAddr := tcpip.ProtocolAddress{
- Protocol: ipv6.ProtocolNumber,
- AddressWithPrefix: StackV6AddrWithPrefix,
- }
- if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err)
- }
+ var routeTable []tcpip.Route
- s.SetRouteTable([]tcpip.Route{
- {
+ if opts.EnableV4 {
+ v4ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv4.ProtocolNumber,
+ AddressWithPrefix: StackAddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err)
+ }
+ routeTable = append(routeTable, tcpip.Route{
Destination: header.IPv4EmptySubnet,
NIC: 1,
- },
- {
+ })
+ }
+
+ if opts.EnableV6 {
+ v6ProtocolAddr := tcpip.ProtocolAddress{
+ Protocol: ipv6.ProtocolNumber,
+ AddressWithPrefix: StackV6AddrWithPrefix,
+ }
+ if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err)
+ }
+ routeTable = append(routeTable, tcpip.Route{
Destination: header.IPv6EmptySubnet,
NIC: 1,
- },
- })
+ })
+ }
+
+ s.SetRouteTable(routeTable)
return &Context{
t: t,
@@ -373,6 +410,9 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code header.ICMPv4Code,
const icmpv4VariableHeaderOffset = 4
copy(icmp[icmpv4VariableHeaderOffset:], p1)
copy(icmp[header.ICMPv4PayloadOffset:], p2)
+ icmp.SetChecksum(0)
+ checksum := ^header.Checksum(icmp, 0 /* initial */)
+ icmp.SetChecksum(checksum)
// Inject packet.
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index d57ed5d79..9bcb918bb 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -487,6 +487,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
nicID = e.BindNICID
}
+ if to.Port == 0 {
+ // Port 0 is an invalid port to send to.
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+
dst, netProto, err := e.checkV4MappedLocked(*to)
if err != nil {
return 0, nil, err
@@ -895,6 +900,9 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
return v, nil
+ case tcpip.AcceptConnOption:
+ return false, nil
+
default:
return false, tcpip.ErrUnknownProtocolOption
}
@@ -1009,7 +1017,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
// On IPv4, UDP checksum is optional, and a zero value indicates the
// transmitter skipped the checksum generation (RFC768).
// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
- if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 &&
+ if r.RequiresTXTransportChecksum() &&
(!noChecksum || r.NetProto == header.IPv6ProtocolNumber) {
xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
for _, v := range data.Views() {
@@ -1366,6 +1374,12 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
e.rcvMu.Unlock()
}
+ e.lastErrorMu.Lock()
+ hasError := e.lastError != nil
+ e.lastErrorMu.Unlock()
+ if hasError {
+ result |= waiter.EventErr
+ }
return result
}
@@ -1373,10 +1387,11 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// On IPv4, UDP checksum is optional, and a zero value means the transmitter
// omitted the checksum generation (RFC768).
// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
-func verifyChecksum(r *stack.Route, hdr header.UDP, pkt *stack.PacketBuffer) bool {
- if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 &&
- (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) {
- xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length())
+func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool {
+ if !pkt.RXTransportChecksumValidated &&
+ (hdr.Checksum() != 0 || pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber) {
+ netHdr := pkt.Network()
+ xsum := header.PseudoHeaderChecksum(ProtocolNumber, netHdr.DestinationAddress(), netHdr.SourceAddress(), hdr.Length())
for _, v := range pkt.Data.Views() {
xsum = header.Checksum(v, xsum)
}
@@ -1387,7 +1402,7 @@ func verifyChecksum(r *stack.Route, hdr header.UDP, pkt *stack.PacketBuffer) boo
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Get the header then trim it from the view.
hdr := header.UDP(pkt.TransportHeader().View())
if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
@@ -1397,7 +1412,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
return
}
- if !verifyChecksum(r, hdr, pkt) {
+ if !verifyChecksum(hdr, pkt) {
// Checksum Error.
e.stack.Stats().UDP.ChecksumErrors.Increment()
e.stats.ReceiveErrors.ChecksumErrors.Increment()
@@ -1428,7 +1443,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
// Push new packet into receive list and increment the buffer size.
packet := &udpPacket{
senderAddress: tcpip.FullAddress{
- NIC: r.NICID(),
+ NIC: pkt.NICID,
Addr: id.RemoteAddress,
Port: header.UDP(hdr).SourcePort(),
},
@@ -1438,7 +1453,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
e.rcvBufSize += pkt.Data.Size()
// Save any useful information from the network header to the packet.
- switch r.NetProto {
+ switch pkt.NetworkProtocolNumber {
case header.IPv4ProtocolNumber:
packet.tos, _ = header.IPv4(pkt.NetworkHeader().View()).TOS()
case header.IPv6ProtocolNumber:
@@ -1448,9 +1463,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
// TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast
// address. packetInfo.LocalAddr should hold a unicast address that can be
// used to respond to the incoming packet.
- packet.packetInfo.LocalAddr = r.LocalAddress
- packet.packetInfo.DestinationAddr = r.LocalAddress
- packet.packetInfo.NIC = r.NICID()
+ localAddr := pkt.Network().DestinationAddress()
+ packet.packetInfo.LocalAddr = localAddr
+ packet.packetInfo.DestinationAddr = localAddr
+ packet.packetInfo.NIC = pkt.NICID
packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()
@@ -1465,14 +1481,16 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
if typ == stack.ControlPortUnreachable {
e.mu.RLock()
- defer e.mu.RUnlock()
-
if e.state == StateConnected {
e.lastErrorMu.Lock()
- defer e.lastErrorMu.Unlock()
-
e.lastError = tcpip.ErrConnectionRefused
+ e.lastErrorMu.Unlock()
+ e.mu.RUnlock()
+
+ e.waiterQueue.Notify(waiter.EventErr)
+ return
}
+ e.mu.RUnlock()
}
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index 3ae6cc221..14e4648cd 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -43,10 +43,9 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder {
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
f.handler(&ForwarderRequest{
stack: f.stack,
- route: r,
id: id,
pkt: pkt,
})
@@ -59,7 +58,6 @@ func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, p
// it via CreateEndpoint.
type ForwarderRequest struct {
stack *stack.Stack
- route *stack.Route
id stack.TransportEndpointID
pkt *stack.PacketBuffer
}
@@ -72,17 +70,25 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID {
// CreateEndpoint creates a connected UDP endpoint for the session request.
func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- ep := newEndpoint(r.stack, r.route.NetProto, queue)
- if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil {
+ netHdr := r.pkt.Network()
+ route, err := r.stack.FindRoute(r.pkt.NICID, netHdr.DestinationAddress(), netHdr.SourceAddress(), r.pkt.NetworkProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return nil, err
+ }
+ route.ResolveWith(r.pkt.SourceLinkAddress())
+
+ ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue)
+ if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil {
ep.Close()
+ route.Release()
return nil, err
}
ep.ID = r.id
- ep.route = r.route.Clone()
+ ep.route = route
ep.dstPort = r.id.RemotePort
- ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.route.NetProto}
- ep.RegisterNICID = r.route.NICID()
+ ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}
+ ep.RegisterNICID = r.pkt.NICID
ep.boundPortFlags = ep.portFlags
ep.state = StateConnected
@@ -91,7 +97,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
ep.rcvReady = true
ep.rcvMu.Unlock()
- ep.HandlePacket(r.route, r.id, r.pkt)
+ ep.HandlePacket(r.id, r.pkt)
return ep, nil
}
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index da5b1deb2..91420edd3 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -78,15 +78,15 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// HandleUnknownDestinationPacket handles packets that are targeted at this
// protocol but don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
hdr := header.UDP(pkt.TransportHeader().View())
if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
- r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
+ p.stack.Stats().UDP.MalformedPacketsReceived.Increment()
return stack.UnknownDestinationPacketMalformed
}
- if !verifyChecksum(r, hdr, pkt) {
- r.Stack().Stats().UDP.ChecksumErrors.Increment()
+ if !verifyChecksum(hdr, pkt) {
+ p.stack.Stats().UDP.ChecksumErrors.Increment()
return stack.UnknownDestinationPacketMalformed
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index b4604ba35..fb7738dda 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -1452,6 +1452,10 @@ func (*testInterface) Enabled() bool {
return true
}
+func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
func TestTTL(t *testing.T) {
for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
@@ -1791,7 +1795,6 @@ func TestV4UnknownDestination(t *testing.T) {
// had only a minimal IP header but the ICMP sender will have allowed
// for a maximally sized packet header.
wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
-
}
// In the case of large payloads the IP packet may be truncated. Update
diff --git a/pkg/unet/unet_test.go b/pkg/unet/unet_test.go
index 5c4b9e8e9..a38ffc19d 100644
--- a/pkg/unet/unet_test.go
+++ b/pkg/unet/unet_test.go
@@ -53,40 +53,40 @@ func randomFilename() (string, error) {
func TestConnectFailure(t *testing.T) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
if _, err := Connect(name, false); err == nil {
- t.Fatalf("connect was successful, expected err")
+ t.Fatalf("Connect was successful, expected err")
}
}
func TestBindFailure(t *testing.T) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
ss, err := BindAndListen(name, false)
if err != nil {
- t.Fatalf("first bind failed, got err %v expected nil", err)
+ t.Fatalf("First bind failed, got err %v expected nil", err)
}
defer ss.Close()
if _, err = BindAndListen(name, false); err == nil {
- t.Fatalf("second bind succeeded, expected non-nil err")
+ t.Fatalf("Second bind succeeded, expected non-nil err")
}
}
func TestMultipleAccept(t *testing.T) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
ss, err := BindAndListen(name, false)
if err != nil {
- t.Fatalf("first bind failed, got err %v expected nil", err)
+ t.Fatalf("First bind failed, got err %v expected nil", err)
}
defer ss.Close()
@@ -99,7 +99,8 @@ func TestMultipleAccept(t *testing.T) {
defer wg.Done()
s, err := Connect(name, false)
if err != nil {
- t.Fatalf("connect failed, got err %v expected nil", err)
+ t.Errorf("Connect failed, got err %v expected nil", err)
+ return
}
s.Close()
}()
@@ -109,7 +110,7 @@ func TestMultipleAccept(t *testing.T) {
for i := 0; i < backlog; i++ {
s, err := ss.Accept()
if err != nil {
- t.Errorf("accept failed, got err %v expected nil", err)
+ t.Errorf("Accept failed, got err %v expected nil", err)
continue
}
s.Close()
@@ -119,35 +120,35 @@ func TestMultipleAccept(t *testing.T) {
func TestServerClose(t *testing.T) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
ss, err := BindAndListen(name, false)
if err != nil {
- t.Fatalf("first bind failed, got err %v expected nil", err)
+ t.Fatalf("First bind failed, got err %v expected nil", err)
}
// Make sure the first close succeeds.
if err := ss.Close(); err != nil {
- t.Fatalf("first close failed, got err %v expected nil", err)
+ t.Fatalf("First close failed, got err %v expected nil", err)
}
// The second one should fail.
if err := ss.Close(); err == nil {
- t.Fatalf("second close succeeded, expected non-nil err")
+ t.Fatalf("Second close succeeded, expected non-nil err")
}
}
func socketPair(t *testing.T, packet bool) (*Socket, *Socket) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
// Bind a server.
ss, err := BindAndListen(name, packet)
if err != nil {
- t.Fatalf("error binding, got %v expected nil", err)
+ t.Fatalf("Error binding, got %v expected nil", err)
}
defer ss.Close()
@@ -165,7 +166,7 @@ func socketPair(t *testing.T, packet bool) (*Socket, *Socket) {
// Connect the client.
client, err := Connect(name, packet)
if err != nil {
- t.Fatalf("error connecting, got %v expected nil", err)
+ t.Fatalf("Error connecting, got %v expected nil", err)
}
// Grab the server handle.
@@ -173,7 +174,7 @@ func socketPair(t *testing.T, packet bool) (*Socket, *Socket) {
case server := <-acceptSocket:
return server, client
case err := <-acceptErr:
- t.Fatalf("accept error: %v", err)
+ t.Fatalf("Accept error: %v", err)
}
panic("unreachable")
}
@@ -186,17 +187,17 @@ func TestSendRecv(t *testing.T) {
// Write on the client.
w := client.Writer(true)
if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
- t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err)
}
// Read on the server.
b := [][]byte{{'b'}}
r := server.Reader(true)
if n, err := r.ReadVec(b); n != 1 || err != nil {
- t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err)
}
if b[0][0] != 'a' {
- t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ t.Fatalf("Got bad read data, got %c, expected a", b[0][0])
}
}
@@ -211,17 +212,17 @@ func TestSymmetric(t *testing.T) {
// Write on the server.
w := server.Writer(true)
if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
- t.Fatalf("for server write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For server write, got n=%d err=%v, expected n=1 err=nil", n, err)
}
// Read on the client.
b := [][]byte{{'b'}}
r := client.Reader(true)
if n, err := r.ReadVec(b); n != 1 || err != nil {
- t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err)
}
if b[0][0] != 'a' {
- t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ t.Fatalf("Got bad read data, got %c, expected a", b[0][0])
}
}
@@ -233,13 +234,13 @@ func TestPacket(t *testing.T) {
// Write on the client.
w := client.Writer(true)
if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
- t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err)
}
// Write on the client again.
w = client.Writer(true)
if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
- t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err)
}
// Read on the server.
@@ -249,19 +250,19 @@ func TestPacket(t *testing.T) {
b := [][]byte{{'b', 'b'}}
r := server.Reader(true)
if n, err := r.ReadVec(b); n != 1 || err != nil {
- t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err)
}
if b[0][0] != 'a' {
- t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ t.Fatalf("Got bad read data, got %c, expected a", b[0][0])
}
// Do it again.
r = server.Reader(true)
if n, err := r.ReadVec(b); n != 1 || err != nil {
- t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err)
}
if b[0][0] != 'a' {
- t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ t.Fatalf("Got bad read data, got %c, expected a", b[0][0])
}
}
@@ -271,12 +272,12 @@ func TestClose(t *testing.T) {
// Make sure the first close succeeds.
if err := client.Close(); err != nil {
- t.Fatalf("first close failed, got err %v expected nil", err)
+ t.Fatalf("First close failed, got err %v expected nil", err)
}
// The second one should fail.
if err := client.Close(); err == nil {
- t.Fatalf("second close succeeded, expected non-nil err")
+ t.Fatalf("Second close succeeded, expected non-nil err")
}
}
@@ -294,17 +295,17 @@ func TestNonBlockingSend(t *testing.T) {
// We're good. That's what we wanted.
blockCount++
} else {
- t.Fatalf("for client write, got n=%d err=%v, expected n=1000 err=nil", n, err)
+ t.Fatalf("For client write, got n=%d err=%v, expected n=1000 err=nil", n, err)
}
}
}
if blockCount == 1000 {
// Shouldn't have _always_ blocked.
- t.Fatalf("socket always blocked!")
+ t.Fatalf("Socket always blocked!")
} else if blockCount == 0 {
// Should have started blocking eventually.
- t.Fatalf("socket never blocked!")
+ t.Fatalf("Socket never blocked!")
}
}
@@ -319,25 +320,25 @@ func TestNonBlockingRecv(t *testing.T) {
// Expected to block immediately.
_, err := r.ReadVec(b)
if err != syscall.EWOULDBLOCK && err != syscall.EAGAIN {
- t.Fatalf("read didn't block, got err %v expected blocking err", err)
+ t.Fatalf("Read didn't block, got err %v expected blocking err", err)
}
// Put some data in the pipe.
w := server.Writer(false)
if n, err := w.WriteVec(b); n != 1 || err != nil {
- t.Fatalf("write failed with n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("Write failed with n=%d err=%v, expected n=1 err=nil", n, err)
}
// Expect it not to block.
if n, err := r.ReadVec(b); n != 1 || err != nil {
- t.Fatalf("read failed with n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("Read failed with n=%d err=%v, expected n=1 err=nil", n, err)
}
// Expect it to return a block error again.
r = client.Reader(false)
_, err = r.ReadVec(b)
if err != syscall.EWOULDBLOCK && err != syscall.EAGAIN {
- t.Fatalf("read didn't block, got err %v expected blocking err", err)
+ t.Fatalf("Read didn't block, got err %v expected blocking err", err)
}
}
@@ -349,17 +350,17 @@ func TestRecvVectors(t *testing.T) {
// Write on the client.
w := client.Writer(true)
if n, err := w.WriteVec([][]byte{{'a', 'b'}}); n != 2 || err != nil {
- t.Fatalf("for client write, got n=%d err=%v, expected n=2 err=nil", n, err)
+ t.Fatalf("For client write, got n=%d err=%v, expected n=2 err=nil", n, err)
}
// Read on the server.
b := [][]byte{{'c'}, {'c'}}
r := server.Reader(true)
if n, err := r.ReadVec(b); n != 2 || err != nil {
- t.Fatalf("for server read, got n=%d err=%v, expected n=2 err=nil", n, err)
+ t.Fatalf("For server read, got n=%d err=%v, expected n=2 err=nil", n, err)
}
if b[0][0] != 'a' || b[1][0] != 'b' {
- t.Fatalf("got bad read data, got %c,%c, expected a,b", b[0][0], b[1][0])
+ t.Fatalf("Got bad read data, got %c,%c, expected a,b", b[0][0], b[1][0])
}
}
@@ -371,17 +372,17 @@ func TestSendVectors(t *testing.T) {
// Write on the client.
w := client.Writer(true)
if n, err := w.WriteVec([][]byte{{'a'}, {'b'}}); n != 2 || err != nil {
- t.Fatalf("for client write, got n=%d err=%v, expected n=2 err=nil", n, err)
+ t.Fatalf("For client write, got n=%d err=%v, expected n=2 err=nil", n, err)
}
// Read on the server.
b := [][]byte{{'c', 'c'}}
r := server.Reader(true)
if n, err := r.ReadVec(b); n != 2 || err != nil {
- t.Fatalf("for server read, got n=%d err=%v, expected n=2 err=nil", n, err)
+ t.Fatalf("For server read, got n=%d err=%v, expected n=2 err=nil", n, err)
}
if b[0][0] != 'a' || b[0][1] != 'b' {
- t.Fatalf("got bad read data, got %c,%c, expected a,b", b[0][0], b[0][1])
+ t.Fatalf("Got bad read data, got %c,%c, expected a,b", b[0][0], b[0][1])
}
}
@@ -394,23 +395,23 @@ func TestSendFDsNotEnabled(t *testing.T) {
w := server.Writer(true)
w.PackFDs(0, 1, 2)
if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
- t.Fatalf("for server write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For server write, got n=%d err=%v, expected n=1 err=nil", n, err)
}
// Read on the client, without enabling FDs.
b := [][]byte{{'b'}}
r := client.Reader(true)
if n, err := r.ReadVec(b); n != 1 || err != nil {
- t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err)
}
if b[0][0] != 'a' {
- t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ t.Fatalf("Got bad read data, got %c, expected a", b[0][0])
}
// Make sure the FDs are not received.
fds, err := r.ExtractFDs()
if len(fds) != 0 || err != nil {
- t.Fatalf("got fds=%v err=%v, expected len(fds)=0 err=nil", fds, err)
+ t.Fatalf("Got fds=%v err=%v, expected len(fds)=0 err=nil", fds, err)
}
}
@@ -418,7 +419,7 @@ func sendFDs(t *testing.T, s *Socket, fds []int) {
w := s.Writer(true)
w.PackFDs(fds...)
if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil {
- t.Fatalf("for write, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For write, got n=%d err=%v, expected n=1 err=nil", n, err)
}
}
@@ -428,7 +429,7 @@ func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) {
// Count the number of FDs.
preEntries, err := ioutil.ReadDir("/proc/self/fd")
if err != nil {
- t.Fatalf("can't readdir, got err %v expected nil", err)
+ t.Fatalf("Can't readdir, got err %v expected nil", err)
}
// Read on the client.
@@ -438,31 +439,31 @@ func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) {
r.EnableFDs(enableSize)
}
if n, err := r.ReadVec(b); n != 1 || err != nil {
- t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err)
+ t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err)
}
if b[0][0] != 'a' {
- t.Fatalf("got bad read data, got %c, expected a", b[0][0])
+ t.Fatalf("Got bad read data, got %c, expected a", b[0][0])
}
// Count the new number of FDs.
postEntries, err := ioutil.ReadDir("/proc/self/fd")
if err != nil {
- t.Fatalf("can't readdir, got err %v expected nil", err)
+ t.Fatalf("Can't readdir, got err %v expected nil", err)
}
if len(preEntries)+expected != len(postEntries) {
- t.Errorf("process fd count isn't right, expected %d got %d", len(preEntries)+expected, len(postEntries))
+ t.Errorf("Process fd count isn't right, expected %d got %d", len(preEntries)+expected, len(postEntries))
}
// Make sure the FDs are there.
fds, err := r.ExtractFDs()
if len(fds) != expected || err != nil {
- t.Fatalf("got fds=%v err=%v, expected len(fds)=%d err=nil", fds, err, expected)
+ t.Fatalf("Got fds=%v err=%v, expected len(fds)=%d err=nil", fds, err, expected)
}
// Make sure they are different from the originals.
for i := 0; i < len(fds); i++ {
if fds[i] == origFDs[i] {
- t.Errorf("got original fd for index %d, expected different", i)
+ t.Errorf("Got original fd for index %d, expected different", i)
}
}
@@ -480,10 +481,10 @@ func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) {
// Make sure the count is back to normal.
finalEntries, err := ioutil.ReadDir("/proc/self/fd")
if err != nil {
- t.Fatalf("can't readdir, got err %v expected nil", err)
+ t.Fatalf("Can't readdir, got err %v expected nil", err)
}
if len(finalEntries) != len(preEntries) {
- t.Errorf("process fd count isn't right, expected %d got %d", len(preEntries), len(finalEntries))
+ t.Errorf("Process fd count isn't right, expected %d got %d", len(preEntries), len(finalEntries))
}
}
@@ -567,7 +568,7 @@ func TestGetPeerCred(t *testing.T) {
}
if got, err := client.GetPeerCred(); err != nil || !reflect.DeepEqual(got, want) {
- t.Errorf("got GetPeerCred() = %v, %v, want = %+v, %+v", got, err, want, nil)
+ t.Errorf("GetPeerCred() = %v, %v, want = %+v, %+v", got, err, want, nil)
}
}
@@ -594,53 +595,53 @@ func TestGetPeerCredFailure(t *testing.T) {
want := "bad file descriptor"
if _, err := s.GetPeerCred(); err == nil || err.Error() != want {
- t.Errorf("got s.GetPeerCred() = %v, want = %s", err, want)
+ t.Errorf("s.GetPeerCred() = %v, want = %s", err, want)
}
}
func TestAcceptClosed(t *testing.T) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
ss, err := BindAndListen(name, false)
if err != nil {
- t.Fatalf("bind failed, got err %v expected nil", err)
+ t.Fatalf("Bind failed, got err %v expected nil", err)
}
if err := ss.Close(); err != nil {
- t.Fatalf("close failed, got err %v expected nil", err)
+ t.Fatalf("Close failed, got err %v expected nil", err)
}
if _, err := ss.Accept(); err == nil {
- t.Errorf("accept on closed SocketServer, got err %v, want != nil", err)
+ t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err)
}
}
func TestCloseAfterAcceptStart(t *testing.T) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
ss, err := BindAndListen(name, false)
if err != nil {
- t.Fatalf("bind failed, got err %v expected nil", err)
+ t.Fatalf("Bind failed, got err %v expected nil", err)
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
+ defer wg.Done()
time.Sleep(50 * time.Millisecond)
if err := ss.Close(); err != nil {
- t.Fatalf("close failed, got err %v expected nil", err)
+ t.Errorf("Close failed, got err %v expected nil", err)
}
- wg.Done()
}()
if _, err := ss.Accept(); err == nil {
- t.Errorf("accept on closed SocketServer, got err %v, want != nil", err)
+ t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err)
}
wg.Wait()
@@ -649,28 +650,28 @@ func TestCloseAfterAcceptStart(t *testing.T) {
func TestReleaseAfterAcceptStart(t *testing.T) {
name, err := randomFilename()
if err != nil {
- t.Fatalf("unable to generate file, got err %v expected nil", err)
+ t.Fatalf("Unable to generate file, got err %v expected nil", err)
}
ss, err := BindAndListen(name, false)
if err != nil {
- t.Fatalf("bind failed, got err %v expected nil", err)
+ t.Fatalf("Bind failed, got err %v expected nil", err)
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
+ defer wg.Done()
time.Sleep(50 * time.Millisecond)
fd, err := ss.Release()
if err != nil {
- t.Fatalf("Release failed, got err %v expected nil", err)
+ t.Errorf("Release failed, got err %v expected nil", err)
}
syscall.Close(fd)
- wg.Done()
}()
if _, err := ss.Accept(); err == nil {
- t.Errorf("accept on closed SocketServer, got err %v, want != nil", err)
+ t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err)
}
wg.Wait()
@@ -688,7 +689,7 @@ func TestControlMessage(t *testing.T) {
cm.PackFDs(want...)
got, err := cm.ExtractFDs()
if err != nil || !reflect.DeepEqual(got, want) {
- t.Errorf("got cm.ExtractFDs() = %v, %v, want = %v, %v", got, err, want, nil)
+ t.Errorf("cm.ExtractFDs() = %v, %v, want = %v, %v", got, err, want, nil)
}
}
}
@@ -705,11 +706,13 @@ func benchmarkSendRecv(b *testing.B, packet bool) {
for i := 0; i < b.N; i++ {
n, err := server.Read(buf)
if n != 1 || err != nil {
- b.Fatalf("server.Read: got (%d, %v), wanted (1, nil)", n, err)
+ b.Errorf("server.Read: got (%d, %v), wanted (1, nil)", n, err)
+ return
}
n, err = server.Write(buf)
if n != 1 || err != nil {
- b.Fatalf("server.Write: got (%d, %v), wanted (1, nil)", n, err)
+ b.Errorf("server.Write: got (%d, %v), wanted (1, nil)", n, err)
+ return
}
}
}()
diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go
index 67a950444..08519d986 100644
--- a/pkg/waiter/waiter.go
+++ b/pkg/waiter/waiter.go
@@ -168,7 +168,7 @@ func NewChannelEntry(c chan struct{}) (Entry, chan struct{}) {
//
// +stateify savable
type Queue struct {
- list waiterList `state:"zerovalue"`
+ list waiterList
mu sync.RWMutex `state:"nosave"`
}
diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD
index 248f77c34..8c73dc5dc 100644
--- a/runsc/boot/BUILD
+++ b/runsc/boot/BUILD
@@ -38,6 +38,7 @@ go_library(
"//pkg/memutil",
"//pkg/rand",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/sentry/arch",
"//pkg/sentry/arch:registers_go_proto",
"//pkg/sentry/control",
@@ -74,6 +75,7 @@ go_library(
"//pkg/sentry/platform",
"//pkg/sentry/sighandling",
"//pkg/sentry/socket/hostinet",
+ "//pkg/sentry/socket/netfilter",
"//pkg/sentry/socket/netlink",
"//pkg/sentry/socket/netlink/route",
"//pkg/sentry/socket/netlink/uevent",
diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go
index 894651519..fdf13c8e1 100644
--- a/runsc/boot/controller.go
+++ b/runsc/boot/controller.go
@@ -30,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/netstack"
"gvisor.dev/gvisor/pkg/sentry/state"
"gvisor.dev/gvisor/pkg/sentry/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sentry/watchdog"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/urpc"
@@ -195,7 +196,7 @@ type containerManager struct {
// StartRoot will start the root container process.
func (cm *containerManager) StartRoot(cid *string, _ *struct{}) error {
- log.Debugf("containerManager.StartRoot %q", *cid)
+ log.Debugf("containerManager.StartRoot, cid: %s", *cid)
// Tell the root container to start and wait for the result.
cm.startChan <- struct{}{}
if err := <-cm.startResultChan; err != nil {
@@ -206,13 +207,13 @@ func (cm *containerManager) StartRoot(cid *string, _ *struct{}) error {
// Processes retrieves information about processes running in the sandbox.
func (cm *containerManager) Processes(cid *string, out *[]*control.Process) error {
- log.Debugf("containerManager.Processes: %q", *cid)
+ log.Debugf("containerManager.Processes, cid: %s", *cid)
return control.Processes(cm.l.k, *cid, out)
}
// Create creates a container within a sandbox.
func (cm *containerManager) Create(cid *string, _ *struct{}) error {
- log.Debugf("containerManager.Create: %q", *cid)
+ log.Debugf("containerManager.Create, cid: %s", *cid)
return cm.l.createContainer(*cid)
}
@@ -236,12 +237,11 @@ type StartArgs struct {
// Start runs a created container within a sandbox.
func (cm *containerManager) Start(args *StartArgs, _ *struct{}) error {
- log.Debugf("containerManager.Start: %+v", args)
-
// Validate arguments.
if args == nil {
return errors.New("start missing arguments")
}
+ log.Debugf("containerManager.Start, cid: %s, args: %+v", args.CID, args)
if args.Spec == nil {
return errors.New("start arguments missing spec")
}
@@ -268,27 +268,27 @@ func (cm *containerManager) Start(args *StartArgs, _ *struct{}) error {
}
}()
if err := cm.l.startContainer(args.Spec, args.Conf, args.CID, fds); err != nil {
- log.Debugf("containerManager.Start failed %q: %+v: %v", args.CID, args, err)
+ log.Debugf("containerManager.Start failed, cid: %s, args: %+v, err: %v", args.CID, args, err)
return err
}
- log.Debugf("Container %q started", args.CID)
+ log.Debugf("Container started, cid: %s", args.CID)
return nil
}
// Destroy stops a container if it is still running and cleans up its
// filesystem.
func (cm *containerManager) Destroy(cid *string, _ *struct{}) error {
- log.Debugf("containerManager.destroy %q", *cid)
+ log.Debugf("containerManager.destroy, cid: %s", *cid)
return cm.l.destroyContainer(*cid)
}
// ExecuteAsync starts running a command on a created or running sandbox. It
// returns the PID of the new process.
func (cm *containerManager) ExecuteAsync(args *control.ExecArgs, pid *int32) error {
- log.Debugf("containerManager.ExecuteAsync: %+v", args)
+ log.Debugf("containerManager.ExecuteAsync, cid: %s, args: %+v", args.ContainerID, args)
tgid, err := cm.l.executeAsync(args)
if err != nil {
- log.Debugf("containerManager.ExecuteAsync failed: %+v: %v", args, err)
+ log.Debugf("containerManager.ExecuteAsync failed, cid: %s, args: %+v, err: %v", args.ContainerID, args, err)
return err
}
*pid = int32(tgid)
@@ -367,12 +367,20 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error {
cm.l.k = k
// Set up the restore environment.
+ ctx := k.SupervisorContext()
mntr := newContainerMounter(cm.l.root.spec, cm.l.root.goferFDs, cm.l.k, cm.l.mountHints)
- renv, err := mntr.createRestoreEnvironment(cm.l.root.conf)
- if err != nil {
- return fmt.Errorf("creating RestoreEnvironment: %v", err)
+ if kernel.VFS2Enabled {
+ ctx, err = mntr.configureRestore(ctx, cm.l.root.conf)
+ if err != nil {
+ return fmt.Errorf("configuring filesystem restore: %v", err)
+ }
+ } else {
+ renv, err := mntr.createRestoreEnvironment(cm.l.root.conf)
+ if err != nil {
+ return fmt.Errorf("creating RestoreEnvironment: %v", err)
+ }
+ fs.SetRestoreEnvironment(*renv)
}
- fs.SetRestoreEnvironment(*renv)
// Prepare to load from the state file.
if eps, ok := networkStack.(*netstack.Stack); ok {
@@ -399,7 +407,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error {
// Load the state.
loadOpts := state.LoadOpts{Source: specFile}
- if err := loadOpts.Load(k, networkStack, time.NewCalibratedClocks()); err != nil {
+ if err := loadOpts.Load(ctx, k, networkStack, time.NewCalibratedClocks(), &vfs.CompleteRestoreOptions{}); err != nil {
return err
}
@@ -444,9 +452,9 @@ func (cm *containerManager) Resume(_, _ *struct{}) error {
// Wait waits for the init process in the given container.
func (cm *containerManager) Wait(cid *string, waitStatus *uint32) error {
- log.Debugf("containerManager.Wait")
+ log.Debugf("containerManager.Wait, cid: %s", *cid)
err := cm.l.waitContainer(*cid, waitStatus)
- log.Debugf("containerManager.Wait returned, waitStatus: %v: %v", waitStatus, err)
+ log.Debugf("containerManager.Wait returned, cid: %s, waitStatus: %#x, err: %v", *cid, *waitStatus, err)
return err
}
@@ -461,8 +469,10 @@ type WaitPIDArgs struct {
// WaitPID waits for the process with PID 'pid' in the sandbox.
func (cm *containerManager) WaitPID(args *WaitPIDArgs, waitStatus *uint32) error {
- log.Debugf("containerManager.Wait")
- return cm.l.waitPID(kernel.ThreadID(args.PID), args.CID, waitStatus)
+ log.Debugf("containerManager.Wait, cid: %s, pid: %d", args.CID, args.PID)
+ err := cm.l.waitPID(kernel.ThreadID(args.PID), args.CID, waitStatus)
+ log.Debugf("containerManager.Wait, cid: %s, pid: %d, waitStatus: %#x, err: %v", args.CID, args.PID, *waitStatus, err)
+ return err
}
// SignalDeliveryMode enumerates different signal delivery modes.
@@ -519,6 +529,6 @@ type SignalArgs struct {
// indicated process, to all processes in the container, or to the foreground
// process group.
func (cm *containerManager) Signal(args *SignalArgs, _ *struct{}) error {
- log.Debugf("containerManager.Signal %+v", args)
+ log.Debugf("containerManager.Signal: cid: %s, PID: %d, signal: %d, mode: %v", args.CID, args.PID, args.Signo, args.Mode)
return cm.l.signal(args.CID, args.PID, args.Signo, args.Mode)
}
diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go
index ddf288456..6b6ae98d7 100644
--- a/runsc/boot/fs.go
+++ b/runsc/boot/fs.go
@@ -105,33 +105,28 @@ func addOverlay(ctx context.Context, conf *config.Config, lower *fs.Inode, name
// mandatory mounts that are required by the OCI specification.
func compileMounts(spec *specs.Spec) []specs.Mount {
// Keep track of whether proc and sys were mounted.
- var procMounted, sysMounted bool
+ var procMounted, sysMounted, devMounted, devptsMounted bool
var mounts []specs.Mount
- // Always mount /dev.
- mounts = append(mounts, specs.Mount{
- Type: devtmpfs.Name,
- Destination: "/dev",
- })
-
- mounts = append(mounts, specs.Mount{
- Type: devpts.Name,
- Destination: "/dev/pts",
- })
-
// Mount all submounts from the spec.
for _, m := range spec.Mounts {
if !specutils.IsSupportedDevMount(m) {
log.Warningf("ignoring dev mount at %q", m.Destination)
continue
}
- mounts = append(mounts, m)
switch filepath.Clean(m.Destination) {
case "/proc":
procMounted = true
case "/sys":
sysMounted = true
+ case "/dev":
+ m.Type = devtmpfs.Name
+ devMounted = true
+ case "/dev/pts":
+ m.Type = devpts.Name
+ devptsMounted = true
}
+ mounts = append(mounts, m)
}
// Mount proc and sys even if the user did not ask for it, as the spec
@@ -149,6 +144,18 @@ func compileMounts(spec *specs.Spec) []specs.Mount {
Destination: "/sys",
})
}
+ if !devMounted {
+ mandatoryMounts = append(mandatoryMounts, specs.Mount{
+ Type: devtmpfs.Name,
+ Destination: "/dev",
+ })
+ }
+ if !devptsMounted {
+ mandatoryMounts = append(mandatoryMounts, specs.Mount{
+ Type: devpts.Name,
+ Destination: "/dev/pts",
+ })
+ }
// The mandatory mounts should be ordered right after the root, in case
// there are submounts of these mandatory mounts already in the spec.
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index 8ad000497..ebdd518d0 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -35,6 +35,7 @@ import (
"gvisor.dev/gvisor/pkg/memutil"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/fdimport"
@@ -49,6 +50,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/sighandling"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netfilter"
"gvisor.dev/gvisor/pkg/sentry/syscalls/linux/vfs2"
"gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/sentry/usage"
@@ -476,6 +478,10 @@ func (l *Loader) Destroy() {
// save/restore.
l.k.Release()
+ // All sentry-created resources should have been released at this point;
+ // check for reference leaks.
+ refsvfs2.DoLeakCheck()
+
// In the success case, stdioFDs and goferFDs will only contain
// released/closed FDs that ownership has been passed over to host FDs and
// gofer sessions. Close them here in case of failure.
@@ -737,7 +743,7 @@ func (l *Loader) createContainerProcess(root bool, cid string, info *containerIn
return nil, err
}
- // Add the HOME enviroment variable if it is not already set.
+ // Add the HOME environment variable if it is not already set.
var envv []string
if kernel.VFS2Enabled {
envv, err = user.MaybeAddExecUserHomeVFS2(ctx, info.procArgs.MountNamespaceVFS2,
@@ -882,7 +888,7 @@ func (l *Loader) destroyContainer(cid string) error {
}
}
- log.Debugf("Container destroyed %q", cid)
+ log.Debugf("Container destroyed, cid: %s", cid)
return nil
}
@@ -1079,6 +1085,7 @@ func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (in
// privileges.
RawFactory: raw.EndpointFactory{},
UniqueID: uniqueID,
+ IPTables: netfilter.DefaultLinuxTables(),
})}
// Enable SACK Recovery.
diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go
index e376f944b..b77b4762e 100644
--- a/runsc/boot/loader_test.go
+++ b/runsc/boot/loader_test.go
@@ -266,7 +266,7 @@ type CreateMountTestcase struct {
func createMountTestcases() []*CreateMountTestcase {
testCases := []*CreateMountTestcase{
- &CreateMountTestcase{
+ {
// Only proc.
name: "only proc mount",
spec: specs.Spec{
@@ -304,11 +304,10 @@ func createMountTestcases() []*CreateMountTestcase {
},
},
},
- // /some/deep/path should be mounted, along with /proc,
- // /dev, and /sys.
+ // /some/deep/path should be mounted, along with /proc, /dev, and /sys.
expectedPaths: []string{"/some/very/very/deep/path", "/proc", "/dev", "/sys"},
},
- &CreateMountTestcase{
+ {
// Mounts are nested inside each other.
name: "nested mounts",
spec: specs.Spec{
@@ -352,7 +351,7 @@ func createMountTestcases() []*CreateMountTestcase {
expectedPaths: []string{"/foo", "/foo/bar", "/foo/bar/baz", "/foo/qux",
"/foo/qux-quz", "/foo/some/very/very/deep/path", "/proc", "/dev", "/sys"},
},
- &CreateMountTestcase{
+ {
name: "mount inside /dev",
spec: specs.Spec{
Root: &specs.Root{
@@ -395,35 +394,37 @@ func createMountTestcases() []*CreateMountTestcase {
},
expectedPaths: []string{"/proc", "/dev", "/dev/fd-foo", "/dev/foo", "/dev/bar", "/sys"},
},
- }
-
- vfsCase := &CreateMountTestcase{
- name: "mounts inside mandatory mounts",
- spec: specs.Spec{
- Root: &specs.Root{
- Path: os.TempDir(),
- Readonly: true,
- },
- Mounts: []specs.Mount{
- {
- Destination: "/proc",
- Type: "tmpfs",
- },
- {
- Destination: "/sys/bar",
- Type: "tmpfs",
+ {
+ name: "mounts inside mandatory mounts",
+ spec: specs.Spec{
+ Root: &specs.Root{
+ Path: os.TempDir(),
+ Readonly: true,
},
-
- {
- Destination: "/tmp/baz",
- Type: "tmpfs",
+ Mounts: []specs.Mount{
+ {
+ Destination: "/proc",
+ Type: "tmpfs",
+ },
+ {
+ Destination: "/sys/bar",
+ Type: "tmpfs",
+ },
+ {
+ Destination: "/tmp/baz",
+ Type: "tmpfs",
+ },
+ {
+ Destination: "/dev/goo",
+ Type: "tmpfs",
+ },
},
},
+ expectedPaths: []string{"/proc", "/sys", "/sys/bar", "/tmp", "/tmp/baz", "/dev/goo"},
},
- expectedPaths: []string{"/proc", "/sys", "/sys/bar", "/tmp", "/tmp/baz"},
}
- return append(testCases, vfsCase)
+ return testCases
}
// Test that MountNamespace can be created with various specs.
diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go
index 004da5b40..b157387ef 100644
--- a/runsc/boot/vfs.go
+++ b/runsc/boot/vfs.go
@@ -210,6 +210,9 @@ func (c *containerMounter) createMountNamespaceVFS2(ctx context.Context, conf *c
ReadOnly: c.root.Readonly,
GetFilesystemOptions: vfs.GetFilesystemOptions{
Data: strings.Join(data, ","),
+ InternalData: gofer.InternalFilesystemOptions{
+ UniqueID: "/",
+ },
},
InternalMount: true,
}
@@ -427,6 +430,7 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo
fsName := m.Type
useOverlay := false
var data []string
+ var iopts interface{}
// Find filesystem name and FS specific data field.
switch m.Type {
@@ -451,6 +455,9 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo
return "", nil, false, fmt.Errorf("9P mount requires a connection FD")
}
data = p9MountData(m.fd, c.getMountAccessType(m.Mount), true /* vfs2 */)
+ iopts = gofer.InternalFilesystemOptions{
+ UniqueID: m.Destination,
+ }
// If configured, add overlay to all writable mounts.
useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly
@@ -462,7 +469,8 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo
opts := &vfs.MountOptions{
GetFilesystemOptions: vfs.GetFilesystemOptions{
- Data: strings.Join(data, ","),
+ Data: strings.Join(data, ","),
+ InternalData: iopts,
},
InternalMount: true,
}
@@ -667,3 +675,21 @@ func (c *containerMounter) makeMountPoint(ctx context.Context, creds *auth.Crede
}
return c.k.VFS().MakeSyntheticMountpoint(ctx, dest, root, creds)
}
+
+// configureRestore returns an updated context.Context including filesystem
+// state used by restore defined by conf.
+func (c *containerMounter) configureRestore(ctx context.Context, conf *config.Config) (context.Context, error) {
+ fdmap := make(map[string]int)
+ fdmap["/"] = c.fds.remove()
+ mounts, err := c.prepareMountsVFS2()
+ if err != nil {
+ return ctx, err
+ }
+ for i := range c.mounts {
+ submount := &mounts[i]
+ if submount.fd >= 0 {
+ fdmap[submount.Destination] = submount.fd
+ }
+ }
+ return context.WithValue(ctx, gofer.CtxRestoreServerFDMap, fdmap), nil
+}
diff --git a/runsc/cgroup/cgroup.go b/runsc/cgroup/cgroup.go
index 56da21584..5bd0afc52 100644
--- a/runsc/cgroup/cgroup.go
+++ b/runsc/cgroup/cgroup.go
@@ -21,6 +21,7 @@ import (
"context"
"errors"
"fmt"
+ "io"
"io/ioutil"
"os"
"path/filepath"
@@ -198,8 +199,13 @@ func LoadPaths(pid string) (map[string]string, error) {
}
defer f.Close()
+ return loadPathsHelper(f)
+}
+
+func loadPathsHelper(cgroup io.Reader) (map[string]string, error) {
paths := make(map[string]string)
- scanner := bufio.NewScanner(f)
+
+ scanner := bufio.NewScanner(cgroup)
for scanner.Scan() {
// Format: ID:[name=]controller1,controller2:path
// Example: 2:cpu,cpuacct:/user.slice
@@ -207,6 +213,9 @@ func LoadPaths(pid string) (map[string]string, error) {
if len(tokens) != 3 {
return nil, fmt.Errorf("invalid cgroups file, line: %q", scanner.Text())
}
+ if len(tokens[1]) == 0 {
+ continue
+ }
for _, ctrlr := range strings.Split(tokens[1], ",") {
// Remove prefix for cgroups with no controller, eg. systemd.
ctrlr = strings.TrimPrefix(ctrlr, "name=")
diff --git a/runsc/cgroup/cgroup_test.go b/runsc/cgroup/cgroup_test.go
index 4db5ee5c3..9794517a7 100644
--- a/runsc/cgroup/cgroup_test.go
+++ b/runsc/cgroup/cgroup_test.go
@@ -647,3 +647,83 @@ func TestPids(t *testing.T) {
})
}
}
+
+func TestLoadPaths(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ cgroups string
+ want map[string]string
+ err string
+ }{
+ {
+ name: "abs-path",
+ cgroups: "0:ctr:/path",
+ want: map[string]string{"ctr": "/path"},
+ },
+ {
+ name: "rel-path",
+ cgroups: "0:ctr:rel-path",
+ want: map[string]string{"ctr": "rel-path"},
+ },
+ {
+ name: "non-controller",
+ cgroups: "0:name=systemd:/path",
+ want: map[string]string{"systemd": "/path"},
+ },
+ {
+ name: "empty",
+ },
+ {
+ name: "multiple",
+ cgroups: "0:ctr0:/path0\n" +
+ "1:ctr1:/path1\n" +
+ "2::/empty\n",
+ want: map[string]string{
+ "ctr0": "/path0",
+ "ctr1": "/path1",
+ },
+ },
+ {
+ name: "missing-field",
+ cgroups: "0:nopath\n",
+ err: "invalid cgroups file",
+ },
+ {
+ name: "too-many-fields",
+ cgroups: "0:ctr:/path:extra\n",
+ err: "invalid cgroups file",
+ },
+ {
+ name: "multiple-malformed",
+ cgroups: "0:ctr0:/path0\n" +
+ "1:ctr1:/path1\n" +
+ "2:\n",
+ err: "invalid cgroups file",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ r := strings.NewReader(tc.cgroups)
+ got, err := loadPathsHelper(r)
+ if len(tc.err) == 0 {
+ if err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+ } else if !strings.Contains(err.Error(), tc.err) {
+ t.Fatalf("Wrong error message, want: *%s*, got: %v", tc.err, err)
+ }
+ for key, vWant := range tc.want {
+ vGot, ok := got[key]
+ if !ok {
+ t.Errorf("Missing controller %q", key)
+ }
+ if vWant != vGot {
+ t.Errorf("Wrong controller %q value, want: %q, got: %q", key, vWant, vGot)
+ }
+ delete(got, key)
+ }
+ for k, v := range got {
+ t.Errorf("Unexpected controller %q: %q", k, v)
+ }
+ })
+ }
+}
diff --git a/runsc/cmd/boot.go b/runsc/cmd/boot.go
index cd419e1aa..2c92e3067 100644
--- a/runsc/cmd/boot.go
+++ b/runsc/cmd/boot.go
@@ -131,11 +131,11 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
return subcommands.ExitUsageError
}
- // Ensure that if there is a panic, all goroutine stacks are printed.
- debug.SetTraceback("system")
-
conf := args[0].(*config.Config)
+ // Set traceback level
+ debug.SetTraceback(conf.Traceback)
+
if b.attached {
// Ensure this process is killed after parent process terminates when
// attached mode is enabled. In the unfortunate event that the parent
diff --git a/runsc/cmd/checkpoint.go b/runsc/cmd/checkpoint.go
index 8fe0c427a..c0bc8f064 100644
--- a/runsc/cmd/checkpoint.go
+++ b/runsc/cmd/checkpoint.go
@@ -75,7 +75,7 @@ func (c *Checkpoint) Execute(_ context.Context, f *flag.FlagSet, args ...interfa
conf := args[0].(*config.Config)
waitStatus := args[1].(*syscall.WaitStatus)
- cont, err := container.Load(conf.RootDir, id)
+ cont, err := container.LoadAndCheck(conf.RootDir, id)
if err != nil {
Fatalf("loading container: %v", err)
}
@@ -149,6 +149,9 @@ func (c *Checkpoint) Execute(_ context.Context, f *flag.FlagSet, args ...interfa
}
ws, err := cont.Wait()
+ if err != nil {
+ Fatalf("Error waiting for container: %v", err)
+ }
*waitStatus = ws
return subcommands.ExitSuccess
diff --git a/runsc/cmd/debug.go b/runsc/cmd/debug.go
index 132198222..609e8231c 100644
--- a/runsc/cmd/debug.go
+++ b/runsc/cmd/debug.go
@@ -91,7 +91,7 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
return subcommands.ExitUsageError
}
var err error
- c, err = container.Load(conf.RootDir, f.Arg(0))
+ c, err = container.LoadAndCheck(conf.RootDir, f.Arg(0))
if err != nil {
return Errorf("loading container %q: %v", f.Arg(0), err)
}
@@ -106,7 +106,7 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
return Errorf("listing containers: %v", err)
}
for _, id := range ids {
- candidate, err := container.Load(conf.RootDir, id)
+ candidate, err := container.LoadAndCheck(conf.RootDir, id)
if err != nil {
return Errorf("loading container %q: %v", id, err)
}
diff --git a/runsc/cmd/delete.go b/runsc/cmd/delete.go
index 4e49deff8..a25637265 100644
--- a/runsc/cmd/delete.go
+++ b/runsc/cmd/delete.go
@@ -68,7 +68,7 @@ func (d *Delete) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}
func (d *Delete) execute(ids []string, conf *config.Config) error {
for _, id := range ids {
- c, err := container.Load(conf.RootDir, id)
+ c, err := container.LoadAndCheck(conf.RootDir, id)
if err != nil {
if os.IsNotExist(err) && d.force {
log.Warningf("couldn't find container %q: %v", id, err)
diff --git a/runsc/cmd/events.go b/runsc/cmd/events.go
index 25fe2cf1c..3836b7b4e 100644
--- a/runsc/cmd/events.go
+++ b/runsc/cmd/events.go
@@ -74,7 +74,7 @@ func (evs *Events) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa
id := f.Arg(0)
conf := args[0].(*config.Config)
- c, err := container.Load(conf.RootDir, id)
+ c, err := container.LoadAndCheck(conf.RootDir, id)
if err != nil {
Fatalf("loading sandbox: %v", err)
}
@@ -85,7 +85,12 @@ func (evs *Events) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa
ev, err := c.Event()
if err != nil {
log.Warningf("Error getting events for container: %v", err)
+ if evs.stats {
+ return subcommands.ExitFailure
+ }
}
+ log.Debugf("Events: %+v", ev)
+
// err must be preserved because it is used below when breaking
// out of the loop.
b, err := json.Marshal(ev)
@@ -101,11 +106,9 @@ func (evs *Events) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa
if err != nil {
return subcommands.ExitFailure
}
- break
+ return subcommands.ExitSuccess
}
time.Sleep(time.Duration(evs.intervalSec) * time.Second)
}
-
- return subcommands.ExitSuccess
}
diff --git a/runsc/cmd/exec.go b/runsc/cmd/exec.go
index 775ed4b43..86c02a22a 100644
--- a/runsc/cmd/exec.go
+++ b/runsc/cmd/exec.go
@@ -112,7 +112,7 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
}
waitStatus := args[1].(*syscall.WaitStatus)
- c, err := container.Load(conf.RootDir, id)
+ c, err := container.LoadAndCheck(conf.RootDir, id)
if err != nil {
Fatalf("loading sandbox: %v", err)
}
diff --git a/runsc/cmd/kill.go b/runsc/cmd/kill.go
index 04eee99b2..fe69e2a08 100644
--- a/runsc/cmd/kill.go
+++ b/runsc/cmd/kill.go
@@ -69,7 +69,7 @@ func (k *Kill) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
Fatalf("it is invalid to specify both --all and --pid")
}
- c, err := container.Load(conf.RootDir, id)
+ c, err := container.LoadAndCheck(conf.RootDir, id)
if err != nil {
Fatalf("loading container: %v", err)
}
diff --git a/runsc/cmd/list.go b/runsc/cmd/list.go
index f92d6fef9..6907eb16a 100644
--- a/runsc/cmd/list.go
+++ b/runsc/cmd/list.go
@@ -79,7 +79,7 @@ func (l *List) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
// Collect the containers.
var containers []*container.Container
for _, id := range ids {
- c, err := container.Load(conf.RootDir, id)
+ c, err := container.LoadAndCheck(conf.RootDir, id)
if err != nil {
Fatalf("loading container %q: %v", id, err)
}
diff --git a/runsc/cmd/pause.go b/runsc/cmd/pause.go
index 0eb1402ed..fe7d4e257 100644
--- a/runsc/cmd/pause.go
+++ b/runsc/cmd/pause.go
@@ -55,7 +55,7 @@ func (*Pause) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s
id := f.Arg(0)
conf := args[0].(*config.Config)
- cont, err := container.Load(conf.RootDir, id)
+ cont, err := container.LoadAndCheck(conf.RootDir, id)
if err != nil {
Fatalf("loading container: %v", err)
}
diff --git a/runsc/cmd/ps.go b/runsc/cmd/ps.go
index bc58c928f..18d7a1436 100644
--- a/runsc/cmd/ps.go
+++ b/runsc/cmd/ps.go
@@ -60,7 +60,7 @@ func (ps *PS) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{})
id := f.Arg(0)
conf := args[0].(*config.Config)
- c, err := container.Load(conf.RootDir, id)
+ c, err := container.LoadAndCheck(conf.RootDir, id)
if err != nil {
Fatalf("loading sandbox: %v", err)
}
diff --git a/runsc/cmd/resume.go b/runsc/cmd/resume.go
index f24823f99..a00928204 100644
--- a/runsc/cmd/resume.go
+++ b/runsc/cmd/resume.go
@@ -56,7 +56,7 @@ func (r *Resume) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}
id := f.Arg(0)
conf := args[0].(*config.Config)
- cont, err := container.Load(conf.RootDir, id)
+ cont, err := container.LoadAndCheck(conf.RootDir, id)
if err != nil {
Fatalf("loading container: %v", err)
}
diff --git a/runsc/cmd/start.go b/runsc/cmd/start.go
index 88991b521..f6499cc44 100644
--- a/runsc/cmd/start.go
+++ b/runsc/cmd/start.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/runsc/config"
"gvisor.dev/gvisor/runsc/container"
"gvisor.dev/gvisor/runsc/flag"
+ "gvisor.dev/gvisor/runsc/specutils"
)
// Start implements subcommands.Command for the "start" command.
@@ -54,10 +55,16 @@ func (*Start) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s
id := f.Arg(0)
conf := args[0].(*config.Config)
- c, err := container.Load(conf.RootDir, id)
+ c, err := container.LoadAndCheck(conf.RootDir, id)
if err != nil {
Fatalf("loading container: %v", err)
}
+ // Read the spec again here to ensure flag annotations from the spec are
+ // applied to "conf".
+ if _, err := specutils.ReadSpec(c.BundleDir, conf); err != nil {
+ Fatalf("reading spec: %v", err)
+ }
+
if err := c.Start(conf); err != nil {
Fatalf("starting container: %v", err)
}
diff --git a/runsc/cmd/state.go b/runsc/cmd/state.go
index 2bd2ab9f8..d8a70dd7f 100644
--- a/runsc/cmd/state.go
+++ b/runsc/cmd/state.go
@@ -57,7 +57,7 @@ func (*State) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s
id := f.Arg(0)
conf := args[0].(*config.Config)
- c, err := container.Load(conf.RootDir, id)
+ c, err := container.LoadAndCheck(conf.RootDir, id)
if err != nil {
Fatalf("loading container: %v", err)
}
diff --git a/runsc/cmd/wait.go b/runsc/cmd/wait.go
index 28d0642ed..c1d6aeae2 100644
--- a/runsc/cmd/wait.go
+++ b/runsc/cmd/wait.go
@@ -72,7 +72,7 @@ func (wt *Wait) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
id := f.Arg(0)
conf := args[0].(*config.Config)
- c, err := container.Load(conf.RootDir, id)
+ c, err := container.LoadAndCheck(conf.RootDir, id)
if err != nil {
Fatalf("loading container: %v", err)
}
diff --git a/runsc/config/config.go b/runsc/config/config.go
index f30f79f68..b02d8e2e1 100644
--- a/runsc/config/config.go
+++ b/runsc/config/config.go
@@ -37,6 +37,9 @@ type Config struct {
// RootDir is the runtime root directory.
RootDir string `flag:"root"`
+ // Traceback changes the Go runtime's traceback level.
+ Traceback string `flag:"traceback"`
+
// Debug indicates that debug logging should be enabled.
Debug bool `flag:"debug"`
diff --git a/runsc/config/flags.go b/runsc/config/flags.go
index a5f25cfa2..13d8f1b25 100644
--- a/runsc/config/flags.go
+++ b/runsc/config/flags.go
@@ -29,7 +29,7 @@ import (
var registration sync.Once
-// This is the set of flags used to populate Config.
+// RegisterFlags registers flags used to populate Config.
func RegisterFlags() {
registration.Do(func() {
// Although these flags are not part of the OCI spec, they are used by
@@ -49,6 +49,7 @@ func RegisterFlags() {
flag.String("debug-log-format", "text", "log format: text (default), json, or json-k8s.")
flag.Bool("alsologtostderr", false, "send log messages to stderr.")
flag.Bool("allow-flag-override", false, "allow OCI annotations (dev.gvisor.flag.<name>) to override flags for debugging.")
+ flag.String("traceback", "system", "golang runtime's traceback level")
// Debugging flags: strace related
flag.Bool("strace", false, "enable strace.")
diff --git a/runsc/container/container.go b/runsc/container/container.go
index 63f64ce6e..4aa139c88 100644
--- a/runsc/container/container.go
+++ b/runsc/container/container.go
@@ -159,9 +159,9 @@ func loadSandbox(rootDir, id string) ([]*Container, error) {
// container to which id unambiguously refers to. Returns ErrNotExist if
// container doesn't exist.
func Load(rootDir, partialID string) (*Container, error) {
- log.Debugf("Load container %q %q", rootDir, partialID)
+ log.Debugf("Load container, rootDir: %q, partial cid: %s", rootDir, partialID)
if err := validateID(partialID); err != nil {
- return nil, fmt.Errorf("validating id: %v", err)
+ return nil, fmt.Errorf("invalid container id: %v", err)
}
id, err := findContainerID(rootDir, partialID)
@@ -184,22 +184,31 @@ func Load(rootDir, partialID string) (*Container, error) {
}
return nil, fmt.Errorf("reading container metadata file %q: %v", state.statePath(), err)
}
+ return c, nil
+}
+
+// LoadAndCheck is similar to Load(), but also checks if the container is still
+// running to get an error earlier to the caller.
+func LoadAndCheck(rootDir, partialID string) (*Container, error) {
+ c, err := Load(rootDir, partialID)
+ if err != nil {
+ // Preserve error so that callers can distinguish 'not found' errors.
+ return nil, err
+ }
- // If the status is "Running" or "Created", check that the sandbox
- // process still exists, and set it to Stopped if it does not.
+ // If the status is "Running" or "Created", check that the sandbox/container
+ // is still running, setting it to Stopped if not.
//
// This is inherently racy.
- if c.Status == Running || c.Status == Created {
- // Check if the sandbox process is still running.
+ switch c.Status {
+ case Created:
if !c.isSandboxRunning() {
// Sandbox no longer exists, so this container definitely does not exist.
c.changeStatus(Stopped)
- } else if c.Status == Running {
- // Container state should reflect the actual state of the application, so
- // we don't consider gofer process here.
- if err := c.SignalContainer(syscall.Signal(0), false); err != nil {
- c.changeStatus(Stopped)
- }
+ }
+ case Running:
+ if err := c.SignalContainer(syscall.Signal(0), false); err != nil {
+ c.changeStatus(Stopped)
}
}
@@ -271,7 +280,7 @@ type Args struct {
// indicates that an existing Sandbox should be used. The caller must call
// Destroy() on the container.
func New(conf *config.Config, args Args) (*Container, error) {
- log.Debugf("Create container %q in root dir: %s", args.ID, conf.RootDir)
+ log.Debugf("Create container, cid: %s, rootDir: %q", args.ID, conf.RootDir)
if err := validateID(args.ID); err != nil {
return nil, err
}
@@ -310,7 +319,15 @@ func New(conf *config.Config, args Args) (*Container, error) {
// indicate the ID of the sandbox, which is the same as the ID of the
// init container in the sandbox.
if isRoot(args.Spec) {
- log.Debugf("Creating new sandbox for container %q", args.ID)
+ log.Debugf("Creating new sandbox for container, cid: %s", args.ID)
+
+ if args.Spec.Linux == nil {
+ args.Spec.Linux = &specs.Linux{}
+ }
+ // Don't force the use of cgroups in tests because they lack permission to do so.
+ if args.Spec.Linux.CgroupsPath == "" && !conf.TestOnlyAllowRunAsCurrentUserWithoutChroot {
+ args.Spec.Linux.CgroupsPath = "/" + args.ID
+ }
// Create and join cgroup before processes are created to ensure they are
// part of the cgroup from the start (and all their children processes).
@@ -321,7 +338,13 @@ func New(conf *config.Config, args Args) (*Container, error) {
if cg != nil {
// If there is cgroup config, install it before creating sandbox process.
if err := cg.Install(args.Spec.Linux.Resources); err != nil {
- return nil, fmt.Errorf("configuring cgroup: %v", err)
+ switch {
+ case errors.Is(err, syscall.EACCES) && conf.Rootless:
+ log.Warningf("Skipping cgroup configuration in rootless mode: %v", err)
+ cg = nil
+ default:
+ return nil, fmt.Errorf("configuring cgroup: %v", err)
+ }
}
}
if err := runInCgroup(cg, func() error {
@@ -366,10 +389,10 @@ func New(conf *config.Config, args Args) (*Container, error) {
if !ok {
return nil, fmt.Errorf("no sandbox ID found when creating container")
}
- log.Debugf("Creating new container %q in sandbox %q", c.ID, sbid)
+ log.Debugf("Creating new container, cid: %s, sandbox: %s", c.ID, sbid)
// Find the sandbox associated with this ID.
- sb, err := Load(conf.RootDir, sbid)
+ sb, err := LoadAndCheck(conf.RootDir, sbid)
if err != nil {
return nil, err
}
@@ -399,7 +422,7 @@ func New(conf *config.Config, args Args) (*Container, error) {
// Start starts running the containerized process inside the sandbox.
func (c *Container) Start(conf *config.Config) error {
- log.Debugf("Start container %q", c.ID)
+ log.Debugf("Start container, cid: %s", c.ID)
if err := c.Saver.lock(); err != nil {
return err
@@ -462,7 +485,7 @@ func (c *Container) Start(conf *config.Config) error {
unlock.Clean()
// Adjust the oom_score_adj for sandbox. This must be done after saveLocked().
- if err := adjustSandboxOOMScoreAdj(c.Sandbox, c.Saver.RootDir, false); err != nil {
+ if err := adjustSandboxOOMScoreAdj(c.Sandbox, c.Spec, c.Saver.RootDir, false); err != nil {
return err
}
@@ -474,7 +497,7 @@ func (c *Container) Start(conf *config.Config) error {
// Restore takes a container and replaces its kernel and file system
// to restore a container from its state file.
func (c *Container) Restore(spec *specs.Spec, conf *config.Config, restoreFile string) error {
- log.Debugf("Restore container %q", c.ID)
+ log.Debugf("Restore container, cid: %s", c.ID)
if err := c.Saver.lock(); err != nil {
return err
}
@@ -501,7 +524,7 @@ func (c *Container) Restore(spec *specs.Spec, conf *config.Config, restoreFile s
// Run is a helper that calls Create + Start + Wait.
func Run(conf *config.Config, args Args) (syscall.WaitStatus, error) {
- log.Debugf("Run container %q in root dir: %s", args.ID, conf.RootDir)
+ log.Debugf("Run container, cid: %s, rootDir: %q", args.ID, conf.RootDir)
c, err := New(conf, args)
if err != nil {
return 0, fmt.Errorf("creating container: %v", err)
@@ -533,7 +556,7 @@ func Run(conf *config.Config, args Args) (syscall.WaitStatus, error) {
// Execute runs the specified command in the container. It returns the PID of
// the newly created process.
func (c *Container) Execute(args *control.ExecArgs) (int32, error) {
- log.Debugf("Execute in container %q, args: %+v", c.ID, args)
+ log.Debugf("Execute in container, cid: %s, args: %+v", c.ID, args)
if err := c.requireStatus("execute in", Created, Running); err != nil {
return 0, err
}
@@ -543,7 +566,7 @@ func (c *Container) Execute(args *control.ExecArgs) (int32, error) {
// Event returns events for the container.
func (c *Container) Event() (*boot.Event, error) {
- log.Debugf("Getting events for container %q", c.ID)
+ log.Debugf("Getting events for container, cid: %s", c.ID)
if err := c.requireStatus("get events for", Created, Running, Paused); err != nil {
return nil, err
}
@@ -563,14 +586,19 @@ func (c *Container) SandboxPid() int {
// Call to wait on a stopped container is needed to retrieve the exit status
// and wait returns immediately.
func (c *Container) Wait() (syscall.WaitStatus, error) {
- log.Debugf("Wait on container %q", c.ID)
- return c.Sandbox.Wait(c.ID)
+ log.Debugf("Wait on container, cid: %s", c.ID)
+ ws, err := c.Sandbox.Wait(c.ID)
+ if err == nil {
+ // Wait succeeded, container is not running anymore.
+ c.changeStatus(Stopped)
+ }
+ return ws, err
}
// WaitRootPID waits for process 'pid' in the sandbox's PID namespace and
// returns its WaitStatus.
func (c *Container) WaitRootPID(pid int32) (syscall.WaitStatus, error) {
- log.Debugf("Wait on PID %d in sandbox %q", pid, c.Sandbox.ID)
+ log.Debugf("Wait on process %d in sandbox, cid: %s", pid, c.Sandbox.ID)
if !c.isSandboxRunning() {
return 0, fmt.Errorf("sandbox is not running")
}
@@ -580,7 +608,7 @@ func (c *Container) WaitRootPID(pid int32) (syscall.WaitStatus, error) {
// WaitPID waits for process 'pid' in the container's PID namespace and returns
// its WaitStatus.
func (c *Container) WaitPID(pid int32) (syscall.WaitStatus, error) {
- log.Debugf("Wait on PID %d in container %q", pid, c.ID)
+ log.Debugf("Wait on process %d in container, cid: %s", pid, c.ID)
if !c.isSandboxRunning() {
return 0, fmt.Errorf("sandbox is not running")
}
@@ -592,7 +620,7 @@ func (c *Container) WaitPID(pid int32) (syscall.WaitStatus, error) {
// SignalContainer returns an error if the container is already stopped.
// TODO(b/113680494): Distinguish different error types.
func (c *Container) SignalContainer(sig syscall.Signal, all bool) error {
- log.Debugf("Signal container %q: %v", c.ID, sig)
+ log.Debugf("Signal container, cid: %s, signal: %v (%d)", c.ID, sig, sig)
// Signaling container in Stopped state is allowed. When all=false,
// an error will be returned anyway; when all=true, this allows
// sending signal to other processes inside the container even
@@ -609,7 +637,7 @@ func (c *Container) SignalContainer(sig syscall.Signal, all bool) error {
// SignalProcess sends sig to a specific process in the container.
func (c *Container) SignalProcess(sig syscall.Signal, pid int32) error {
- log.Debugf("Signal process %d in container %q: %v", pid, c.ID, sig)
+ log.Debugf("Signal process %d in container, cid: %s, signal: %v (%d)", pid, c.ID, sig, sig)
if err := c.requireStatus("signal a process inside", Running); err != nil {
return err
}
@@ -623,15 +651,15 @@ func (c *Container) SignalProcess(sig syscall.Signal, pid int32) error {
// container process inside the sandbox. It returns a function that will stop
// forwarding signals.
func (c *Container) ForwardSignals(pid int32, fgProcess bool) func() {
- log.Debugf("Forwarding all signals to container %q PID %d fgProcess=%t", c.ID, pid, fgProcess)
+ log.Debugf("Forwarding all signals to container, cid: %s, PIDPID: %d, fgProcess: %t", c.ID, pid, fgProcess)
stop := sighandling.StartSignalForwarding(func(sig linux.Signal) {
- log.Debugf("Forwarding signal %d to container %q PID %d fgProcess=%t", sig, c.ID, pid, fgProcess)
+ log.Debugf("Forwarding signal %d to container, cid: %s, PID: %d, fgProcess: %t", sig, c.ID, pid, fgProcess)
if err := c.Sandbox.SignalProcess(c.ID, pid, syscall.Signal(sig), fgProcess); err != nil {
log.Warningf("error forwarding signal %d to container %q: %v", sig, c.ID, err)
}
})
return func() {
- log.Debugf("Done forwarding signals to container %q PID %d fgProcess=%t", c.ID, pid, fgProcess)
+ log.Debugf("Done forwarding signals to container, cid: %s, PID: %d, fgProcess: %t", c.ID, pid, fgProcess)
stop()
}
}
@@ -639,7 +667,7 @@ func (c *Container) ForwardSignals(pid int32, fgProcess bool) func() {
// Checkpoint sends the checkpoint call to the container.
// The statefile will be written to f, the file at the specified image-path.
func (c *Container) Checkpoint(f *os.File) error {
- log.Debugf("Checkpoint container %q", c.ID)
+ log.Debugf("Checkpoint container, cid: %s", c.ID)
if err := c.requireStatus("checkpoint", Created, Running, Paused); err != nil {
return err
}
@@ -649,7 +677,7 @@ func (c *Container) Checkpoint(f *os.File) error {
// Pause suspends the container and its kernel.
// The call only succeeds if the container's status is created or running.
func (c *Container) Pause() error {
- log.Debugf("Pausing container %q", c.ID)
+ log.Debugf("Pausing container, cid: %s", c.ID)
if err := c.Saver.lock(); err != nil {
return err
}
@@ -660,7 +688,7 @@ func (c *Container) Pause() error {
}
if err := c.Sandbox.Pause(c.ID); err != nil {
- return fmt.Errorf("pausing container: %v", err)
+ return fmt.Errorf("pausing container %q: %v", c.ID, err)
}
c.changeStatus(Paused)
return c.saveLocked()
@@ -669,7 +697,7 @@ func (c *Container) Pause() error {
// Resume unpauses the container and its kernel.
// The call only succeeds if the container's status is paused.
func (c *Container) Resume() error {
- log.Debugf("Resuming container %q", c.ID)
+ log.Debugf("Resuming container, cid: %s", c.ID)
if err := c.Saver.lock(); err != nil {
return err
}
@@ -708,7 +736,7 @@ func (c *Container) Processes() ([]*control.Process, error) {
// Destroy stops all processes and frees all resources associated with the
// container.
func (c *Container) Destroy() error {
- log.Debugf("Destroy container %q", c.ID)
+ log.Debugf("Destroy container, cid: %s", c.ID)
if err := c.Saver.lock(); err != nil {
return err
@@ -745,14 +773,12 @@ func (c *Container) Destroy() error {
c.changeStatus(Stopped)
// Adjust oom_score_adj for the sandbox. This must be done after the container
- // is stopped and the directory at c.Root is removed. Adjustment can be
- // skipped if the root container is exiting, because it brings down the entire
- // sandbox.
+ // is stopped and the directory at c.Root is removed.
//
// Use 'sb' to tell whether it has been executed before because Destroy must
// be idempotent.
- if sb != nil && !isRoot(c.Spec) {
- if err := adjustSandboxOOMScoreAdj(sb, c.Saver.RootDir, true); err != nil {
+ if sb != nil {
+ if err := adjustSandboxOOMScoreAdj(sb, c.Spec, c.Saver.RootDir, true); err != nil {
errs = append(errs, err.Error())
}
}
@@ -781,7 +807,7 @@ func (c *Container) Destroy() error {
//
// Precondition: container must be locked with container.lock().
func (c *Container) saveLocked() error {
- log.Debugf("Save container %q", c.ID)
+ log.Debugf("Save container, cid: %s", c.ID)
if err := c.Saver.saveLocked(c); err != nil {
return fmt.Errorf("saving container metadata: %v", err)
}
@@ -795,7 +821,7 @@ func (c *Container) stop() error {
var cgroup *cgroup.Cgroup
if c.Sandbox != nil {
- log.Debugf("Destroying container %q", c.ID)
+ log.Debugf("Destroying container, cid: %s", c.ID)
if err := c.Sandbox.DestroyContainer(c.ID); err != nil {
return fmt.Errorf("destroying container %q: %v", c.ID, err)
}
@@ -809,7 +835,7 @@ func (c *Container) stop() error {
// Try killing gofer if it does not exit with container.
if c.GoferPid != 0 {
- log.Debugf("Killing gofer for container %q, PID: %d", c.ID, c.GoferPid)
+ log.Debugf("Killing gofer for container, cid: %s, PID: %d", c.ID, c.GoferPid)
if err := syscall.Kill(c.GoferPid, syscall.SIGKILL); err != nil {
// The gofer may already be stopped, log the error.
log.Warningf("Error sending signal %d to gofer %d: %v", syscall.SIGKILL, c.GoferPid, err)
@@ -1082,7 +1108,13 @@ func (c *Container) adjustGoferOOMScoreAdj() error {
// TODO(gvisor.dev/issue/238): This call could race with other containers being
// created at the same time and end up setting the wrong oom_score_adj to the
// sandbox. Use rpc client to synchronize.
-func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool) error {
+func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, spec *specs.Spec, rootDir string, destroy bool) error {
+ // Adjustment can be skipped if the root container is exiting, because it
+ // brings down the entire sandbox.
+ if isRoot(spec) && destroy {
+ return nil
+ }
+
containers, err := loadSandbox(rootDir, s.ID)
if err != nil {
return fmt.Errorf("loading sandbox containers: %v", err)
@@ -1096,53 +1128,34 @@ func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool)
// Get the lowest score for all containers.
var lowScore int
scoreFound := false
- if len(containers) == 1 && specutils.SpecContainerType(containers[0].Spec) == specutils.ContainerTypeUnspecified {
- // This is a single-container sandbox. Set the oom_score_adj to
- // the value specified in the OCI bundle.
- if containers[0].Spec.Process.OOMScoreAdj != nil {
- scoreFound = true
- lowScore = *containers[0].Spec.Process.OOMScoreAdj
+ for _, container := range containers {
+ // Special multi-container support for CRI. Ignore the root container when
+ // calculating oom_score_adj for the sandbox because it is the
+ // infrastructure (pause) container and always has a very low oom_score_adj.
+ //
+ // We will use OOMScoreAdj in the single-container case where the
+ // containerd container-type annotation is not present.
+ if specutils.SpecContainerType(container.Spec) == specutils.ContainerTypeSandbox {
+ continue
}
- } else {
- for _, container := range containers {
- // Special multi-container support for CRI. Ignore the root
- // container when calculating oom_score_adj for the sandbox because
- // it is the infrastructure (pause) container and always has a very
- // low oom_score_adj.
- //
- // We will use OOMScoreAdj in the single-container case where the
- // containerd container-type annotation is not present.
- if specutils.SpecContainerType(container.Spec) == specutils.ContainerTypeSandbox {
- continue
- }
- if container.Spec.Process.OOMScoreAdj != nil && (!scoreFound || *container.Spec.Process.OOMScoreAdj < lowScore) {
- scoreFound = true
- lowScore = *container.Spec.Process.OOMScoreAdj
- }
+ if container.Spec.Process.OOMScoreAdj != nil && (!scoreFound || *container.Spec.Process.OOMScoreAdj < lowScore) {
+ scoreFound = true
+ lowScore = *container.Spec.Process.OOMScoreAdj
}
}
// If the container is destroyed and remaining containers have no
- // oomScoreAdj specified then we must revert to the oom_score_adj of the
- // parent process.
+ // oomScoreAdj specified then we must revert to the original oom_score_adj
+ // saved with the root container.
if !scoreFound && destroy {
- ppid, err := specutils.GetParentPid(s.Pid)
- if err != nil {
- return fmt.Errorf("getting parent pid of sandbox pid %d: %v", s.Pid, err)
- }
- pScore, err := specutils.GetOOMScoreAdj(ppid)
- if err != nil {
- return fmt.Errorf("getting oom_score_adj of parent %d: %v", ppid, err)
- }
-
+ lowScore = containers[0].Sandbox.OriginalOOMScoreAdj
scoreFound = true
- lowScore = pScore
}
- // Only set oom_score_adj if one of the containers has oom_score_adj set
- // in the OCI bundle. If not, we need to inherit the parent process's
- // oom_score_adj.
+ // Only set oom_score_adj if one of the containers has oom_score_adj set. If
+ // not, oom_score_adj is inherited from the parent process.
+ //
// See: https://github.com/opencontainers/runtime-spec/blob/master/config.md#linux-process
if !scoreFound {
return nil
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index cc188f45b..fa99e403a 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -364,7 +364,7 @@ func TestLifecycle(t *testing.T) {
defer c.Destroy()
// Load the container from disk and check the status.
- c, err = Load(rootDir, args.ID)
+ c, err = LoadAndCheck(rootDir, args.ID)
if err != nil {
t.Fatalf("error loading container: %v", err)
}
@@ -387,7 +387,7 @@ func TestLifecycle(t *testing.T) {
}
// Load the container from disk and check the status.
- c, err = Load(rootDir, args.ID)
+ c, err = LoadAndCheck(rootDir, args.ID)
if err != nil {
t.Fatalf("error loading container: %v", err)
}
@@ -428,7 +428,7 @@ func TestLifecycle(t *testing.T) {
}
// Load the container from disk and check the status.
- c, err = Load(rootDir, args.ID)
+ c, err = LoadAndCheck(rootDir, args.ID)
if err != nil {
t.Fatalf("error loading container: %v", err)
}
@@ -451,7 +451,7 @@ func TestLifecycle(t *testing.T) {
}
// Loading the container by id should fail.
- if _, err = Load(rootDir, args.ID); err == nil {
+ if _, err = LoadAndCheck(rootDir, args.ID); err == nil {
t.Errorf("expected loading destroyed container to fail, but it did not")
}
})
@@ -1738,7 +1738,7 @@ func doAbbreviatedIDsTest(t *testing.T, vfs2 bool) {
cids[2]: cids[2],
}
for shortid, longid := range unambiguous {
- if _, err := Load(rootDir, shortid); err != nil {
+ if _, err := LoadAndCheck(rootDir, shortid); err != nil {
t.Errorf("%q should resolve to %q: %v", shortid, longid, err)
}
}
@@ -1749,7 +1749,7 @@ func doAbbreviatedIDsTest(t *testing.T, vfs2 bool) {
"ba",
}
for _, shortid := range ambiguous {
- if s, err := Load(rootDir, shortid); err == nil {
+ if s, err := LoadAndCheck(rootDir, shortid); err == nil {
t.Errorf("%q should be ambiguous, but resolved to %q", shortid, s.ID)
}
}
@@ -1976,11 +1976,11 @@ func doDestroyNotStartedTest(t *testing.T, vfs2 bool) {
// TestDestroyStarting attempts to force a race between start and destroy.
func TestDestroyStarting(t *testing.T) {
- doDestroyNotStartedTest(t, false)
+ doDestroyStartingTest(t, false)
}
func TestDestroyStartedVFS2(t *testing.T) {
- doDestroyNotStartedTest(t, true)
+ doDestroyStartingTest(t, true)
}
func doDestroyStartingTest(t *testing.T, vfs2 bool) {
@@ -2007,7 +2007,7 @@ func doDestroyStartingTest(t *testing.T, vfs2 bool) {
// Container is not thread safe, so load another instance to run in
// concurrently.
- startCont, err := Load(rootDir, args.ID)
+ startCont, err := LoadAndCheck(rootDir, args.ID)
if err != nil {
t.Fatalf("error loading container: %v", err)
}
diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go
index 850e80290..cadc63bf3 100644
--- a/runsc/container/multi_container_test.go
+++ b/runsc/container/multi_container_test.go
@@ -15,6 +15,7 @@
package container
import (
+ "encoding/json"
"fmt"
"io/ioutil"
"math"
@@ -762,7 +763,7 @@ func TestMultiContainerKillAll(t *testing.T) {
// processes still running inside.
containers[1].SignalContainer(syscall.SIGKILL, false)
op := func() error {
- c, err := Load(conf.RootDir, ids[1])
+ c, err := LoadAndCheck(conf.RootDir, ids[1])
if err != nil {
return err
}
@@ -776,7 +777,7 @@ func TestMultiContainerKillAll(t *testing.T) {
}
}
- c, err := Load(conf.RootDir, ids[1])
+ c, err := LoadAndCheck(conf.RootDir, ids[1])
if err != nil {
t.Fatalf("failed to load child container %q: %v", c.ID, err)
}
@@ -899,7 +900,7 @@ func TestMultiContainerDestroyStarting(t *testing.T) {
// Container is not thread safe, so load another instance to run in
// concurrently.
- startCont, err := Load(rootDir, ids[i])
+ startCont, err := LoadAndCheck(rootDir, ids[i])
if err != nil {
t.Fatalf("error loading container: %v", err)
}
@@ -1766,3 +1767,72 @@ func TestMultiContainerHomeEnvDir(t *testing.T) {
})
}
}
+
+func TestMultiContainerEvent(t *testing.T) {
+ conf := testutil.TestConfig(t)
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+ conf.RootDir = rootDir
+
+ // Setup the containers.
+ sleep := []string{"/bin/sleep", "100"}
+ quick := []string{"/bin/true"}
+ podSpec, ids := createSpecs(sleep, sleep, quick)
+ containers, cleanup, err := startContainers(conf, podSpec, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ for _, cont := range containers {
+ t.Logf("Running containerd %s", cont.ID)
+ }
+
+ // Wait for last container to stabilize the process count that is checked
+ // further below.
+ if ws, err := containers[2].Wait(); err != nil || ws != 0 {
+ t.Fatalf("Container.Wait, status: %v, err: %v", ws, err)
+ }
+
+ // Check events for running containers.
+ for _, cont := range containers[:2] {
+ evt, err := cont.Event()
+ if err != nil {
+ t.Errorf("Container.Events(): %v", err)
+ }
+ if want := "stats"; evt.Type != want {
+ t.Errorf("Wrong event type, want: %s, got :%s", want, evt.Type)
+ }
+ if cont.ID != evt.ID {
+ t.Errorf("Wrong container ID, want: %s, got :%s", cont.ID, evt.ID)
+ }
+ // Event.Data is an interface, so it comes from the wire was
+ // map[string]string. Marshal and unmarshall again to the correc type.
+ data, err := json.Marshal(evt.Data)
+ if err != nil {
+ t.Fatalf("invalid event data: %v", err)
+ }
+ var stats boot.Stats
+ if err := json.Unmarshal(data, &stats); err != nil {
+ t.Fatalf("invalid event data: %v", err)
+ }
+ // One process per remaining container.
+ if want := uint64(2); stats.Pids.Current != want {
+ t.Errorf("Wrong number of PIDs, want: %d, got :%d", want, stats.Pids.Current)
+ }
+ }
+
+ // Check that stop and destroyed containers return error.
+ if err := containers[1].Destroy(); err != nil {
+ t.Fatalf("container.Destroy: %v", err)
+ }
+ for _, cont := range containers[1:] {
+ _, err := cont.Event()
+ if err == nil {
+ t.Errorf("Container.Events() should have failed, cid:%s, state: %v", cont.ID, cont.Status)
+ }
+ }
+}
diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go
index c4309feb3..4a4110477 100644
--- a/runsc/sandbox/sandbox.go
+++ b/runsc/sandbox/sandbox.go
@@ -66,6 +66,10 @@ type Sandbox struct {
// Cgroup has the cgroup configuration for the sandbox.
Cgroup *cgroup.Cgroup `json:"cgroup"`
+ // OriginalOOMScoreAdj stores the value of oom_score_adj when the sandbox
+ // started, before it may be modified.
+ OriginalOOMScoreAdj int `json:"originalOomScoreAdj"`
+
// child is set if a sandbox process is a child of the current process.
//
// This field isn't saved to json, because only a creator of sandbox
@@ -739,6 +743,11 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn
}
return err
}
+ s.OriginalOOMScoreAdj, err = specutils.GetOOMScoreAdj(cmd.Process.Pid)
+ if err != nil {
+ return err
+ }
+
s.child = true
s.Pid = cmd.Process.Pid
log.Infof("Sandbox started, PID: %d", s.Pid)
@@ -1133,11 +1142,11 @@ func (s *Sandbox) DestroyContainer(cid string) error {
func (s *Sandbox) destroyContainer(cid string) error {
if s.IsRootContainer(cid) {
- log.Debugf("Destroying root container %q by destroying sandbox", cid)
+ log.Debugf("Destroying root container by destroying sandbox, cid: %s", cid)
return s.destroy()
}
- log.Debugf("Destroying container %q in sandbox %q", cid, s.ID)
+ log.Debugf("Destroying container, cid: %s, sandbox: %s", cid, s.ID)
conn, err := s.sandboxConnect()
if err != nil {
return err
diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go
index 0392e3e83..fdbba1832 100644
--- a/runsc/specutils/specutils.go
+++ b/runsc/specutils/specutils.go
@@ -19,6 +19,7 @@ package specutils
import (
"encoding/json"
"fmt"
+ "io"
"io/ioutil"
"os"
"path"
@@ -169,7 +170,7 @@ func ReadSpec(bundleDir string, conf *config.Config) (*specs.Spec, error) {
// ReadSpecFromFile reads an OCI runtime spec from the given File, and
// normalizes all relative paths into absolute by prepending the bundle dir.
func ReadSpecFromFile(bundleDir string, specFile *os.File, conf *config.Config) (*specs.Spec, error) {
- if _, err := specFile.Seek(0, os.SEEK_SET); err != nil {
+ if _, err := specFile.Seek(0, io.SeekStart); err != nil {
return nil, fmt.Errorf("error seeking to beginning of file %q: %v", specFile.Name(), err)
}
specBytes, err := ioutil.ReadAll(specFile)
@@ -344,15 +345,9 @@ func IsSupportedDevMount(m specs.Mount) bool {
var existingDevices = []string{
"/dev/fd", "/dev/stdin", "/dev/stdout", "/dev/stderr",
"/dev/null", "/dev/zero", "/dev/full", "/dev/random",
- "/dev/urandom", "/dev/shm", "/dev/pts", "/dev/ptmx",
+ "/dev/urandom", "/dev/shm", "/dev/ptmx",
}
dst := filepath.Clean(m.Destination)
- if dst == "/dev" {
- // OCI spec uses many different mounts for the things inside of '/dev'. We
- // have a single mount at '/dev' that is always mounted, regardless of
- // whether it was asked for, as the spec says we SHOULD.
- return false
- }
for _, dev := range existingDevices {
if dst == dev || strings.HasPrefix(dst, dev+"/") {
return false
@@ -425,7 +420,7 @@ func Mount(src, dst, typ string, flags uint32) error {
// Special case, as there is no source directory for proc mounts.
isDir = true
} else if fi, err := os.Stat(src); err != nil {
- return fmt.Errorf("Stat(%q) failed: %v", src, err)
+ return fmt.Errorf("stat(%q) failed: %v", src, err)
} else {
isDir = fi.IsDir()
}
@@ -433,25 +428,25 @@ func Mount(src, dst, typ string, flags uint32) error {
if isDir {
// Create the destination directory.
if err := os.MkdirAll(dst, 0777); err != nil {
- return fmt.Errorf("Mkdir(%q) failed: %v", dst, err)
+ return fmt.Errorf("mkdir(%q) failed: %v", dst, err)
}
} else {
// Create the parent destination directory.
parent := path.Dir(dst)
if err := os.MkdirAll(parent, 0777); err != nil {
- return fmt.Errorf("Mkdir(%q) failed: %v", parent, err)
+ return fmt.Errorf("mkdir(%q) failed: %v", parent, err)
}
// Create the destination file if it does not exist.
f, err := os.OpenFile(dst, syscall.O_CREAT, 0777)
if err != nil {
- return fmt.Errorf("Open(%q) failed: %v", dst, err)
+ return fmt.Errorf("open(%q) failed: %v", dst, err)
}
f.Close()
}
// Do the mount.
if err := syscall.Mount(src, dst, typ, uintptr(flags), ""); err != nil {
- return fmt.Errorf("Mount(%q, %q, %d) failed: %v", src, dst, flags, err)
+ return fmt.Errorf("mount(%q, %q, %d) failed: %v", src, dst, flags, err)
}
return nil
}
@@ -486,35 +481,6 @@ func GetOOMScoreAdj(pid int) (int, error) {
return strconv.Atoi(strings.TrimSpace(string(data)))
}
-// GetParentPid gets the parent process ID of the specified PID.
-func GetParentPid(pid int) (int, error) {
- data, err := ioutil.ReadFile(fmt.Sprintf("/proc/%d/stat", pid))
- if err != nil {
- return 0, err
- }
-
- var cpid string
- var name string
- var state string
- var ppid int
- // Parse after the binary name.
- _, err = fmt.Sscanf(string(data),
- "%v %v %v %d",
- // cpid is ignored.
- &cpid,
- // name is ignored.
- &name,
- // state is ignored.
- &state,
- &ppid)
-
- if err != nil {
- return 0, err
- }
-
- return ppid, nil
-}
-
// EnvVar looks for a varible value in the env slice assuming the following
// format: "NAME=VALUE".
func EnvVar(env []string, name string) (string, bool) {
diff --git a/test/benchmarks/base/BUILD b/test/benchmarks/base/BUILD
index 32c139204..7dfd4b693 100644
--- a/test/benchmarks/base/BUILD
+++ b/test/benchmarks/base/BUILD
@@ -13,7 +13,7 @@ go_library(
go_test(
name = "base_test",
- size = "large",
+ size = "enormous",
srcs = [
"size_test.go",
"startup_test.go",
diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go
index 32bf2a992..d3e5efd4f 100644
--- a/test/iptables/filter_output.go
+++ b/test/iptables/filter_output.go
@@ -441,9 +441,20 @@ func (FilterOutputDestination) Name() string {
// ContainerAction implements TestCase.ContainerAction.
func (FilterOutputDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
- rules := [][]string{
- {"-A", "OUTPUT", "-d", ip.String(), "-j", "ACCEPT"},
- {"-P", "OUTPUT", "DROP"},
+ var rules [][]string
+ if ipv6 {
+ rules = [][]string{
+ {"-A", "OUTPUT", "-d", ip.String(), "-j", "ACCEPT"},
+ // Allow solicited node multicast addresses so we can send neighbor
+ // solicitations.
+ {"-A", "OUTPUT", "-d", "ff02::1:ff00:0/104", "-j", "ACCEPT"},
+ {"-P", "OUTPUT", "DROP"},
+ }
+ } else {
+ rules = [][]string{
+ {"-A", "OUTPUT", "-d", ip.String(), "-j", "ACCEPT"},
+ {"-P", "OUTPUT", "DROP"},
+ }
}
if err := filterTableRules(ipv6, rules); err != nil {
return err
diff --git a/test/iptables/nat.go b/test/iptables/nat.go
index dd9a18339..b98d99fb8 100644
--- a/test/iptables/nat.go
+++ b/test/iptables/nat.go
@@ -577,11 +577,18 @@ func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net.
connCh := make(chan int)
errCh := make(chan error)
go func() {
- connFD, _, err := syscall.Accept(sockfd)
- if err != nil {
- errCh <- err
+ for {
+ connFD, _, err := syscall.Accept(sockfd)
+ if errors.Is(err, syscall.EINTR) {
+ continue
+ }
+ if err != nil {
+ errCh <- err
+ return
+ }
+ connCh <- connFD
+ return
}
- connCh <- connFD
}()
// Wait for accept() to return or for the context to finish.
diff --git a/test/packetimpact/netdevs/netdevs.go b/test/packetimpact/netdevs/netdevs.go
index eecfe0730..006988896 100644
--- a/test/packetimpact/netdevs/netdevs.go
+++ b/test/packetimpact/netdevs/netdevs.go
@@ -40,7 +40,7 @@ var (
deviceLine = regexp.MustCompile(`^\s*(\d+): (\w+)`)
linkLine = regexp.MustCompile(`^\s*link/\w+ ([0-9a-fA-F:]+)`)
inetLine = regexp.MustCompile(`^\s*inet ([0-9./]+)`)
- inet6Line = regexp.MustCompile(`^\s*inet6 ([0-9a-fA-Z:/]+)`)
+ inet6Line = regexp.MustCompile(`^\s*inet6 ([0-9a-fA-F:/]+)`)
)
// ParseDevices parses the output from `ip addr show` into a map from device
diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl
index 1546d0d51..c03c2c62c 100644
--- a/test/packetimpact/runner/defs.bzl
+++ b/test/packetimpact/runner/defs.bzl
@@ -252,6 +252,9 @@ ALL_TESTS = [
expect_netstack_failure = True,
),
PacketimpactTestInfo(
+ name = "ipv4_fragment_reassembly",
+ ),
+ PacketimpactTestInfo(
name = "ipv6_fragment_reassembly",
),
PacketimpactTestInfo(
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
index a90046f69..8fa585804 100644
--- a/test/packetimpact/testbench/connections.go
+++ b/test/packetimpact/testbench/connections.go
@@ -839,6 +839,61 @@ func (conn *TCPIPv4) Drain(t *testing.T) {
conn.sniffer.Drain(t)
}
+// IPv4Conn maintains the state for all the layers in a IPv4 connection.
+type IPv4Conn Connection
+
+// NewIPv4Conn creates a new IPv4Conn connection with reasonable defaults.
+func NewIPv4Conn(t *testing.T, outgoingIPv4, incomingIPv4 IPv4) IPv4Conn {
+ t.Helper()
+
+ etherState, err := newEtherState(Ether{}, Ether{})
+ if err != nil {
+ t.Fatalf("can't make EtherState: %s", err)
+ }
+ ipv4State, err := newIPv4State(outgoingIPv4, incomingIPv4)
+ if err != nil {
+ t.Fatalf("can't make IPv4State: %s", err)
+ }
+
+ injector, err := NewInjector(t)
+ if err != nil {
+ t.Fatalf("can't make injector: %s", err)
+ }
+ sniffer, err := NewSniffer(t)
+ if err != nil {
+ t.Fatalf("can't make sniffer: %s", err)
+ }
+
+ return IPv4Conn{
+ layerStates: []layerState{etherState, ipv4State},
+ injector: injector,
+ sniffer: sniffer,
+ }
+}
+
+// Send sends a frame with ipv4 overriding the IPv4 layer defaults and
+// additionalLayers added after it.
+func (c *IPv4Conn) Send(t *testing.T, ipv4 IPv4, additionalLayers ...Layer) {
+ t.Helper()
+
+ (*Connection)(c).send(t, Layers{&ipv4}, additionalLayers...)
+}
+
+// Close cleans up any resources held.
+func (c *IPv4Conn) Close(t *testing.T) {
+ t.Helper()
+
+ (*Connection)(c).Close(t)
+}
+
+// ExpectFrame expects a frame that matches the provided Layers within the
+// timeout specified. If it doesn't arrive in time, an error is returned.
+func (c *IPv4Conn) ExpectFrame(t *testing.T, frame Layers, timeout time.Duration) (Layers, error) {
+ t.Helper()
+
+ return (*Connection)(c).ExpectFrame(t, frame, timeout)
+}
+
// IPv6Conn maintains the state for all the layers in a IPv6 connection.
type IPv6Conn Connection
diff --git a/test/packetimpact/testbench/dut_client.go b/test/packetimpact/testbench/dut_client.go
index d0e68c5da..0fc3d97b4 100644
--- a/test/packetimpact/testbench/dut_client.go
+++ b/test/packetimpact/testbench/dut_client.go
@@ -19,7 +19,7 @@ import (
pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto"
)
-// PosixClient is a gRPC client for the Posix service.
+// POSIXClient is a gRPC client for the Posix service.
type POSIXClient pb.PosixClient
// NewPOSIXClient makes a new gRPC client for the POSIX service.
diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go
index a35562ca8..af7a2ba4e 100644
--- a/test/packetimpact/testbench/layers.go
+++ b/test/packetimpact/testbench/layers.go
@@ -879,6 +879,9 @@ type ICMPv4 struct {
Type *header.ICMPv4Type
Code *header.ICMPv4Code
Checksum *uint16
+ Ident *uint16
+ Sequence *uint16
+ Payload []byte
}
func (l *ICMPv4) String() string {
@@ -887,7 +890,7 @@ func (l *ICMPv4) String() string {
// ToBytes implements Layer.ToBytes.
func (l *ICMPv4) ToBytes() ([]byte, error) {
- b := make([]byte, header.ICMPv4MinimumSize)
+ b := make([]byte, header.ICMPv4MinimumSize+len(l.Payload))
h := header.ICMPv4(b)
if l.Type != nil {
h.SetType(*l.Type)
@@ -895,15 +898,33 @@ func (l *ICMPv4) ToBytes() ([]byte, error) {
if l.Code != nil {
h.SetCode(*l.Code)
}
+ if copied := copy(h.Payload(), l.Payload); copied != len(l.Payload) {
+ panic(fmt.Sprintf("wrong number of bytes copied into h.Payload(): got = %d, want = %d", len(h.Payload()), len(l.Payload)))
+ }
+ if l.Ident != nil {
+ h.SetIdent(*l.Ident)
+ }
+ if l.Sequence != nil {
+ h.SetSequence(*l.Sequence)
+ }
+
+ // The checksum must be handled last because the ICMPv4 header fields are
+ // included in the computation.
if l.Checksum != nil {
h.SetChecksum(*l.Checksum)
- return h, nil
- }
- payload, err := payload(l)
- if err != nil {
- return nil, err
+ } else {
+ // Compute the checksum based on the ICMPv4.Payload and also the subsequent
+ // layers.
+ payload, err := payload(l)
+ if err != nil {
+ return nil, err
+ }
+ var vv buffer.VectorisedView
+ vv.AppendView(buffer.View(l.Payload))
+ vv.Append(payload)
+ h.SetChecksum(header.ICMPv4Checksum(h, vv))
}
- h.SetChecksum(header.ICMPv4Checksum(h, payload))
+
return h, nil
}
@@ -915,8 +936,11 @@ func parseICMPv4(b []byte) (Layer, layerParser) {
Type: ICMPv4Type(h.Type()),
Code: ICMPv4Code(h.Code()),
Checksum: Uint16(h.Checksum()),
+ Ident: Uint16(h.Ident()),
+ Sequence: Uint16(h.Sequence()),
+ Payload: h.Payload(),
}
- return &icmpv4, parsePayload
+ return &icmpv4, nil
}
func (l *ICMPv4) match(other Layer) bool {
diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD
index 8c2de5a9f..c30c77a17 100644
--- a/test/packetimpact/tests/BUILD
+++ b/test/packetimpact/tests/BUILD
@@ -298,6 +298,18 @@ packetimpact_testbench(
)
packetimpact_testbench(
+ name = "ipv4_fragment_reassembly",
+ srcs = ["ipv4_fragment_reassembly_test.go"],
+ deps = [
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_testbench(
name = "ipv6_fragment_reassembly",
srcs = ["ipv6_fragment_reassembly_test.go"],
deps = [
@@ -305,6 +317,7 @@ packetimpact_testbench(
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//test/packetimpact/testbench",
+ "@com_github_google_go_cmp//cmp:go_default_library",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/test/packetimpact/tests/ipv4_fragment_reassembly_test.go b/test/packetimpact/tests/ipv4_fragment_reassembly_test.go
new file mode 100644
index 000000000..65c0df140
--- /dev/null
+++ b/test/packetimpact/tests/ipv4_fragment_reassembly_test.go
@@ -0,0 +1,142 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4_fragment_reassembly_test
+
+import (
+ "flag"
+ "math/rand"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+type fragmentInfo struct {
+ offset uint16
+ size uint16
+ more uint8
+}
+
+func TestIPv4FragmentReassembly(t *testing.T) {
+ const fragmentID = 42
+ icmpv4ProtoNum := uint8(header.ICMPv4ProtocolNumber)
+
+ tests := []struct {
+ description string
+ ipPayloadLen int
+ fragments []fragmentInfo
+ expectReply bool
+ }{
+ {
+ description: "basic reassembly",
+ ipPayloadLen: 2000,
+ fragments: []fragmentInfo{
+ {offset: 0, size: 1000, more: header.IPv4FlagMoreFragments},
+ {offset: 1000, size: 1000, more: 0},
+ },
+ expectReply: true,
+ },
+ {
+ description: "out of order fragments",
+ ipPayloadLen: 2000,
+ fragments: []fragmentInfo{
+ {offset: 1000, size: 1000, more: 0},
+ {offset: 0, size: 1000, more: header.IPv4FlagMoreFragments},
+ },
+ expectReply: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.description, func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ conn := testbench.NewIPv4Conn(t, testbench.IPv4{}, testbench.IPv4{})
+ defer conn.Close(t)
+
+ data := make([]byte, test.ipPayloadLen)
+ icmp := header.ICMPv4(data[:header.ICMPv4MinimumSize])
+ icmp.SetType(header.ICMPv4Echo)
+ icmp.SetCode(header.ICMPv4UnusedCode)
+ icmp.SetChecksum(0)
+ icmp.SetSequence(0)
+ icmp.SetIdent(0)
+ originalPayload := data[header.ICMPv4MinimumSize:]
+ if _, err := rand.Read(originalPayload); err != nil {
+ t.Fatalf("rand.Read: %s", err)
+ }
+ cksum := header.ICMPv4Checksum(
+ icmp,
+ buffer.NewVectorisedView(len(originalPayload), []buffer.View{originalPayload}),
+ )
+ icmp.SetChecksum(cksum)
+
+ for _, fragment := range test.fragments {
+ conn.Send(t,
+ testbench.IPv4{
+ Protocol: &icmpv4ProtoNum,
+ FragmentOffset: testbench.Uint16(fragment.offset),
+ Flags: testbench.Uint8(fragment.more),
+ ID: testbench.Uint16(fragmentID),
+ },
+ &testbench.Payload{
+ Bytes: data[fragment.offset:][:fragment.size],
+ })
+ }
+
+ var bytesReceived int
+ reassembledPayload := make([]byte, test.ipPayloadLen)
+ for {
+ incomingFrame, err := conn.ExpectFrame(t, testbench.Layers{
+ &testbench.Ether{},
+ &testbench.IPv4{},
+ &testbench.ICMPv4{},
+ }, time.Second)
+ if err != nil {
+ // Either an unexpected frame was received, or none at all.
+ if bytesReceived < test.ipPayloadLen {
+ t.Fatalf("received %d bytes out of %d, then conn.ExpectFrame(_, _, time.Second) failed with %s", bytesReceived, test.ipPayloadLen, err)
+ }
+ break
+ }
+ if !test.expectReply {
+ t.Fatalf("unexpected reply received:\n%s", incomingFrame)
+ }
+ ipPayload, err := incomingFrame[2 /* ICMPv4 */].ToBytes()
+ if err != nil {
+ t.Fatalf("failed to parse ICMPv4 header: incomingPacket[2].ToBytes() = (_, %s)", err)
+ }
+ offset := *incomingFrame[1 /* IPv4 */].(*testbench.IPv4).FragmentOffset
+ if copied := copy(reassembledPayload[offset:], ipPayload); copied != len(ipPayload) {
+ t.Fatalf("wrong number of bytes copied into reassembledPayload: got = %d, want = %d", copied, len(ipPayload))
+ }
+ bytesReceived += len(ipPayload)
+ }
+
+ if test.expectReply {
+ if diff := cmp.Diff(originalPayload, reassembledPayload[header.ICMPv4MinimumSize:]); diff != "" {
+ t.Fatalf("reassembledPayload mismatch (-want +got):\n%s", diff)
+ }
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/ipv6_fragment_reassembly_test.go b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go
index a24c85566..4a29de688 100644
--- a/test/packetimpact/tests/ipv6_fragment_reassembly_test.go
+++ b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go
@@ -15,154 +15,137 @@
package ipv6_fragment_reassembly_test
import (
- "bytes"
- "encoding/binary"
- "encoding/hex"
"flag"
+ "math/rand"
"net"
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/test/packetimpact/testbench"
)
-const (
- // The payload length for the first fragment we send. This number
- // is a multiple of 8 near 750 (half of 1500).
- firstPayloadLength = 752
- // The ID field for our outgoing fragments.
- fragmentID = 1
- // A node must be able to accept a fragmented packet that,
- // after reassembly, is as large as 1500 octets.
- reassemblyCap = 1500
-)
-
func init() {
testbench.RegisterFlags(flag.CommandLine)
}
-func TestIPv6FragmentReassembly(t *testing.T) {
- dut := testbench.NewDUT(t)
- defer dut.TearDown()
- conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
- defer conn.Close(t)
-
- firstPayloadToSend := make([]byte, firstPayloadLength)
- for i := range firstPayloadToSend {
- firstPayloadToSend[i] = 'A'
- }
-
- secondPayloadLength := reassemblyCap - firstPayloadLength - header.ICMPv6EchoMinimumSize
- secondPayloadToSend := firstPayloadToSend[:secondPayloadLength]
-
- icmpv6EchoPayload := make([]byte, 4)
- binary.BigEndian.PutUint16(icmpv6EchoPayload[0:], 0)
- binary.BigEndian.PutUint16(icmpv6EchoPayload[2:], 0)
- icmpv6EchoPayload = append(icmpv6EchoPayload, firstPayloadToSend...)
-
- lIP := tcpip.Address(net.ParseIP(testbench.LocalIPv6).To16())
- rIP := tcpip.Address(net.ParseIP(testbench.RemoteIPv6).To16())
- icmpv6 := testbench.ICMPv6{
- Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest),
- Code: testbench.ICMPv6Code(header.ICMPv6UnusedCode),
- Payload: icmpv6EchoPayload,
- }
- icmpv6Bytes, err := icmpv6.ToBytes()
- if err != nil {
- t.Fatalf("failed to serialize ICMPv6: %s", err)
- }
- cksum := header.ICMPv6Checksum(
- header.ICMPv6(icmpv6Bytes),
- lIP,
- rIP,
- buffer.NewVectorisedView(len(secondPayloadToSend), []buffer.View{secondPayloadToSend}),
- )
-
- conn.Send(t, testbench.IPv6{},
- &testbench.IPv6FragmentExtHdr{
- FragmentOffset: testbench.Uint16(0),
- MoreFragments: testbench.Bool(true),
- Identification: testbench.Uint32(fragmentID),
- },
- &testbench.ICMPv6{
- Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest),
- Code: testbench.ICMPv6Code(header.ICMPv6UnusedCode),
- Payload: icmpv6EchoPayload,
- Checksum: &cksum,
- })
+type fragmentInfo struct {
+ offset uint16
+ size uint16
+ more bool
+}
+func TestIPv6FragmentReassembly(t *testing.T) {
+ const fragmentID = 42
icmpv6ProtoNum := header.IPv6ExtensionHeaderIdentifier(header.ICMPv6ProtocolNumber)
- conn.Send(t, testbench.IPv6{},
- &testbench.IPv6FragmentExtHdr{
- NextHeader: &icmpv6ProtoNum,
- FragmentOffset: testbench.Uint16((firstPayloadLength + header.ICMPv6EchoMinimumSize) / 8),
- MoreFragments: testbench.Bool(false),
- Identification: testbench.Uint32(fragmentID),
+ tests := []struct {
+ description string
+ ipPayloadLen int
+ fragments []fragmentInfo
+ expectReply bool
+ }{
+ {
+ description: "basic reassembly",
+ ipPayloadLen: 1500,
+ fragments: []fragmentInfo{
+ {offset: 0, size: 760, more: true},
+ {offset: 760, size: 740, more: false},
+ },
+ expectReply: true,
},
- &testbench.Payload{
- Bytes: secondPayloadToSend,
- })
-
- gotEchoReplyFirstPart, err := conn.ExpectFrame(t, testbench.Layers{
- &testbench.Ether{},
- &testbench.IPv6{},
- &testbench.IPv6FragmentExtHdr{
- FragmentOffset: testbench.Uint16(0),
- MoreFragments: testbench.Bool(true),
+ {
+ description: "out of order fragments",
+ ipPayloadLen: 3000,
+ fragments: []fragmentInfo{
+ {offset: 0, size: 1024, more: true},
+ {offset: 2048, size: 952, more: false},
+ {offset: 1024, size: 1024, more: true},
+ },
+ expectReply: true,
},
- &testbench.ICMPv6{
- Type: testbench.ICMPv6Type(header.ICMPv6EchoReply),
- Code: testbench.ICMPv6Code(header.ICMPv6UnusedCode),
- },
- }, time.Second)
- if err != nil {
- t.Fatalf("expected a fragmented ICMPv6 Echo Reply, but got none: %s", err)
}
- id := *gotEchoReplyFirstPart[2].(*testbench.IPv6FragmentExtHdr).Identification
- gotFirstPayload, err := gotEchoReplyFirstPart[len(gotEchoReplyFirstPart)-1].ToBytes()
- if err != nil {
- t.Fatalf("failed to serialize ICMPv6: %s", err)
- }
- icmpPayload := gotFirstPayload[header.ICMPv6EchoMinimumSize:]
- receivedLen := len(icmpPayload)
- wantSecondPayloadLen := reassemblyCap - header.ICMPv6EchoMinimumSize - receivedLen
- wantFirstPayload := make([]byte, receivedLen)
- for i := range wantFirstPayload {
- wantFirstPayload[i] = 'A'
- }
- wantSecondPayload := wantFirstPayload[:wantSecondPayloadLen]
- if !bytes.Equal(icmpPayload, wantFirstPayload) {
- t.Fatalf("received unexpected payload, got: %s, want: %s",
- hex.Dump(icmpPayload),
- hex.Dump(wantFirstPayload))
- }
-
- gotEchoReplySecondPart, err := conn.ExpectFrame(t, testbench.Layers{
- &testbench.Ether{},
- &testbench.IPv6{},
- &testbench.IPv6FragmentExtHdr{
- NextHeader: &icmpv6ProtoNum,
- FragmentOffset: testbench.Uint16(uint16((receivedLen + header.ICMPv6EchoMinimumSize) / 8)),
- MoreFragments: testbench.Bool(false),
- Identification: &id,
- },
- &testbench.ICMPv6{},
- }, time.Second)
- if err != nil {
- t.Fatalf("expected the rest of ICMPv6 Echo Reply, but got none: %s", err)
- }
- secondPayload, err := gotEchoReplySecondPart[len(gotEchoReplySecondPart)-1].ToBytes()
- if err != nil {
- t.Fatalf("failed to serialize ICMPv6 Echo Reply: %s", err)
- }
- if !bytes.Equal(secondPayload, wantSecondPayload) {
- t.Fatalf("received unexpected payload, got: %s, want: %s",
- hex.Dump(secondPayload),
- hex.Dump(wantSecondPayload))
+ for _, test := range tests {
+ t.Run(test.description, func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
+ defer conn.Close(t)
+
+ lIP := tcpip.Address(net.ParseIP(testbench.LocalIPv6).To16())
+ rIP := tcpip.Address(net.ParseIP(testbench.RemoteIPv6).To16())
+
+ data := make([]byte, test.ipPayloadLen)
+ icmp := header.ICMPv6(data[:header.ICMPv6HeaderSize])
+ icmp.SetType(header.ICMPv6EchoRequest)
+ icmp.SetCode(header.ICMPv6UnusedCode)
+ icmp.SetChecksum(0)
+ originalPayload := data[header.ICMPv6HeaderSize:]
+ if _, err := rand.Read(originalPayload); err != nil {
+ t.Fatalf("rand.Read: %s", err)
+ }
+
+ cksum := header.ICMPv6Checksum(
+ icmp,
+ lIP,
+ rIP,
+ buffer.NewVectorisedView(len(originalPayload), []buffer.View{originalPayload}),
+ )
+ icmp.SetChecksum(cksum)
+
+ for _, fragment := range test.fragments {
+ conn.Send(t, testbench.IPv6{},
+ &testbench.IPv6FragmentExtHdr{
+ NextHeader: &icmpv6ProtoNum,
+ FragmentOffset: testbench.Uint16(fragment.offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit),
+ MoreFragments: testbench.Bool(fragment.more),
+ Identification: testbench.Uint32(fragmentID),
+ },
+ &testbench.Payload{
+ Bytes: data[fragment.offset:][:fragment.size],
+ })
+ }
+
+ var bytesReceived int
+ reassembledPayload := make([]byte, test.ipPayloadLen)
+ for {
+ incomingFrame, err := conn.ExpectFrame(t, testbench.Layers{
+ &testbench.Ether{},
+ &testbench.IPv6{},
+ &testbench.IPv6FragmentExtHdr{},
+ &testbench.ICMPv6{},
+ }, time.Second)
+ if err != nil {
+ // Either an unexpected frame was received, or none at all.
+ if bytesReceived < test.ipPayloadLen {
+ t.Fatalf("received %d bytes out of %d, then conn.ExpectFrame(_, _, time.Second) failed with %s", bytesReceived, test.ipPayloadLen, err)
+ }
+ break
+ }
+ if !test.expectReply {
+ t.Fatalf("unexpected reply received:\n%s", incomingFrame)
+ }
+ ipPayload, err := incomingFrame[3 /* ICMPv6 */].ToBytes()
+ if err != nil {
+ t.Fatalf("failed to parse ICMPv6 header: incomingPacket[3].ToBytes() = (_, %s)", err)
+ }
+ offset := *incomingFrame[2 /* IPv6FragmentExtHdr */].(*testbench.IPv6FragmentExtHdr).FragmentOffset
+ offset *= header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit
+ if copied := copy(reassembledPayload[offset:], ipPayload); copied != len(ipPayload) {
+ t.Fatalf("wrong number of bytes copied into reassembledPayload: got = %d, want = %d", copied, len(ipPayload))
+ }
+ bytesReceived += len(ipPayload)
+ }
+
+ if test.expectReply {
+ if diff := cmp.Diff(originalPayload, reassembledPayload[header.ICMPv6HeaderSize:]); diff != "" {
+ t.Fatalf("reassembledPayload mismatch (-want +got):\n%s", diff)
+ }
+ }
+ })
}
}
diff --git a/test/packetimpact/tests/tcp_network_unreachable_test.go b/test/packetimpact/tests/tcp_network_unreachable_test.go
index 2f57dff19..8a1fe1279 100644
--- a/test/packetimpact/tests/tcp_network_unreachable_test.go
+++ b/test/packetimpact/tests/tcp_network_unreachable_test.go
@@ -74,7 +74,9 @@ func TestTCPSynSentUnreachable(t *testing.T) {
}
var icmpv4 testbench.ICMPv4 = testbench.ICMPv4{
Type: testbench.ICMPv4Type(header.ICMPv4DstUnreachable),
- Code: testbench.ICMPv4Code(header.ICMPv4HostUnreachable)}
+ Code: testbench.ICMPv4Code(header.ICMPv4HostUnreachable),
+ }
+
layers = append(layers, &icmpv4, ip, tcp)
rawConn.SendFrameStateless(t, layers)
diff --git a/test/root/oom_score_adj_test.go b/test/root/oom_score_adj_test.go
index 4243eb59e..0dcc0fdea 100644
--- a/test/root/oom_score_adj_test.go
+++ b/test/root/oom_score_adj_test.go
@@ -40,11 +40,7 @@ var (
// TestOOMScoreAdjSingle tests that oom_score_adj is set properly in a
// single container sandbox.
func TestOOMScoreAdjSingle(t *testing.T) {
- ppid, err := specutils.GetParentPid(os.Getpid())
- if err != nil {
- t.Fatalf("getting parent pid: %v", err)
- }
- parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(ppid)
+ parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(os.Getppid())
if err != nil {
t.Fatalf("getting parent oom_score_adj: %v", err)
}
@@ -122,11 +118,7 @@ func TestOOMScoreAdjSingle(t *testing.T) {
// TestOOMScoreAdjMulti tests that oom_score_adj is set properly in a
// multi-container sandbox.
func TestOOMScoreAdjMulti(t *testing.T) {
- ppid, err := specutils.GetParentPid(os.Getpid())
- if err != nil {
- t.Fatalf("getting parent pid: %v", err)
- }
- parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(ppid)
+ parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(os.Getppid())
if err != nil {
t.Fatalf("getting parent oom_score_adj: %v", err)
}
diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl
index 9b5994d59..4992147d4 100644
--- a/test/runner/defs.bzl
+++ b/test/runner/defs.bzl
@@ -97,10 +97,10 @@ def _syscall_test(
# we figure out how to request ipv4 sockets on Guitar machines.
if network == "host":
tags.append("noguitar")
- tags.append("block-network")
# Disable off-host networking.
tags.append("requires-net:loopback")
+ tags.append("block-network")
# gotsan makes sense only if tests are running in gVisor.
if platform == "native":
diff --git a/test/runner/runner.go b/test/runner/runner.go
index 22d535f8d..7ab2c3edf 100644
--- a/test/runner/runner.go
+++ b/test/runner/runner.go
@@ -53,6 +53,9 @@ var (
runscPath = flag.String("runsc", "", "path to runsc binary")
addUDSTree = flag.Bool("add-uds-tree", false, "expose a tree of UDS utilities for use in tests")
+ // TODO(gvisor.dev/issue/4572): properly support leak checking for runsc, and
+ // set to true as the default for the test runner.
+ leakCheck = flag.Bool("leak-check", false, "check for reference leaks")
)
// runTestCaseNative runs the test case directly on the host machine.
@@ -174,6 +177,9 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error {
if *addUDSTree {
args = append(args, "-fsgofer-host-uds")
}
+ if *leakCheck {
+ args = append(args, "-ref-leak-mode=log-names")
+ }
testLogDir := ""
if undeclaredOutputsDir, ok := syscall.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); ok {
diff --git a/test/runtimes/exclude/java11.csv b/test/runtimes/exclude/java11.csv
index d978baca7..e41441374 100644
--- a/test/runtimes/exclude/java11.csv
+++ b/test/runtimes/exclude/java11.csv
@@ -144,6 +144,7 @@ jdk/jfr/cmd/TestSplit.java,,java.lang.RuntimeException: 'Missing file' missing f
jdk/jfr/cmd/TestSummary.java,,java.lang.RuntimeException: 'Missing file' missing from stdout/stderr
jdk/jfr/event/compiler/TestCompilerStats.java,,java.lang.RuntimeException: Field nmetodsSize not in event
jdk/jfr/event/metadata/TestDefaultConfigurations.java,,Setting 'threshold' in event 'jdk.SecurityPropertyModification' was not configured in the configuration 'default'
+jdk/jfr/event/oldobject/TestLargeRootSet.java,,Flaky - `main' threw exception: java.lang.RuntimeException: Could not find root object
jdk/jfr/event/runtime/TestActiveSettingEvent.java,,java.lang.Exception: Could not find setting with name jdk.X509Validation#threshold
jdk/jfr/event/runtime/TestModuleEvents.java,,java.lang.RuntimeException: assertEquals: expected jdk.proxy1 to equal java.base
jdk/jfr/event/runtime/TestNetworkUtilizationEvent.java,,
diff --git a/test/runtimes/exclude/nodejs12.4.0.csv b/test/runtimes/exclude/nodejs12.4.0.csv
index ba993814f..c4e7917ec 100644
--- a/test/runtimes/exclude/nodejs12.4.0.csv
+++ b/test/runtimes/exclude/nodejs12.4.0.csv
@@ -1,31 +1,22 @@
test name,bug id,comment
async-hooks/test-statwatcher.js,https://github.com/nodejs/node/issues/21425,Check for fix inclusion in nodejs releases after 2020-03-29
-benchmark/test-benchmark-fs.js,,
-benchmark/test-benchmark-napi.js,,
+benchmark/test-benchmark-fs.js,,Broken test
+benchmark/test-benchmark-napi.js,,Broken test
doctool/test-make-doc.js,b/68848110,Expected to fail.
internet/test-dgram-multicast-set-interface-lo.js,b/162798882,
-internet/test-doctool-versions.js,,
-internet/test-uv-threadpool-schedule.js,,
-parallel/test-cluster-dgram-reuse.js,b/64024294,
+internet/test-doctool-versions.js,,Broken test
+internet/test-uv-threadpool-schedule.js,,Broken test
parallel/test-dgram-bind-fd.js,b/132447356,
parallel/test-dgram-socket-buffer-size.js,b/68847921,
parallel/test-dns-channel-timeout.js,b/161893056,
-parallel/test-fs-access.js,,
-parallel/test-fs-watchfile.js,,Flaky - File already exists error
-parallel/test-fs-write-stream.js,b/166819807,Flaky
-parallel/test-fs-write-stream-double-close.js,b/166819807,Flaky
-parallel/test-fs-write-stream-throw-type-error.js,b/166819807,Flaky
-parallel/test-http-writable-true-after-close.js,,Flaky - Mismatched <anonymous> function calls. Expected exactly 1 actual 2
+parallel/test-fs-access.js,,Broken test
+parallel/test-fs-watchfile.js,b/166819807,Flaky - VFS1 only
+parallel/test-fs-write-stream.js,b/166819807,Flaky - VFS1 only
+parallel/test-fs-write-stream-double-close.js,b/166819807,Flaky - VFS1 only
+parallel/test-fs-write-stream-throw-type-error.js,b/166819807,Flaky - VFS1 only
+parallel/test-http-writable-true-after-close.js,b/171301436,Flaky - Mismatched <anonymous> function calls. Expected exactly 1 actual 2
parallel/test-os.js,b/63997097,
-parallel/test-net-server-listen-options.js,,Flaky - EADDRINUSE
-parallel/test-process-uid-gid.js,,
-parallel/test-tls-cli-min-version-1.0.js,,Flaky - EADDRINUSE
-parallel/test-tls-cli-min-version-1.1.js,,Flaky - EADDRINUSE
-parallel/test-tls-cli-min-version-1.2.js,,Flaky - EADDRINUSE
-parallel/test-tls-cli-min-version-1.3.js,,Flaky - EADDRINUSE
-parallel/test-tls-cli-max-version-1.2.js,,Flaky - EADDRINUSE
-parallel/test-tls-cli-max-version-1.3.js,,Flaky - EADDRINUSE
-parallel/test-tls-min-max-version.js,,Flaky - EADDRINUSE
+parallel/test-process-uid-gid.js,,Does not work inside Docker with gid nobody
pseudo-tty/test-assert-colors.js,b/162801321,
pseudo-tty/test-assert-no-color.js,b/162801321,
pseudo-tty/test-assert-position-indicator.js,b/162801321,
@@ -48,11 +39,7 @@ pseudo-tty/test-tty-stdout-resize.js,b/162801321,
pseudo-tty/test-tty-stream-constructors.js,b/162801321,
pseudo-tty/test-tty-window-size.js,b/162801321,
pseudo-tty/test-tty-wrap.js,b/162801321,
-pummel/test-heapdump-http2.js,,Flaky
-pummel/test-net-pingpong.js,,
+pummel/test-net-pingpong.js,,Broken test
pummel/test-vm-memleak.js,b/162799436,
-pummel/test-watch-file.js,,Flaky - Timeout
-sequential/test-child-process-pass-fd.js,b/63926391,Flaky
-sequential/test-https-connect-localport.js,,Flaky - EADDRINUSE
-sequential/test-net-bytes-per-incoming-chunk-overhead.js,,flaky - timeout
-tick-processor/test-tick-processor-builtin.js,,
+pummel/test-watch-file.js,,Flaky - VFS1 only
+tick-processor/test-tick-processor-builtin.js,,Broken test
diff --git a/test/runtimes/exclude/php7.3.6.csv b/test/runtimes/exclude/php7.3.6.csv
index a73f3bcfb..c051fe571 100644
--- a/test/runtimes/exclude/php7.3.6.csv
+++ b/test/runtimes/exclude/php7.3.6.csv
@@ -8,6 +8,7 @@ ext/mbstring/tests/bug77165.phpt,,
ext/mbstring/tests/bug77454.phpt,,
ext/mbstring/tests/mb_convert_encoding_leak.phpt,,
ext/mbstring/tests/mb_strrpos_encoding_3rd_param.phpt,,
+ext/pcre/tests/cache_limit.phpt,,Broken test - Flaky
ext/session/tests/session_module_name_variation4.phpt,,Flaky
ext/session/tests/session_set_save_handler_class_018.phpt,,
ext/session/tests/session_set_save_handler_iface_003.phpt,,
@@ -26,13 +27,14 @@ ext/standard/tests/file/php_fd_wrapper_01.phpt,,
ext/standard/tests/file/php_fd_wrapper_02.phpt,,
ext/standard/tests/file/php_fd_wrapper_03.phpt,,
ext/standard/tests/file/php_fd_wrapper_04.phpt,,
-ext/standard/tests/file/realpath_bug77484.phpt,b/162894969,
+ext/standard/tests/file/realpath_bug77484.phpt,b/162894969,VFS1 only failure
ext/standard/tests/file/rename_variation.phpt,b/68717309,
ext/standard/tests/file/symlink_link_linkinfo_is_link_variation4.phpt,b/162895341,
-ext/standard/tests/file/symlink_link_linkinfo_is_link_variation8.phpt,b/162896223,
+ext/standard/tests/file/symlink_link_linkinfo_is_link_variation8.phpt,b/162896223,VFS1 only failure
ext/standard/tests/general_functions/escapeshellarg_bug71270.phpt,,
ext/standard/tests/general_functions/escapeshellcmd_bug71270.phpt,,
ext/standard/tests/streams/proc_open_bug60120.phpt,,Flaky until php-src 3852a35fdbcb
+ext/standard/tests/streams/proc_open_bug64438.phpt,,Flaky
ext/standard/tests/streams/proc_open_bug69900.phpt,,Flaky
ext/standard/tests/streams/stream_socket_sendto.phpt,,
ext/standard/tests/strings/007.phpt,,
diff --git a/test/runtimes/exclude/python3.7.3.csv b/test/runtimes/exclude/python3.7.3.csv
index 8760f8951..e9fef03b7 100644
--- a/test/runtimes/exclude/python3.7.3.csv
+++ b/test/runtimes/exclude/python3.7.3.csv
@@ -4,7 +4,6 @@ test_asyncore,b/162973328,
test_epoll,b/162983393,
test_fcntl,b/162978767,fcntl invalid argument -- artificial test to make sure something works in 64 bit mode.
test_httplib,b/163000009,OSError: [Errno 98] Address already in use
-test_imaplib,b/162979661,
test_logging,b/162980079,
test_multiprocessing_fork,,Flaky. Sometimes times out.
test_multiprocessing_forkserver,,Flaky. Sometimes times out.
@@ -18,4 +17,3 @@ test_selectors,b/76116849,OSError not raised with epoll
test_smtplib,b/162980434,unclosed sockets
test_signal,,Flaky - signal: alarm clock
test_socket,b/75983380,
-test_subprocess,b/162980831,
diff --git a/test/runtimes/proctor/main.go b/test/runtimes/proctor/main.go
index e5607ac92..81cb68381 100644
--- a/test/runtimes/proctor/main.go
+++ b/test/runtimes/proctor/main.go
@@ -22,6 +22,7 @@ import (
"log"
"os"
"strings"
+ "syscall"
"gvisor.dev/gvisor/test/runtimes/proctor/lib"
)
@@ -33,6 +34,29 @@ var (
pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children")
)
+// setNumFilesLimit changes the NOFILE soft rlimit if it is too high.
+func setNumFilesLimit() error {
+ // In docker containers, the default value of the NOFILE limit is
+ // 1048576. A few runtime tests (e.g. python:test_subprocess)
+ // enumerates all possible file descriptors and these tests can fail by
+ // timeout if the NOFILE limit is too high. On gVisor, syscalls are
+ // slower so these tests will need even more time to pass.
+ const nofile = 32768
+ rLimit := syscall.Rlimit{}
+ err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit)
+ if err != nil {
+ return fmt.Errorf("failed to get RLIMIT_NOFILE: %v", err)
+ }
+ if rLimit.Cur > nofile {
+ rLimit.Cur = nofile
+ err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rLimit)
+ if err != nil {
+ return fmt.Errorf("failed to set RLIMIT_NOFILE: %v", err)
+ }
+ }
+ return nil
+}
+
func main() {
flag.Parse()
@@ -74,6 +98,10 @@ func main() {
tests = strings.Split(*testNames, ",")
}
+ if err := setNumFilesLimit(); err != nil {
+ log.Fatalf("%v", err)
+ }
+
// Run tests.
cmds := tr.TestCmds(tests)
for _, cmd := range cmds {
diff --git a/test/runtimes/runner/lib/lib.go b/test/runtimes/runner/lib/lib.go
index 78285cb0e..64e6e14db 100644
--- a/test/runtimes/runner/lib/lib.go
+++ b/test/runtimes/runner/lib/lib.go
@@ -34,8 +34,16 @@ import (
// RunTests is a helper that is called by main. It exists so that we can run
// defered functions before exiting. It returns an exit code that should be
// passed to os.Exit.
-func RunTests(lang, image, excludeFile string, batchSize int, timeout time.Duration) int {
- // Get tests to exclude..
+func RunTests(lang, image, excludeFile string, partitionNum, totalPartitions, batchSize int, timeout time.Duration) int {
+ if partitionNum <= 0 || totalPartitions <= 0 || partitionNum > totalPartitions {
+ fmt.Fprintf(os.Stderr, "invalid partition %d of %d", partitionNum, totalPartitions)
+ return 1
+ }
+
+ // TODO(gvisor.dev/issue/1624): Remove those tests from all exclude lists
+ // that only fail with VFS1.
+
+ // Get tests to exclude.
excludes, err := getExcludes(excludeFile)
if err != nil {
fmt.Fprintf(os.Stderr, "Error getting exclude list: %s\n", err.Error())
@@ -55,7 +63,7 @@ func RunTests(lang, image, excludeFile string, batchSize int, timeout time.Durat
// Get a slice of tests to run. This will also start a single Docker
// container that will be used to run each test. The final test will
// stop the Docker container.
- tests, err := getTests(ctx, d, lang, image, batchSize, timeout, excludes)
+ tests, err := getTests(ctx, d, lang, image, partitionNum, totalPartitions, batchSize, timeout, excludes)
if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err.Error())
return 1
@@ -66,7 +74,7 @@ func RunTests(lang, image, excludeFile string, batchSize int, timeout time.Durat
}
// getTests executes all tests as table tests.
-func getTests(ctx context.Context, d *dockerutil.Container, lang, image string, batchSize int, timeout time.Duration, excludes map[string]struct{}) ([]testing.InternalTest, error) {
+func getTests(ctx context.Context, d *dockerutil.Container, lang, image string, partitionNum, totalPartitions, batchSize int, timeout time.Duration, excludes map[string]struct{}) ([]testing.InternalTest, error) {
// Start the container.
opts := dockerutil.RunOpts{
Image: fmt.Sprintf("runtimes/%s", image),
@@ -86,6 +94,14 @@ func getTests(ctx context.Context, d *dockerutil.Container, lang, image string,
// shard.
tests := strings.Fields(list)
sort.Strings(tests)
+
+ partitionSize := len(tests) / totalPartitions
+ if partitionNum == totalPartitions {
+ tests = tests[(partitionNum-1)*partitionSize:]
+ } else {
+ tests = tests[(partitionNum-1)*partitionSize : partitionNum*partitionSize]
+ }
+
indices, err := testutil.TestIndicesForShard(len(tests))
if err != nil {
return nil, fmt.Errorf("TestsForShard() failed: %v", err)
@@ -116,8 +132,15 @@ func getTests(ctx context.Context, d *dockerutil.Container, lang, image string,
err error
)
+ state, err := d.Status(ctx)
+ if err != nil {
+ t.Fatalf("Could not find container status: %v", err)
+ }
+ if !state.Running {
+ t.Fatalf("container is not running: state = %s", state.Status)
+ }
+
go func() {
- fmt.Printf("RUNNING the following in a batch\n%s\n", strings.Join(tcs, "\n"))
output, err = d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", lang, "--tests", strings.Join(tcs, ","))
close(done)
}()
@@ -125,12 +148,12 @@ func getTests(ctx context.Context, d *dockerutil.Container, lang, image string,
select {
case <-done:
if err == nil {
- fmt.Printf("PASS: (%v)\n\n", time.Since(now))
+ fmt.Printf("PASS: (%v) %d tests passed\n", time.Since(now), len(tcs))
return
}
- t.Errorf("FAIL: (%v):\n%s\n", time.Since(now), output)
+ t.Errorf("FAIL: (%v):\nBatch:\n%s\nOutput:\n%s\n", time.Since(now), strings.Join(tcs, "\n"), output)
case <-time.After(timeout):
- t.Errorf("TIMEOUT: (%v):\n%s\n", time.Since(now), output)
+ t.Errorf("TIMEOUT: (%v):\nBatch:\n%s\nOutput:\n%s\n", time.Since(now), strings.Join(tcs, "\n"), output)
}
},
})
diff --git a/test/runtimes/runner/main.go b/test/runtimes/runner/main.go
index ec79a22c2..5b3443e36 100644
--- a/test/runtimes/runner/main.go
+++ b/test/runtimes/runner/main.go
@@ -25,11 +25,13 @@ import (
)
var (
- lang = flag.String("lang", "", "language runtime to test")
- image = flag.String("image", "", "docker image with runtime tests")
- excludeFile = flag.String("exclude_file", "", "file containing list of tests to exclude, in CSV format with fields: test name, bug id, comment")
- batchSize = flag.Int("batch", 50, "number of test cases run in one command")
- timeout = flag.Duration("timeout", 90*time.Minute, "batch timeout")
+ lang = flag.String("lang", "", "language runtime to test")
+ image = flag.String("image", "", "docker image with runtime tests")
+ excludeFile = flag.String("exclude_file", "", "file containing list of tests to exclude, in CSV format with fields: test name, bug id, comment")
+ partition = flag.Int("partition", 1, "partition number, this is 1-indexed")
+ totalPartitions = flag.Int("total_partitions", 1, "total number of partitions")
+ batchSize = flag.Int("batch", 50, "number of test cases run in one command")
+ timeout = flag.Duration("timeout", 90*time.Minute, "batch timeout")
)
func main() {
@@ -38,5 +40,5 @@ func main() {
fmt.Fprintf(os.Stderr, "lang and image flags must not be empty\n")
os.Exit(1)
}
- os.Exit(lib.RunTests(*lang, *image, *excludeFile, *batchSize, *timeout))
+ os.Exit(lib.RunTests(*lang, *image, *excludeFile, *partition, *totalPartitions, *batchSize, *timeout))
}
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index f66a9ceb4..b5a4ef4df 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -695,6 +695,10 @@ syscall_test(
)
syscall_test(
+ test = "//test/syscalls/linux:socket_ip_unbound_netlink_test",
+)
+
+syscall_test(
test = "//test/syscalls/linux:socket_netdevice_test",
)
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 572f39a5d..2350f7e69 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -1285,6 +1285,7 @@ cc_binary(
"//test/util:mount_util",
"//test/util:multiprocess_util",
"//test/util:posix_error",
+ "//test/util:save_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
@@ -1801,10 +1802,14 @@ cc_binary(
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/time",
gtest,
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
"//test/util:logging",
+ "//test/util:memory_util",
"//test/util:multiprocess_util",
"//test/util:platform_util",
"//test/util:signal_util",
+ "//test/util:temp_path",
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:time_util",
@@ -2101,10 +2106,12 @@ cc_binary(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
gtest,
+ "//test/util:signal_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "//test/util:timer_util",
],
)
@@ -2124,6 +2131,7 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "//test/util:timer_util",
],
)
@@ -2137,10 +2145,12 @@ cc_binary(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
gtest,
+ "//test/util:signal_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "//test/util:timer_util",
],
)
@@ -2434,6 +2444,7 @@ cc_library(
"@com_google_absl//absl/memory",
gtest,
"//test/util:posix_error",
+ "//test/util:save_util",
"//test/util:test_util",
],
alwayslink = 1,
@@ -2878,6 +2889,24 @@ cc_binary(
)
cc_binary(
+ name = "socket_ip_unbound_netlink_test",
+ testonly = 1,
+ srcs = [
+ "socket_ip_unbound_netlink.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_netlink_route_util",
+ ":socket_test_util",
+ "//test/util:capability_util",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
name = "socket_domain_test",
testonly = 1,
srcs = [
@@ -3441,6 +3470,7 @@ cc_binary(
"@com_google_absl//absl/strings",
gtest,
"//test/util:posix_error",
+ "//test/util:save_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
diff --git a/test/syscalls/linux/mknod.cc b/test/syscalls/linux/mknod.cc
index b96907b30..1635c6d0c 100644
--- a/test/syscalls/linux/mknod.cc
+++ b/test/syscalls/linux/mknod.cc
@@ -125,6 +125,16 @@ TEST(MknodTest, Socket) {
ASSERT_THAT(unlink(filename.c_str()), SyscallSucceeds());
}
+PosixErrorOr<FileDescriptor> OpenRetryEINTR(std::string const& path, int flags,
+ mode_t mode = 0) {
+ while (true) {
+ auto maybe_fd = Open(path, flags, mode);
+ if (maybe_fd.ok() || maybe_fd.error().errno_value() != EINTR) {
+ return maybe_fd;
+ }
+ }
+}
+
TEST(MknodTest, Fifo) {
const std::string fifo = NewTempAbsPath();
ASSERT_THAT(mknod(fifo.c_str(), S_IFIFO | S_IRUSR | S_IWUSR, 0),
@@ -139,14 +149,16 @@ TEST(MknodTest, Fifo) {
// Read-end of the pipe.
ScopedThread t([&fifo, &buf, &msg]() {
- FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_RDONLY));
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenRetryEINTR(fifo.c_str(), O_RDONLY));
EXPECT_THAT(ReadFd(fd.get(), buf.data(), buf.size()),
SyscallSucceedsWithValue(msg.length()));
EXPECT_EQ(msg, std::string(buf.data()));
});
// Write-end of the pipe.
- FileDescriptor wfd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_WRONLY));
+ FileDescriptor wfd =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenRetryEINTR(fifo.c_str(), O_WRONLY));
EXPECT_THAT(WriteFd(wfd.get(), msg.c_str(), msg.length()),
SyscallSucceedsWithValue(msg.length()));
}
@@ -164,15 +176,16 @@ TEST(MknodTest, FifoOtrunc) {
std::vector<char> buf(512);
// Read-end of the pipe.
ScopedThread t([&fifo, &buf, &msg]() {
- FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_RDONLY));
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenRetryEINTR(fifo.c_str(), O_RDONLY));
EXPECT_THAT(ReadFd(fd.get(), buf.data(), buf.size()),
SyscallSucceedsWithValue(msg.length()));
EXPECT_EQ(msg, std::string(buf.data()));
});
// Write-end of the pipe.
- FileDescriptor wfd =
- ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_WRONLY | O_TRUNC));
+ FileDescriptor wfd = ASSERT_NO_ERRNO_AND_VALUE(
+ OpenRetryEINTR(fifo.c_str(), O_WRONLY | O_TRUNC));
EXPECT_THAT(WriteFd(wfd.get(), msg.c_str(), msg.length()),
SyscallSucceedsWithValue(msg.length()));
}
@@ -192,14 +205,15 @@ TEST(MknodTest, FifoTruncNoOp) {
std::vector<char> buf(512);
// Read-end of the pipe.
ScopedThread t([&fifo, &buf, &msg]() {
- FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_RDONLY));
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenRetryEINTR(fifo.c_str(), O_RDONLY));
EXPECT_THAT(ReadFd(fd.get(), buf.data(), buf.size()),
SyscallSucceedsWithValue(msg.length()));
EXPECT_EQ(msg, std::string(buf.data()));
});
- FileDescriptor wfd =
- ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_WRONLY | O_TRUNC));
+ FileDescriptor wfd = ASSERT_NO_ERRNO_AND_VALUE(
+ OpenRetryEINTR(fifo.c_str(), O_WRONLY | O_TRUNC));
EXPECT_THAT(ftruncate(wfd.get(), 0), SyscallFailsWithErrno(EINVAL));
EXPECT_THAT(WriteFd(wfd.get(), msg.c_str(), msg.length()),
SyscallSucceedsWithValue(msg.length()));
diff --git a/test/syscalls/linux/mmap.cc b/test/syscalls/linux/mmap.cc
index e52c9cbcb..83546830d 100644
--- a/test/syscalls/linux/mmap.cc
+++ b/test/syscalls/linux/mmap.cc
@@ -592,6 +592,12 @@ TEST_F(MMapTest, ProtExec) {
memcpy(reinterpret_cast<void*>(addr), machine_code, sizeof(machine_code));
+#if defined(__aarch64__)
+ // We use this as a memory barrier for Arm64.
+ ASSERT_THAT(Protect(addr, kPageSize, PROT_READ | PROT_EXEC),
+ SyscallSucceeds());
+#endif
+
func = reinterpret_cast<uint32_t (*)(void)>(addr);
EXPECT_EQ(42, func());
diff --git a/test/syscalls/linux/mount.cc b/test/syscalls/linux/mount.cc
index 3aab25b23..d65b7d031 100644
--- a/test/syscalls/linux/mount.cc
+++ b/test/syscalls/linux/mount.cc
@@ -34,6 +34,7 @@
#include "test/util/mount_util.h"
#include "test/util/multiprocess_util.h"
#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
@@ -131,7 +132,9 @@ TEST(MountTest, UmountDetach) {
ASSERT_NO_ERRNO_AND_VALUE(Mount("", dir.path(), "tmpfs", 0, "mode=0700",
/* umountflags= */ MNT_DETACH));
const struct stat after = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path()));
- EXPECT_NE(before.st_ino, after.st_ino);
+ EXPECT_FALSE(before.st_dev == after.st_dev && before.st_ino == after.st_ino)
+ << "mount point has device number " << before.st_dev
+ << " and inode number " << before.st_ino << " before and after mount";
// Create files in the new mount.
constexpr char kContents[] = "no no no";
@@ -147,12 +150,14 @@ TEST(MountTest, UmountDetach) {
// Unmount the tmpfs.
mount.Release()();
- // Only check for inode number equality if the directory is not in overlayfs.
- // If xino option is not enabled and if all overlayfs layers do not belong to
- // the same filesystem then "the value of st_ino for directory objects may not
- // be persistent and could change even while the overlay filesystem is
- // mounted." -- Documentation/filesystems/overlayfs.txt
- if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) {
+ // Inode numbers for gofer-accessed files may change across save/restore.
+ //
+ // For overlayfs, if xino option is not enabled and if all overlayfs layers do
+ // not belong to the same filesystem then "the value of st_ino for directory
+ // objects may not be persistent and could change even while the overlay
+ // filesystem is mounted." -- Documentation/filesystems/overlayfs.txt
+ if (!IsRunningWithSaveRestore() &&
+ !ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) {
const struct stat after2 = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path()));
EXPECT_EQ(before.st_ino, after2.st_ino);
}
@@ -214,18 +219,23 @@ TEST(MountTest, MountTmpfs) {
const struct stat s = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path()));
EXPECT_EQ(s.st_mode, S_IFDIR | 0700);
- EXPECT_NE(s.st_ino, before.st_ino);
+ EXPECT_FALSE(before.st_dev == s.st_dev && before.st_ino == s.st_ino)
+ << "mount point has device number " << before.st_dev
+ << " and inode number " << before.st_ino << " before and after mount";
EXPECT_NO_ERRNO(Open(JoinPath(dir.path(), "foo"), O_CREAT | O_RDWR, 0777));
}
// Now that dir is unmounted again, we should have the old inode back.
- // Only check for inode number equality if the directory is not in overlayfs.
- // If xino option is not enabled and if all overlayfs layers do not belong to
- // the same filesystem then "the value of st_ino for directory objects may not
- // be persistent and could change even while the overlay filesystem is
- // mounted." -- Documentation/filesystems/overlayfs.txt
- if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) {
+ //
+ // Inode numbers for gofer-accessed files may change across save/restore.
+ //
+ // For overlayfs, if xino option is not enabled and if all overlayfs layers do
+ // not belong to the same filesystem then "the value of st_ino for directory
+ // objects may not be persistent and could change even while the overlay
+ // filesystem is mounted." -- Documentation/filesystems/overlayfs.txt
+ if (!IsRunningWithSaveRestore() &&
+ !ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) {
const struct stat after = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path()));
EXPECT_EQ(before.st_ino, after.st_ino);
}
diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc
index b558e3a01..a7c46adbf 100644
--- a/test/syscalls/linux/packet_socket_raw.cc
+++ b/test/syscalls/linux/packet_socket_raw.cc
@@ -664,6 +664,17 @@ TEST_P(RawPacketTest, SetAndGetSocketLinger) {
EXPECT_EQ(0, memcmp(&sl, &got_linger, length));
}
+TEST_P(RawPacketTest, GetSocketAcceptConn) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int got = -1;
+ socklen_t length = sizeof(got);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ACCEPTCONN, &got, &length),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_EQ(length, sizeof(got));
+ EXPECT_EQ(got, 0);
+}
INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest,
::testing::Values(ETH_P_IP, ETH_P_ALL));
diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc
index e8fcc4439..7a0f33dff 100644
--- a/test/syscalls/linux/proc.cc
+++ b/test/syscalls/linux/proc.cc
@@ -26,6 +26,7 @@
#include <string.h>
#include <sys/mman.h>
#include <sys/prctl.h>
+#include <sys/ptrace.h>
#include <sys/stat.h>
#include <sys/statfs.h>
#include <sys/utsname.h>
@@ -512,6 +513,414 @@ TEST(ProcSelfAuxv, EntryValues) {
EXPECT_EQ(i, proc_auxv.size());
}
+// Just open and read a part of /proc/self/mem, check that we can read an item.
+TEST(ProcPidMem, Read) {
+ auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/mem", O_RDONLY));
+ char input[] = "hello-world";
+ char output[sizeof(input)];
+ ASSERT_THAT(pread(memfd.get(), output, sizeof(output),
+ reinterpret_cast<off_t>(input)),
+ SyscallSucceedsWithValue(sizeof(input)));
+ ASSERT_STREQ(input, output);
+}
+
+// Perform read on an unmapped region.
+TEST(ProcPidMem, Unmapped) {
+ // Strategy: map then unmap, so we have a guaranteed unmapped region
+ auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/mem", O_RDONLY));
+ Mapping mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ // Fill it with things
+ memset(mapping.ptr(), 'x', mapping.len());
+ char expected = 'x', output;
+ ASSERT_THAT(pread(memfd.get(), &output, sizeof(output),
+ reinterpret_cast<off_t>(mapping.ptr())),
+ SyscallSucceedsWithValue(sizeof(output)));
+ ASSERT_EQ(expected, output);
+
+ // Unmap region again
+ ASSERT_THAT(munmap(mapping.ptr(), mapping.len()), SyscallSucceeds());
+
+ // Now we want EIO error
+ ASSERT_THAT(pread(memfd.get(), &output, sizeof(output),
+ reinterpret_cast<off_t>(mapping.ptr())),
+ SyscallFailsWithErrno(EIO));
+}
+
+// Perform read repeatedly to verify offset change.
+TEST(ProcPidMem, RepeatedRead) {
+ auto const num_reads = 3;
+ char expected[] = "01234567890abcdefghijkl";
+ char output[sizeof(expected) / num_reads];
+
+ auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/mem", O_RDONLY));
+ ASSERT_THAT(lseek(memfd.get(), reinterpret_cast<off_t>(&expected), SEEK_SET),
+ SyscallSucceedsWithValue(reinterpret_cast<off_t>(&expected)));
+ for (auto i = 0; i < num_reads; i++) {
+ ASSERT_THAT(read(memfd.get(), &output, sizeof(output)),
+ SyscallSucceedsWithValue(sizeof(output)));
+ ASSERT_EQ(strncmp(&expected[i * sizeof(output)], output, sizeof(output)),
+ 0);
+ }
+}
+
+// Perform seek operations repeatedly.
+TEST(ProcPidMem, RepeatedSeek) {
+ auto const num_reads = 3;
+ char expected[] = "01234567890abcdefghijkl";
+ char output[sizeof(expected) / num_reads];
+
+ auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/mem", O_RDONLY));
+ ASSERT_THAT(lseek(memfd.get(), reinterpret_cast<off_t>(&expected), SEEK_SET),
+ SyscallSucceedsWithValue(reinterpret_cast<off_t>(&expected)));
+ // Read from start
+ ASSERT_THAT(read(memfd.get(), &output, sizeof(output)),
+ SyscallSucceedsWithValue(sizeof(output)));
+ ASSERT_EQ(strncmp(&expected[0 * sizeof(output)], output, sizeof(output)), 0);
+ // Skip ahead one read
+ ASSERT_THAT(lseek(memfd.get(), sizeof(output), SEEK_CUR),
+ SyscallSucceedsWithValue(reinterpret_cast<off_t>(&expected) +
+ sizeof(output) * 2));
+ // Do read again
+ ASSERT_THAT(read(memfd.get(), &output, sizeof(output)),
+ SyscallSucceedsWithValue(sizeof(output)));
+ ASSERT_EQ(strncmp(&expected[2 * sizeof(output)], output, sizeof(output)), 0);
+ // Skip back three reads
+ ASSERT_THAT(lseek(memfd.get(), -3 * sizeof(output), SEEK_CUR),
+ SyscallSucceedsWithValue(reinterpret_cast<off_t>(&expected)));
+ // Do read again
+ ASSERT_THAT(read(memfd.get(), &output, sizeof(output)),
+ SyscallSucceedsWithValue(sizeof(output)));
+ ASSERT_EQ(strncmp(&expected[0 * sizeof(output)], output, sizeof(output)), 0);
+ // Check that SEEK_END does not work
+ ASSERT_THAT(lseek(memfd.get(), 0, SEEK_END), SyscallFailsWithErrno(EINVAL));
+}
+
+// Perform read past an allocated memory region.
+TEST(ProcPidMem, PartialRead) {
+ // Strategy: map large region, then do unmap and remap smaller region
+ auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/mem", O_RDONLY));
+
+ Mapping mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ ASSERT_THAT(munmap(mapping.ptr(), mapping.len()), SyscallSucceeds());
+ Mapping smaller_mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ Mmap(mapping.ptr(), kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0));
+
+ // Fill it with things
+ memset(smaller_mapping.ptr(), 'x', smaller_mapping.len());
+
+ // Now we want no error
+ char expected[] = {'x'};
+ std::unique_ptr<char[]> output(new char[kPageSize]);
+ off_t read_offset =
+ reinterpret_cast<off_t>(smaller_mapping.ptr()) + kPageSize - 1;
+ ASSERT_THAT(
+ pread(memfd.get(), output.get(), sizeof(output.get()), read_offset),
+ SyscallSucceedsWithValue(sizeof(expected)));
+ // Since output is larger, than expected we have to do manual compare
+ ASSERT_EQ(expected[0], (output).get()[0]);
+}
+
+// Perform read on /proc/[pid]/mem after exit.
+TEST(ProcPidMem, AfterExit) {
+ int pfd1[2] = {};
+ int pfd2[2] = {};
+
+ char expected[] = "hello-world";
+
+ ASSERT_THAT(pipe(pfd1), SyscallSucceeds());
+ ASSERT_THAT(pipe(pfd2), SyscallSucceeds());
+
+ // Create child process
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // Close reading end of first pipe
+ close(pfd1[0]);
+
+ // Tell parent about location of input
+ char ok = 1;
+ TEST_CHECK(WriteFd(pfd1[1], &ok, sizeof(ok)) == sizeof(ok));
+ TEST_PCHECK(close(pfd1[1]) == 0);
+
+ // Close writing end of second pipe
+ TEST_PCHECK(close(pfd2[1]) == 0);
+
+ // Await parent OK to die
+ ok = 0;
+ TEST_CHECK(ReadFd(pfd2[0], &ok, sizeof(ok)) == sizeof(ok));
+
+ // Close rest pipes
+ TEST_PCHECK(close(pfd2[0]) == 0);
+ _exit(0);
+ }
+
+ // In parent process.
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ // Close writing end of first pipe
+ EXPECT_THAT(close(pfd1[1]), SyscallSucceeds());
+
+ // Wait for child to be alive and well
+ char ok = 0;
+ EXPECT_THAT(ReadFd(pfd1[0], &ok, sizeof(ok)),
+ SyscallSucceedsWithValue(sizeof(ok)));
+ // Close reading end of first pipe
+ EXPECT_THAT(close(pfd1[0]), SyscallSucceeds());
+
+ // Open /proc/pid/mem fd
+ std::string mempath = absl::StrCat("/proc/", child_pid, "/mem");
+ auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open(mempath, O_RDONLY));
+
+ // Expect that we can read
+ char output[sizeof(expected)];
+ EXPECT_THAT(pread(memfd.get(), &output, sizeof(output),
+ reinterpret_cast<off_t>(&expected)),
+ SyscallSucceedsWithValue(sizeof(output)));
+ EXPECT_STREQ(expected, output);
+
+ // Tell proc its ok to go
+ EXPECT_THAT(close(pfd2[0]), SyscallSucceeds());
+ ok = 1;
+ EXPECT_THAT(WriteFd(pfd2[1], &ok, sizeof(ok)),
+ SyscallSucceedsWithValue(sizeof(ok)));
+ EXPECT_THAT(close(pfd2[1]), SyscallSucceeds());
+
+ // Expect termination
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0), SyscallSucceeds());
+
+ // Expect that we can't read anymore
+ EXPECT_THAT(pread(memfd.get(), &output, sizeof(output),
+ reinterpret_cast<off_t>(&expected)),
+ SyscallSucceedsWithValue(0));
+}
+
+// Read from /proc/[pid]/mem with different UID/GID and attached state.
+TEST(ProcPidMem, DifferentUserAttached) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID)));
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_DAC_OVERRIDE)));
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_PTRACE)));
+
+ int pfd1[2] = {};
+ int pfd2[2] = {};
+
+ ASSERT_THAT(pipe(pfd1), SyscallSucceeds());
+ ASSERT_THAT(pipe(pfd2), SyscallSucceeds());
+
+ // Create child process
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // Close reading end of first pipe
+ close(pfd1[0]);
+
+ // Tell parent about location of input
+ char input[] = "hello-world";
+ off_t input_location = reinterpret_cast<off_t>(input);
+ TEST_CHECK(WriteFd(pfd1[1], &input_location, sizeof(input_location)) ==
+ sizeof(input_location));
+ TEST_PCHECK(close(pfd1[1]) == 0);
+
+ // Close writing end of second pipe
+ TEST_PCHECK(close(pfd2[1]) == 0);
+
+ // Await parent OK to die
+ char ok = 0;
+ TEST_CHECK(ReadFd(pfd2[0], &ok, sizeof(ok)) == sizeof(ok));
+
+ // Close rest pipes
+ TEST_PCHECK(close(pfd2[0]) == 0);
+ _exit(0);
+ }
+
+ // In parent process.
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ // Close writing end of first pipe
+ EXPECT_THAT(close(pfd1[1]), SyscallSucceeds());
+
+ // Read target location from child
+ off_t target_location;
+ EXPECT_THAT(ReadFd(pfd1[0], &target_location, sizeof(target_location)),
+ SyscallSucceedsWithValue(sizeof(target_location)));
+ // Close reading end of first pipe
+ EXPECT_THAT(close(pfd1[0]), SyscallSucceeds());
+
+ ScopedThread([&] {
+ // Attach to child subprocess without stopping it
+ EXPECT_THAT(ptrace(PTRACE_SEIZE, child_pid, NULL, NULL), SyscallSucceeds());
+
+ // Keep capabilities after setuid
+ EXPECT_THAT(prctl(PR_SET_KEEPCAPS, 1, 0, 0, 0), SyscallSucceeds());
+ constexpr int kNobody = 65534;
+ EXPECT_THAT(syscall(SYS_setuid, kNobody), SyscallSucceeds());
+
+ // Only restore CAP_SYS_PTRACE and CAP_DAC_OVERRIDE
+ EXPECT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, true));
+ EXPECT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, true));
+
+ // Open /proc/pid/mem fd
+ std::string mempath = absl::StrCat("/proc/", child_pid, "/mem");
+ auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open(mempath, O_RDONLY));
+ char expected[] = "hello-world";
+ char output[sizeof(expected)];
+ EXPECT_THAT(pread(memfd.get(), output, sizeof(output),
+ reinterpret_cast<off_t>(target_location)),
+ SyscallSucceedsWithValue(sizeof(output)));
+ EXPECT_STREQ(expected, output);
+
+ // Tell proc its ok to go
+ EXPECT_THAT(close(pfd2[0]), SyscallSucceeds());
+ char ok = 1;
+ EXPECT_THAT(WriteFd(pfd2[1], &ok, sizeof(ok)),
+ SyscallSucceedsWithValue(sizeof(ok)));
+ EXPECT_THAT(close(pfd2[1]), SyscallSucceeds());
+
+ // Expect termination
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << " status " << status;
+ });
+}
+
+// Attempt to read from /proc/[pid]/mem with different UID/GID.
+TEST(ProcPidMem, DifferentUser) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID)));
+
+ int pfd1[2] = {};
+ int pfd2[2] = {};
+
+ ASSERT_THAT(pipe(pfd1), SyscallSucceeds());
+ ASSERT_THAT(pipe(pfd2), SyscallSucceeds());
+
+ // Create child process
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // Close reading end of first pipe
+ close(pfd1[0]);
+
+ // Tell parent about location of input
+ char input[] = "hello-world";
+ off_t input_location = reinterpret_cast<off_t>(input);
+ TEST_CHECK(WriteFd(pfd1[1], &input_location, sizeof(input_location)) ==
+ sizeof(input_location));
+ TEST_PCHECK(close(pfd1[1]) == 0);
+
+ // Close writing end of second pipe
+ TEST_PCHECK(close(pfd2[1]) == 0);
+
+ // Await parent OK to die
+ char ok = 0;
+ TEST_CHECK(ReadFd(pfd2[0], &ok, sizeof(ok)) == sizeof(ok));
+
+ // Close rest pipes
+ TEST_PCHECK(close(pfd2[0]) == 0);
+ _exit(0);
+ }
+
+ // In parent process.
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ // Close writing end of first pipe
+ EXPECT_THAT(close(pfd1[1]), SyscallSucceeds());
+
+ // Read target location from child
+ off_t target_location;
+ EXPECT_THAT(ReadFd(pfd1[0], &target_location, sizeof(target_location)),
+ SyscallSucceedsWithValue(sizeof(target_location)));
+ // Close reading end of first pipe
+ EXPECT_THAT(close(pfd1[0]), SyscallSucceeds());
+
+ ScopedThread([&] {
+ constexpr int kNobody = 65534;
+ EXPECT_THAT(syscall(SYS_setuid, kNobody), SyscallSucceeds());
+
+ // Attempt to open /proc/[child_pid]/mem
+ std::string mempath = absl::StrCat("/proc/", child_pid, "/mem");
+ EXPECT_THAT(open(mempath.c_str(), O_RDONLY), SyscallFailsWithErrno(EACCES));
+
+ // Tell proc its ok to go
+ EXPECT_THAT(close(pfd2[0]), SyscallSucceeds());
+ char ok = 1;
+ EXPECT_THAT(WriteFd(pfd2[1], &ok, sizeof(ok)),
+ SyscallSucceedsWithValue(sizeof(ok)));
+ EXPECT_THAT(close(pfd2[1]), SyscallSucceeds());
+
+ // Expect termination
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0), SyscallSucceeds());
+ });
+}
+
+// Perform read on /proc/[pid]/mem with same UID/GID.
+TEST(ProcPidMem, SameUser) {
+ int pfd1[2] = {};
+ int pfd2[2] = {};
+
+ ASSERT_THAT(pipe(pfd1), SyscallSucceeds());
+ ASSERT_THAT(pipe(pfd2), SyscallSucceeds());
+
+ // Create child process
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // Close reading end of first pipe
+ close(pfd1[0]);
+
+ // Tell parent about location of input
+ char input[] = "hello-world";
+ off_t input_location = reinterpret_cast<off_t>(input);
+ TEST_CHECK(WriteFd(pfd1[1], &input_location, sizeof(input_location)) ==
+ sizeof(input_location));
+ TEST_PCHECK(close(pfd1[1]) == 0);
+
+ // Close writing end of second pipe
+ TEST_PCHECK(close(pfd2[1]) == 0);
+
+ // Await parent OK to die
+ char ok = 0;
+ TEST_CHECK(ReadFd(pfd2[0], &ok, sizeof(ok)) == sizeof(ok));
+
+ // Close rest pipes
+ TEST_PCHECK(close(pfd2[0]) == 0);
+ _exit(0);
+ }
+ // In parent process.
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ // Close writing end of first pipe
+ EXPECT_THAT(close(pfd1[1]), SyscallSucceeds());
+
+ // Read target location from child
+ off_t target_location;
+ EXPECT_THAT(ReadFd(pfd1[0], &target_location, sizeof(target_location)),
+ SyscallSucceedsWithValue(sizeof(target_location)));
+ // Close reading end of first pipe
+ EXPECT_THAT(close(pfd1[0]), SyscallSucceeds());
+
+ // Open /proc/pid/mem fd
+ std::string mempath = absl::StrCat("/proc/", child_pid, "/mem");
+ auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open(mempath, O_RDONLY));
+ char expected[] = "hello-world";
+ char output[sizeof(expected)];
+ EXPECT_THAT(pread(memfd.get(), output, sizeof(output),
+ reinterpret_cast<off_t>(target_location)),
+ SyscallSucceedsWithValue(sizeof(output)));
+ EXPECT_STREQ(expected, output);
+
+ // Tell proc its ok to go
+ EXPECT_THAT(close(pfd2[0]), SyscallSucceeds());
+ char ok = 1;
+ EXPECT_THAT(WriteFd(pfd2[1], &ok, sizeof(ok)),
+ SyscallSucceedsWithValue(sizeof(ok)));
+ EXPECT_THAT(close(pfd2[1]), SyscallSucceeds());
+
+ // Expect termination
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0), SyscallSucceeds());
+}
+
// Just open and read /proc/self/maps, check that we can find [stack]
TEST(ProcSelfMaps, Basic) {
auto proc_self_maps =
diff --git a/test/syscalls/linux/proc_pid_smaps.cc b/test/syscalls/linux/proc_pid_smaps.cc
index 9fb1b3a2c..738923822 100644
--- a/test/syscalls/linux/proc_pid_smaps.cc
+++ b/test/syscalls/linux/proc_pid_smaps.cc
@@ -191,7 +191,7 @@ PosixErrorOr<std::vector<ProcPidSmapsEntry>> ParseProcPidSmaps(
// amount of whitespace).
if (!entry) {
std::cerr << "smaps line not considered a maps line: "
- << maybe_maps_entry.error_message() << std::endl;
+ << maybe_maps_entry.error().message() << std::endl;
return PosixError(
EINVAL,
absl::StrCat("smaps field line without preceding maps line: ", l));
diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc
index 926690eb8..13c19d4a8 100644
--- a/test/syscalls/linux/ptrace.cc
+++ b/test/syscalls/linux/ptrace.cc
@@ -30,10 +30,13 @@
#include "absl/flags/flag.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
+#include "test/util/fs_util.h"
#include "test/util/logging.h"
+#include "test/util/memory_util.h"
#include "test/util/multiprocess_util.h"
#include "test/util/platform_util.h"
#include "test/util/signal_util.h"
+#include "test/util/temp_path.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
#include "test/util/time_util.h"
@@ -113,10 +116,21 @@ TEST(PtraceTest, AttachParent_PeekData_PokeData_SignalSuppression) {
// except disabled.
SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) > 0);
- constexpr long kBeforePokeDataValue = 10;
- constexpr long kAfterPokeDataValue = 20;
+ // Test PTRACE_POKE/PEEKDATA on both anonymous and file mappings.
+ const auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ ASSERT_NO_ERRNO(Truncate(file.path(), kPageSize));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+ const auto file_mapping = ASSERT_NO_ERRNO_AND_VALUE(Mmap(
+ nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd.get(), 0));
- volatile long word = kBeforePokeDataValue;
+ constexpr long kBeforePokeDataAnonValue = 10;
+ constexpr long kAfterPokeDataAnonValue = 20;
+ constexpr long kBeforePokeDataFileValue = 0; // implicit, due to truncate()
+ constexpr long kAfterPokeDataFileValue = 30;
+
+ volatile long anon_word = kBeforePokeDataAnonValue;
+ auto* file_word_ptr = static_cast<volatile long*>(file_mapping.ptr());
pid_t const child_pid = fork();
if (child_pid == 0) {
@@ -134,12 +148,22 @@ TEST(PtraceTest, AttachParent_PeekData_PokeData_SignalSuppression) {
MaybeSave();
TEST_CHECK(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP);
- // Replace the value of word in the parent process with kAfterPokeDataValue.
- long const parent_word = ptrace(PTRACE_PEEKDATA, parent_pid, &word, 0);
+ // Replace the value of anon_word in the parent process with
+ // kAfterPokeDataAnonValue.
+ long parent_word = ptrace(PTRACE_PEEKDATA, parent_pid, &anon_word, 0);
+ MaybeSave();
+ TEST_CHECK(parent_word == kBeforePokeDataAnonValue);
+ TEST_PCHECK(ptrace(PTRACE_POKEDATA, parent_pid, &anon_word,
+ kAfterPokeDataAnonValue) == 0);
+ MaybeSave();
+
+ // Replace the value pointed to by file_word_ptr in the mapped file with
+ // kAfterPokeDataFileValue, via the parent process' mapping.
+ parent_word = ptrace(PTRACE_PEEKDATA, parent_pid, file_word_ptr, 0);
MaybeSave();
- TEST_CHECK(parent_word == kBeforePokeDataValue);
- TEST_PCHECK(
- ptrace(PTRACE_POKEDATA, parent_pid, &word, kAfterPokeDataValue) == 0);
+ TEST_CHECK(parent_word == kBeforePokeDataFileValue);
+ TEST_PCHECK(ptrace(PTRACE_POKEDATA, parent_pid, file_word_ptr,
+ kAfterPokeDataFileValue) == 0);
MaybeSave();
// Detach from the parent and suppress the SIGSTOP. If the SIGSTOP is not
@@ -160,7 +184,8 @@ TEST(PtraceTest, AttachParent_PeekData_PokeData_SignalSuppression) {
<< " status " << status;
// Check that the child's PTRACE_POKEDATA was effective.
- EXPECT_EQ(kAfterPokeDataValue, word);
+ EXPECT_EQ(kAfterPokeDataAnonValue, anon_word);
+ EXPECT_EQ(kAfterPokeDataFileValue, *file_word_ptr);
}
TEST(PtraceTest, GetSigMask) {
diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc
index 1b9dbc584..bd779da92 100644
--- a/test/syscalls/linux/raw_socket_icmp.cc
+++ b/test/syscalls/linux/raw_socket_icmp.cc
@@ -438,6 +438,19 @@ TEST_F(RawSocketICMPTest, SetAndGetSocketLinger) {
EXPECT_EQ(0, memcmp(&sl, &got_linger, length));
}
+// Test getsockopt for SO_ACCEPTCONN.
+TEST_F(RawSocketICMPTest, GetSocketAcceptConn) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int got = -1;
+ socklen_t length = sizeof(got);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ACCEPTCONN, &got, &length),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_EQ(length, sizeof(got));
+ EXPECT_EQ(got, 0);
+}
+
void RawSocketICMPTest::ExpectICMPSuccess(const struct icmphdr& icmp) {
// We're going to receive both the echo request and reply, but the order is
// indeterminate.
diff --git a/test/syscalls/linux/semaphore.cc b/test/syscalls/linux/semaphore.cc
index e9b131ca9..890f4a246 100644
--- a/test/syscalls/linux/semaphore.cc
+++ b/test/syscalls/linux/semaphore.cc
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <signal.h>
#include <sys/ipc.h>
#include <sys/sem.h>
#include <sys/types.h>
@@ -486,6 +487,292 @@ TEST(SemaphoreTest, SemIpcSet) {
ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallFailsWithErrno(EACCES));
}
+TEST(SemaphoreTest, SemCtlIpcStat) {
+ // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false));
+ const uid_t kUid = getuid();
+ const gid_t kGid = getgid();
+ time_t start_time = time(nullptr);
+
+ AutoSem sem(semget(IPC_PRIVATE, 10, 0600 | IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+
+ struct semid_ds ds;
+ EXPECT_THAT(semctl(sem.get(), 0, IPC_STAT, &ds), SyscallSucceeds());
+
+ EXPECT_EQ(ds.sem_perm.__key, IPC_PRIVATE);
+ EXPECT_EQ(ds.sem_perm.uid, kUid);
+ EXPECT_EQ(ds.sem_perm.gid, kGid);
+ EXPECT_EQ(ds.sem_perm.cuid, kUid);
+ EXPECT_EQ(ds.sem_perm.cgid, kGid);
+ EXPECT_EQ(ds.sem_perm.mode, 0600);
+ // Last semop time is not set on creation.
+ EXPECT_EQ(ds.sem_otime, 0);
+ EXPECT_GE(ds.sem_ctime, start_time);
+ EXPECT_EQ(ds.sem_nsems, 10);
+
+ // The timestamps only have a resolution of seconds; slow down so we actually
+ // see the timestamps change.
+ absl::SleepFor(absl::Seconds(1));
+
+ // Set semid_ds structure of the set.
+ auto last_ctime = ds.sem_ctime;
+ start_time = time(nullptr);
+ struct semid_ds semid_to_set = {};
+ semid_to_set.sem_perm.uid = kUid;
+ semid_to_set.sem_perm.gid = kGid;
+ semid_to_set.sem_perm.mode = 0666;
+ ASSERT_THAT(semctl(sem.get(), 0, IPC_SET, &semid_to_set), SyscallSucceeds());
+ struct sembuf buf = {};
+ buf.sem_op = 1;
+ ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallSucceeds());
+
+ EXPECT_THAT(semctl(sem.get(), 0, IPC_STAT, &ds), SyscallSucceeds());
+ EXPECT_EQ(ds.sem_perm.mode, 0666);
+ EXPECT_GE(ds.sem_otime, start_time);
+ EXPECT_GT(ds.sem_ctime, last_ctime);
+
+ // An invalid semid fails the syscall with errno EINVAL.
+ EXPECT_THAT(semctl(sem.get() + 1, 0, IPC_STAT, &ds),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Make semaphore not readable and check the signal fails.
+ semid_to_set.sem_perm.mode = 0200;
+ ASSERT_THAT(semctl(sem.get(), 0, IPC_SET, &semid_to_set), SyscallSucceeds());
+ EXPECT_THAT(semctl(sem.get(), 0, IPC_STAT, &ds),
+ SyscallFailsWithErrno(EACCES));
+}
+
+// Calls semctl(semid, 0, cmd) until the returned value is >= target, an
+// internal timeout expires, or semctl returns an error.
+PosixErrorOr<int> WaitSemctl(int semid, int target, int cmd) {
+ constexpr absl::Duration timeout = absl::Seconds(10);
+ const auto deadline = absl::Now() + timeout;
+ int semcnt = 0;
+ while (absl::Now() < deadline) {
+ semcnt = semctl(semid, 0, cmd);
+ if (semcnt < 0) {
+ return PosixError(errno, "semctl(GETZCNT) failed");
+ }
+ if (semcnt >= target) {
+ break;
+ }
+ absl::SleepFor(absl::Milliseconds(10));
+ }
+ return semcnt;
+}
+
+TEST(SemaphoreTest, SemopGetzcnt) {
+ // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false));
+ // Create a write only semaphore set.
+ AutoSem sem(semget(IPC_PRIVATE, 1, 0200 | IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+
+ // No read permission to retrieve semzcnt.
+ EXPECT_THAT(semctl(sem.get(), 0, GETZCNT), SyscallFailsWithErrno(EACCES));
+
+ // Remove the calling thread's read permission.
+ struct semid_ds ds = {};
+ ds.sem_perm.uid = getuid();
+ ds.sem_perm.gid = getgid();
+ ds.sem_perm.mode = 0600;
+ ASSERT_THAT(semctl(sem.get(), 0, IPC_SET, &ds), SyscallSucceeds());
+
+ std::vector<pid_t> children;
+ ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 1), SyscallSucceeds());
+
+ struct sembuf buf = {};
+ buf.sem_num = 0;
+ buf.sem_op = 0;
+ constexpr size_t kLoops = 10;
+ for (auto i = 0; i < kLoops; i++) {
+ auto child_pid = fork();
+ if (child_pid == 0) {
+ TEST_PCHECK(RetryEINTR(semop)(sem.get(), &buf, 1) == 0);
+ _exit(0);
+ }
+ children.push_back(child_pid);
+ }
+
+ EXPECT_THAT(WaitSemctl(sem.get(), kLoops, GETZCNT),
+ IsPosixErrorOkAndHolds(kLoops));
+ // Set semval to 0, which wakes up children that sleep on the semop.
+ ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 0), SyscallSucceeds());
+ for (const auto& child_pid : children) {
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0);
+ }
+ EXPECT_EQ(semctl(sem.get(), 0, GETZCNT), 0);
+}
+
+TEST(SemaphoreTest, SemopGetzcntOnSetRemoval) {
+ auto semid = semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT);
+ ASSERT_THAT(semid, SyscallSucceeds());
+ ASSERT_THAT(semctl(semid, 0, SETVAL, 1), SyscallSucceeds());
+ ASSERT_EQ(semctl(semid, 0, GETZCNT), 0);
+
+ auto child_pid = fork();
+ if (child_pid == 0) {
+ struct sembuf buf = {};
+ buf.sem_num = 0;
+ buf.sem_op = 0;
+
+ // Ensure that wait will only unblock when the semaphore is removed. On
+ // EINTR retry it may race with deletion and return EINVAL.
+ TEST_PCHECK(RetryEINTR(semop)(semid, &buf, 1) < 0 &&
+ (errno == EIDRM || errno == EINVAL));
+ _exit(0);
+ }
+
+ EXPECT_THAT(WaitSemctl(semid, 1, GETZCNT), IsPosixErrorOkAndHolds(1));
+ // Remove the semaphore set, which fails the sleep semop.
+ ASSERT_THAT(semctl(semid, 0, IPC_RMID), SyscallSucceeds());
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0);
+ EXPECT_THAT(semctl(semid, 0, GETZCNT), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(SemaphoreTest, SemopGetzcntOnSignal_NoRandomSave) {
+ AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+ ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 1), SyscallSucceeds());
+ ASSERT_EQ(semctl(sem.get(), 0, GETZCNT), 0);
+
+ // Saving will cause semop() to be spuriously interrupted.
+ DisableSave ds;
+
+ auto child_pid = fork();
+ if (child_pid == 0) {
+ TEST_PCHECK(signal(SIGHUP, [](int sig) -> void {}) != SIG_ERR);
+ struct sembuf buf = {};
+ buf.sem_num = 0;
+ buf.sem_op = 0;
+
+ TEST_PCHECK(semop(sem.get(), &buf, 1) < 0 && errno == EINTR);
+ _exit(0);
+ }
+
+ EXPECT_THAT(WaitSemctl(sem.get(), 1, GETZCNT), IsPosixErrorOkAndHolds(1));
+ // Send a signal to the child, which fails the sleep semop.
+ ASSERT_EQ(kill(child_pid, SIGHUP), 0);
+
+ ds.reset();
+
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0);
+ EXPECT_EQ(semctl(sem.get(), 0, GETZCNT), 0);
+}
+
+TEST(SemaphoreTest, SemopGetncnt) {
+ // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false));
+ // Create a write only semaphore set.
+ AutoSem sem(semget(IPC_PRIVATE, 1, 0200 | IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+
+ // No read permission to retrieve semzcnt.
+ EXPECT_THAT(semctl(sem.get(), 0, GETNCNT), SyscallFailsWithErrno(EACCES));
+
+ // Remove the calling thread's read permission.
+ struct semid_ds ds = {};
+ ds.sem_perm.uid = getuid();
+ ds.sem_perm.gid = getgid();
+ ds.sem_perm.mode = 0600;
+ ASSERT_THAT(semctl(sem.get(), 0, IPC_SET, &ds), SyscallSucceeds());
+
+ std::vector<pid_t> children;
+
+ struct sembuf buf = {};
+ buf.sem_num = 0;
+ buf.sem_op = -1;
+ constexpr size_t kLoops = 10;
+ for (auto i = 0; i < kLoops; i++) {
+ auto child_pid = fork();
+ if (child_pid == 0) {
+ TEST_PCHECK(RetryEINTR(semop)(sem.get(), &buf, 1) == 0);
+ _exit(0);
+ }
+ children.push_back(child_pid);
+ }
+ EXPECT_THAT(WaitSemctl(sem.get(), kLoops, GETNCNT),
+ IsPosixErrorOkAndHolds(kLoops));
+ // Set semval to 1, which wakes up children that sleep on the semop.
+ ASSERT_THAT(semctl(sem.get(), 0, SETVAL, kLoops), SyscallSucceeds());
+ for (const auto& child_pid : children) {
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0);
+ }
+ EXPECT_EQ(semctl(sem.get(), 0, GETNCNT), 0);
+}
+
+TEST(SemaphoreTest, SemopGetncntOnSetRemoval) {
+ auto semid = semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT);
+ ASSERT_THAT(semid, SyscallSucceeds());
+ ASSERT_EQ(semctl(semid, 0, GETNCNT), 0);
+
+ auto child_pid = fork();
+ if (child_pid == 0) {
+ struct sembuf buf = {};
+ buf.sem_num = 0;
+ buf.sem_op = -1;
+
+ // Ensure that wait will only unblock when the semaphore is removed. On
+ // EINTR retry it may race with deletion and return EINVAL
+ TEST_PCHECK(RetryEINTR(semop)(semid, &buf, 1) < 0 &&
+ (errno == EIDRM || errno == EINVAL));
+ _exit(0);
+ }
+
+ EXPECT_THAT(WaitSemctl(semid, 1, GETNCNT), IsPosixErrorOkAndHolds(1));
+ // Remove the semaphore set, which fails the sleep semop.
+ ASSERT_THAT(semctl(semid, 0, IPC_RMID), SyscallSucceeds());
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0);
+ EXPECT_THAT(semctl(semid, 0, GETNCNT), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(SemaphoreTest, SemopGetncntOnSignal_NoRandomSave) {
+ AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+ ASSERT_EQ(semctl(sem.get(), 0, GETNCNT), 0);
+
+ // Saving will cause semop() to be spuriously interrupted.
+ DisableSave ds;
+
+ auto child_pid = fork();
+ if (child_pid == 0) {
+ TEST_PCHECK(signal(SIGHUP, [](int sig) -> void {}) != SIG_ERR);
+ struct sembuf buf = {};
+ buf.sem_num = 0;
+ buf.sem_op = -1;
+
+ TEST_PCHECK(semop(sem.get(), &buf, 1) < 0 && errno == EINTR);
+ _exit(0);
+ }
+ EXPECT_THAT(WaitSemctl(sem.get(), 1, GETNCNT), IsPosixErrorOkAndHolds(1));
+ // Send a signal to the child, which fails the sleep semop.
+ ASSERT_EQ(kill(child_pid, SIGHUP), 0);
+
+ ds.reset();
+
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0);
+ EXPECT_EQ(semctl(sem.get(), 0, GETNCNT), 0);
+}
+
} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc
index a8bfb01f1..cf0977118 100644
--- a/test/syscalls/linux/sendfile.cc
+++ b/test/syscalls/linux/sendfile.cc
@@ -25,9 +25,11 @@
#include "absl/time/time.h"
#include "test/util/eventfd_util.h"
#include "test/util/file_descriptor.h"
+#include "test/util/signal_util.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
+#include "test/util/timer_util.h"
namespace gvisor {
namespace testing {
@@ -629,6 +631,57 @@ TEST(SendFileTest, SendFileToPipe) {
SyscallSucceedsWithValue(kDataSize));
}
+static volatile int signaled = 0;
+void SigUsr1Handler(int sig, siginfo_t* info, void* context) { signaled = 1; }
+
+TEST(SendFileTest, ToEventFDDoesNotSpin_NoRandomSave) {
+ FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0));
+
+ // Write the maximum value of an eventfd to a file.
+ const uint64_t kMaxEventfdValue = 0xfffffffffffffffe;
+ const auto tempfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const auto tempfd = ASSERT_NO_ERRNO_AND_VALUE(Open(tempfile.path(), O_RDWR));
+ ASSERT_THAT(
+ pwrite(tempfd.get(), &kMaxEventfdValue, sizeof(kMaxEventfdValue), 0),
+ SyscallSucceedsWithValue(sizeof(kMaxEventfdValue)));
+
+ // Set the eventfd's value to 1.
+ const uint64_t kOne = 1;
+ ASSERT_THAT(write(efd.get(), &kOne, sizeof(kOne)),
+ SyscallSucceedsWithValue(sizeof(kOne)));
+
+ // Set up signal handler.
+ struct sigaction sa = {};
+ sa.sa_sigaction = SigUsr1Handler;
+ sa.sa_flags = SA_SIGINFO;
+ const auto cleanup_sigact =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGUSR1, sa));
+
+ // Send SIGUSR1 to this thread in 1 second.
+ struct sigevent sev = {};
+ sev.sigev_notify = SIGEV_THREAD_ID;
+ sev.sigev_signo = SIGUSR1;
+ sev.sigev_notify_thread_id = gettid();
+ auto timer = ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev));
+ struct itimerspec its = {};
+ its.it_value = absl::ToTimespec(absl::Seconds(1));
+ DisableSave ds; // Asserting an EINTR.
+ ASSERT_NO_ERRNO(timer.Set(0, its));
+
+ // Sendfile from tempfd to the eventfd. Since the eventfd is not already at
+ // its maximum value, the eventfd is "ready for writing"; however, since the
+ // eventfd's existing value plus the new value would exceed the maximum, the
+ // write should internally fail with EWOULDBLOCK. In this case, sendfile()
+ // should block instead of spinning, and eventually be interrupted by our
+ // timer. See b/172075629.
+ EXPECT_THAT(
+ sendfile(efd.get(), tempfd.get(), nullptr, sizeof(kMaxEventfdValue)),
+ SyscallFailsWithErrno(EINTR));
+
+ // Signal should have been handled.
+ EXPECT_EQ(signaled, 1);
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 11fcec443..e19a83413 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -350,6 +350,10 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdownListen) {
sockaddr_storage conn_addr = connector.addr;
ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ // TODO(b/157236388): Remove Disable save after bug is fixed. S/R test can
+ // fail because the last socket may not be delivered to the accept queue
+ // by the time connect returns.
+ DisableSave ds;
for (int i = 0; i < kBacklog; i++) {
auto client = ASSERT_NO_ERRNO_AND_VALUE(
Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
@@ -554,7 +558,11 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdownWhileConnect) {
});
}
-TEST_P(SocketInetLoopbackTest, TCPbacklog) {
+// TODO(b/157236388): Remove _NoRandomSave once bug is fixed. Test fails w/
+// random save as established connections which can't be delivered to the accept
+// queue because the queue is full are not correctly delivered after restore
+// causing the last accept to timeout on the restore.
+TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) {
auto const& param = GetParam();
TestAddress const& listener = param.listener;
@@ -567,7 +575,8 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog) {
ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
listener.addr_len),
SyscallSucceeds());
- ASSERT_THAT(listen(listen_fd.get(), 2), SyscallSucceeds());
+ constexpr int kBacklogSize = 2;
+ ASSERT_THAT(listen(listen_fd.get(), kBacklogSize), SyscallSucceeds());
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
@@ -931,7 +940,7 @@ void setupTimeWaitClose(const TestAddress* listener,
}
// shutdown to trigger TIME_WAIT.
- ASSERT_THAT(shutdown(active_closefd.get(), SHUT_RDWR), SyscallSucceeds());
+ ASSERT_THAT(shutdown(active_closefd.get(), SHUT_WR), SyscallSucceeds());
{
const int kTimeout = 10000;
struct pollfd pfd = {
@@ -941,7 +950,8 @@ void setupTimeWaitClose(const TestAddress* listener,
ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
ASSERT_EQ(pfd.revents, POLLIN);
}
- ScopedThread t([&]() {
+ ASSERT_THAT(shutdown(passive_closefd.get(), SHUT_WR), SyscallSucceeds());
+ {
constexpr int kTimeout = 10000;
constexpr int16_t want_events = POLLHUP;
struct pollfd pfd = {
@@ -949,11 +959,8 @@ void setupTimeWaitClose(const TestAddress* listener,
.events = want_events,
};
ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
- });
+ }
- passive_closefd.reset();
- t.Join();
- active_closefd.reset();
// This sleep is needed to reduce flake to ensure that the passive-close
// ensures the state transitions to CLOSE from LAST_ACK.
absl::SleepFor(absl::Seconds(1));
@@ -1143,6 +1150,9 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) {
sockaddr_storage conn_addr = connector.addr;
ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+
+ // TODO(b/157236388): Reenable Cooperative S/R once bug is fixed.
+ DisableSave ds;
ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
reinterpret_cast<sockaddr*>(&conn_addr),
connector.addr_len),
diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc
index 3f2c0fdf2..f69f8f99f 100644
--- a/test/syscalls/linux/socket_ip_udp_generic.cc
+++ b/test/syscalls/linux/socket_ip_udp_generic.cc
@@ -472,5 +472,19 @@ TEST_P(UDPSocketPairTest, SetAndGetSocketLinger) {
EXPECT_EQ(0, memcmp(&sl, &got_linger, length));
}
+// Test getsockopt for SO_ACCEPTCONN on udp socket.
+TEST_P(UDPSocketPairTest, GetSocketAcceptConn) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int got = -1;
+ socklen_t length = sizeof(got);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_ACCEPTCONN, &got, &length),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_EQ(length, sizeof(got));
+ EXPECT_EQ(got, 0);
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_unbound.cc b/test/syscalls/linux/socket_ip_unbound.cc
index 8f7ccc868..029f1e872 100644
--- a/test/syscalls/linux/socket_ip_unbound.cc
+++ b/test/syscalls/linux/socket_ip_unbound.cc
@@ -454,23 +454,15 @@ TEST_P(IPUnboundSocketTest, SetReuseAddr) {
INSTANTIATE_TEST_SUITE_P(
IPUnboundSockets, IPUnboundSocketTest,
- ::testing::ValuesIn(VecCat<SocketKind>(VecCat<SocketKind>(
+ ::testing::ValuesIn(VecCat<SocketKind>(
ApplyVec<SocketKind>(IPv4UDPUnboundSocket,
- AllBitwiseCombinations(List<int>{SOCK_DGRAM},
- List<int>{0,
- SOCK_NONBLOCK})),
+ std::vector<int>{0, SOCK_NONBLOCK}),
ApplyVec<SocketKind>(IPv6UDPUnboundSocket,
- AllBitwiseCombinations(List<int>{SOCK_DGRAM},
- List<int>{0,
- SOCK_NONBLOCK})),
+ std::vector<int>{0, SOCK_NONBLOCK}),
ApplyVec<SocketKind>(IPv4TCPUnboundSocket,
- AllBitwiseCombinations(List<int>{SOCK_STREAM},
- List<int>{0,
- SOCK_NONBLOCK})),
+ std::vector{0, SOCK_NONBLOCK}),
ApplyVec<SocketKind>(IPv6TCPUnboundSocket,
- AllBitwiseCombinations(List<int>{SOCK_STREAM},
- List<int>{
- 0, SOCK_NONBLOCK}))))));
+ std::vector{0, SOCK_NONBLOCK}))));
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_unbound_netlink.cc b/test/syscalls/linux/socket_ip_unbound_netlink.cc
new file mode 100644
index 000000000..6036bfcaf
--- /dev/null
+++ b/test/syscalls/linux/socket_ip_unbound_netlink.cc
@@ -0,0 +1,104 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <netinet/in.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <cstdio>
+#include <cstring>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_netlink_route_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to pairs of IP sockets.
+using IPv6UnboundSocketTest = SimpleSocketTest;
+
+TEST_P(IPv6UnboundSocketTest, ConnectToBadLocalAddress_NoRandomSave) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ // TODO(gvisor.dev/issue/4595): Addresses on net devices are not saved
+ // across save/restore.
+ DisableSave ds;
+
+ // Delete the loopback address from the loopback interface.
+ Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink());
+ EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET6,
+ /*prefixlen=*/128, &in6addr_loopback,
+ sizeof(in6addr_loopback)));
+ Cleanup defer_addr_removal =
+ Cleanup([loopback_link = std::move(loopback_link)] {
+ EXPECT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET6,
+ /*prefixlen=*/128, &in6addr_loopback,
+ sizeof(in6addr_loopback)));
+ });
+
+ TestAddress addr = V6Loopback();
+ reinterpret_cast<sockaddr_in6*>(&addr.addr)->sin6_port = 65535;
+ auto sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ EXPECT_THAT(connect(sock->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallFailsWithErrno(EADDRNOTAVAIL));
+}
+
+INSTANTIATE_TEST_SUITE_P(IPUnboundSockets, IPv6UnboundSocketTest,
+ ::testing::ValuesIn(std::vector<SocketKind>{
+ IPv6UDPUnboundSocket(0),
+ IPv6TCPUnboundSocket(0)}));
+
+using IPv4UnboundSocketTest = SimpleSocketTest;
+
+TEST_P(IPv4UnboundSocketTest, ConnectToBadLocalAddress_NoRandomSave) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ // TODO(gvisor.dev/issue/4595): Addresses on net devices are not saved
+ // across save/restore.
+ DisableSave ds;
+
+ // Delete the loopback address from the loopback interface.
+ Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink());
+ struct in_addr laddr;
+ laddr.s_addr = htonl(INADDR_LOOPBACK);
+ EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET,
+ /*prefixlen=*/8, &laddr, sizeof(laddr)));
+ Cleanup defer_addr_removal = Cleanup(
+ [loopback_link = std::move(loopback_link), addr = std::move(laddr)] {
+ EXPECT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET,
+ /*prefixlen=*/8, &addr, sizeof(addr)));
+ });
+ TestAddress addr = V4Loopback();
+ reinterpret_cast<sockaddr_in*>(&addr.addr)->sin_port = 65535;
+ auto sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ EXPECT_THAT(connect(sock->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallFailsWithErrno(EADDRNOTAVAIL));
+}
+
+INSTANTIATE_TEST_SUITE_P(IPUnboundSockets, IPv4UnboundSocketTest,
+ ::testing::ValuesIn(std::vector<SocketKind>{
+ IPv4UDPUnboundSocket(0),
+ IPv4TCPUnboundSocket(0)}));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
index a72c76c97..b3f54e7f6 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
@@ -28,6 +28,7 @@
#include "test/syscalls/linux/ip_socket_test_util.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
#include "test/util/test_util.h"
namespace gvisor {
@@ -75,7 +76,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNoGroup) {
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
EXPECT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
PosixErrorIs(EAGAIN, ::testing::_));
}
@@ -209,7 +210,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddr) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
@@ -265,7 +266,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNic) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
@@ -321,7 +322,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddr) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
@@ -377,7 +378,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNic) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
@@ -437,7 +438,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrConnect) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
@@ -497,7 +498,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicConnect) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
@@ -553,7 +554,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelf) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
@@ -609,7 +610,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelf) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
@@ -669,7 +670,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfConnect) {
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
EXPECT_THAT(
- RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
PosixErrorIs(EAGAIN, ::testing::_));
}
@@ -727,7 +728,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfConnect) {
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
EXPECT_THAT(
- RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
PosixErrorIs(EAGAIN, ::testing::_));
}
@@ -785,7 +786,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfNoLoop) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
@@ -845,7 +846,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfNoLoop) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
@@ -919,7 +920,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropAddr) {
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
EXPECT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
PosixErrorIs(EAGAIN, ::testing::_));
}
@@ -977,7 +978,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropNic) {
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
EXPECT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
PosixErrorIs(EAGAIN, ::testing::_));
}
@@ -1330,8 +1331,8 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionOnTwoSockets) {
// Check that we received the multicast packet on both sockets.
for (auto& sockets : socket_pairs) {
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RecvMsgTimeout(sockets->second_fd(), recv_buf,
- sizeof(recv_buf), 1 /*timeout*/),
+ ASSERT_THAT(RecvTimeout(sockets->second_fd(), recv_buf, sizeof(recv_buf),
+ 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -1409,8 +1410,8 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) {
// Check that we received the multicast packet on both sockets.
for (auto& sockets : socket_pairs) {
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RecvMsgTimeout(sockets->second_fd(), recv_buf,
- sizeof(recv_buf), 1 /*timeout*/),
+ ASSERT_THAT(RecvTimeout(sockets->second_fd(), recv_buf, sizeof(recv_buf),
+ 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -1432,8 +1433,8 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) {
char recv_buf[sizeof(send_buf)] = {};
for (auto& sockets : socket_pairs) {
- ASSERT_THAT(RecvMsgTimeout(sockets->second_fd(), recv_buf,
- sizeof(recv_buf), 1 /*timeout*/),
+ ASSERT_THAT(RecvTimeout(sockets->second_fd(), recv_buf, sizeof(recv_buf),
+ 1 /*timeout*/),
PosixErrorIs(EAGAIN, ::testing::_));
}
}
@@ -1486,7 +1487,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenJoinThenReceive) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -1530,7 +1531,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenNoJoinThenNoReceive) {
// Check that we don't receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
PosixErrorIs(EAGAIN, ::testing::_));
}
@@ -1580,7 +1581,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenSend) {
// Check that we received the packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -1627,7 +1628,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenReceive) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -1678,7 +1679,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenSend) {
// Check that we received the packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -1737,7 +1738,7 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) {
// of the other sockets to have received it, but we will check that later.
char recv_buf[sizeof(send_buf)] = {};
EXPECT_THAT(
- RecvMsgTimeout(last->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(last->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(send_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -1745,9 +1746,9 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) {
// Verify that no other messages were received.
for (auto& socket : sockets) {
char recv_buf[kMessageSize] = {};
- EXPECT_THAT(RecvMsgTimeout(socket->get(), recv_buf, sizeof(recv_buf),
- 1 /*timeout*/),
- PosixErrorIs(EAGAIN, ::testing::_));
+ EXPECT_THAT(
+ RecvTimeout(socket->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ PosixErrorIs(EAGAIN, ::testing::_));
}
}
@@ -2108,6 +2109,9 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) {
constexpr int kMessageSize = 10;
+ // Saving during each iteration of the following loop is too expensive.
+ DisableSave ds;
+
for (int i = 0; i < 100; ++i) {
// Send a new message to the REUSEADDR/REUSEPORT group. We use a new socket
// each time so that a new ephemerial port will be used each time. This
@@ -2120,16 +2124,18 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) {
SyscallSucceedsWithValue(sizeof(send_buf)));
}
+ ds.reset();
+
// Check that both receivers got messages. This checks that we are using load
// balancing (REUSEPORT) instead of the most recently bound socket
// (REUSEADDR).
char recv_buf[kMessageSize] = {};
- EXPECT_THAT(RecvMsgTimeout(receiver1->get(), recv_buf, sizeof(recv_buf),
- 1 /*timeout*/),
- IsPosixErrorOkAndHolds(kMessageSize));
- EXPECT_THAT(RecvMsgTimeout(receiver2->get(), recv_buf, sizeof(recv_buf),
- 1 /*timeout*/),
- IsPosixErrorOkAndHolds(kMessageSize));
+ EXPECT_THAT(
+ RecvTimeout(receiver1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(kMessageSize));
+ EXPECT_THAT(
+ RecvTimeout(receiver2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(kMessageSize));
}
// Test that socket will receive packet info control message.
@@ -2193,8 +2199,8 @@ TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) {
received_msg.msg_controllen = CMSG_LEN(cmsg_data_len);
received_msg.msg_control = received_cmsg_buf;
- ASSERT_THAT(RetryEINTR(recvmsg)(receiver->get(), &received_msg, 0),
- SyscallSucceedsWithValue(kDataLength));
+ ASSERT_THAT(RecvMsgTimeout(receiver->get(), &received_msg, 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(kDataLength));
cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg);
ASSERT_NE(cmsg, nullptr);
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc
index 49a0f06d9..875016812 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc
@@ -40,17 +40,9 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, JoinSubnet) {
/*prefixlen=*/24, &addr, sizeof(addr)));
Cleanup defer_addr_removal = Cleanup(
[loopback_link = std::move(loopback_link), addr = std::move(addr)] {
- if (IsRunningOnGvisor()) {
- // TODO(gvisor.dev/issue/3921): Remove this once deleting addresses
- // via netlink is supported.
- EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET,
- /*prefixlen=*/24, &addr, sizeof(addr)),
- PosixErrorIs(EOPNOTSUPP, ::testing::_));
- } else {
- EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET,
- /*prefixlen=*/24, &addr,
- sizeof(addr)));
- }
+ EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET,
+ /*prefixlen=*/24, &addr,
+ sizeof(addr)));
});
auto snd_sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -124,17 +116,9 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, ReuseAddrSubnetDirectedBroadcast) {
24 /* prefixlen */, &addr, sizeof(addr)));
Cleanup defer_addr_removal = Cleanup(
[loopback_link = std::move(loopback_link), addr = std::move(addr)] {
- if (IsRunningOnGvisor()) {
- // TODO(gvisor.dev/issue/3921): Remove this once deleting addresses
- // via netlink is supported.
- EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET,
- /*prefixlen=*/24, &addr, sizeof(addr)),
- PosixErrorIs(EOPNOTSUPP, ::testing::_));
- } else {
- EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET,
- /*prefixlen=*/24, &addr,
- sizeof(addr)));
- }
+ EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET,
+ /*prefixlen=*/24, &addr,
+ sizeof(addr)));
});
TestAddress broadcast_address("SubnetBroadcastAddress");
diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc
index 241ddad74..ee3c08770 100644
--- a/test/syscalls/linux/socket_netlink_route.cc
+++ b/test/syscalls/linux/socket_netlink_route.cc
@@ -511,53 +511,42 @@ TEST(NetlinkRouteTest, LookupAll) {
ASSERT_GT(count, 0);
}
-TEST(NetlinkRouteTest, AddAddr) {
+TEST(NetlinkRouteTest, AddAndRemoveAddr) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+ // Don't do cooperative save/restore because netstack state is not restored.
+ // TODO(gvisor.dev/issue/4595): enable cooperative save tests.
+ const DisableSave ds;
Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink());
- FileDescriptor fd =
- ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
-
- struct request {
- struct nlmsghdr hdr;
- struct ifaddrmsg ifa;
- struct rtattr rtattr;
- struct in_addr addr;
- char pad[NLMSG_ALIGNTO + RTA_ALIGNTO];
- };
-
- struct request req = {};
- req.hdr.nlmsg_type = RTM_NEWADDR;
- req.hdr.nlmsg_seq = kSeq;
- req.ifa.ifa_family = AF_INET;
- req.ifa.ifa_prefixlen = 24;
- req.ifa.ifa_flags = 0;
- req.ifa.ifa_scope = 0;
- req.ifa.ifa_index = loopback_link.index;
- req.rtattr.rta_type = IFA_LOCAL;
- req.rtattr.rta_len = RTA_LENGTH(sizeof(req.addr));
- inet_pton(AF_INET, "10.0.0.1", &req.addr);
- req.hdr.nlmsg_len =
- NLMSG_LENGTH(sizeof(req.ifa)) + NLMSG_ALIGN(req.rtattr.rta_len);
+ struct in_addr addr;
+ ASSERT_EQ(inet_pton(AF_INET, "10.0.0.1", &addr), 1);
// Create should succeed, as no such address in kernel.
- req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_ACK;
- EXPECT_NO_ERRNO(
- NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len));
+ ASSERT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET,
+ /*prefixlen=*/24, &addr, sizeof(addr)));
+
+ Cleanup defer_addr_removal = Cleanup(
+ [loopback_link = std::move(loopback_link), addr = std::move(addr)] {
+ // First delete should succeed, as address exists.
+ EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET,
+ /*prefixlen=*/24, &addr,
+ sizeof(addr)));
+
+ // Second delete should fail, as address no longer exists.
+ EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET,
+ /*prefixlen=*/24, &addr, sizeof(addr)),
+ PosixErrorIs(EADDRNOTAVAIL, ::testing::_));
+ });
// Replace an existing address should succeed.
- req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_REPLACE | NLM_F_ACK;
- req.hdr.nlmsg_seq++;
- EXPECT_NO_ERRNO(
- NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len));
+ ASSERT_NO_ERRNO(LinkReplaceLocalAddr(loopback_link.index, AF_INET,
+ /*prefixlen=*/24, &addr, sizeof(addr)));
// Create exclusive should fail, as we created the address above.
- req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL | NLM_F_ACK;
- req.hdr.nlmsg_seq++;
- EXPECT_THAT(
- NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len),
- PosixErrorIs(EEXIST, ::testing::_));
+ EXPECT_THAT(LinkAddExclusiveLocalAddr(loopback_link.index, AF_INET,
+ /*prefixlen=*/24, &addr, sizeof(addr)),
+ PosixErrorIs(EEXIST, ::testing::_));
}
// GetRouteDump tests a RTM_GETROUTE + NLM_F_DUMP request.
diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc
index 7a0bad4cb..46f749c7c 100644
--- a/test/syscalls/linux/socket_netlink_route_util.cc
+++ b/test/syscalls/linux/socket_netlink_route_util.cc
@@ -29,6 +29,8 @@ constexpr uint32_t kSeq = 12345;
// Types of address modifications that may be performed on an interface.
enum class LinkAddrModification {
kAdd,
+ kAddExclusive,
+ kReplace,
kDelete,
};
@@ -40,6 +42,14 @@ PosixError PopulateNlmsghdr(LinkAddrModification modification,
hdr->nlmsg_type = RTM_NEWADDR;
hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
return NoError();
+ case LinkAddrModification::kAddExclusive:
+ hdr->nlmsg_type = RTM_NEWADDR;
+ hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_EXCL | NLM_F_ACK;
+ return NoError();
+ case LinkAddrModification::kReplace:
+ hdr->nlmsg_type = RTM_NEWADDR;
+ hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_REPLACE | NLM_F_ACK;
+ return NoError();
case LinkAddrModification::kDelete:
hdr->nlmsg_type = RTM_DELADDR;
hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
@@ -144,6 +154,18 @@ PosixError LinkAddLocalAddr(int index, int family, int prefixlen,
LinkAddrModification::kAdd);
}
+PosixError LinkAddExclusiveLocalAddr(int index, int family, int prefixlen,
+ const void* addr, int addrlen) {
+ return LinkModifyLocalAddr(index, family, prefixlen, addr, addrlen,
+ LinkAddrModification::kAddExclusive);
+}
+
+PosixError LinkReplaceLocalAddr(int index, int family, int prefixlen,
+ const void* addr, int addrlen) {
+ return LinkModifyLocalAddr(index, family, prefixlen, addr, addrlen,
+ LinkAddrModification::kReplace);
+}
+
PosixError LinkDelLocalAddr(int index, int family, int prefixlen,
const void* addr, int addrlen) {
return LinkModifyLocalAddr(index, family, prefixlen, addr, addrlen,
diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h
index e5badca70..eaa91ad79 100644
--- a/test/syscalls/linux/socket_netlink_route_util.h
+++ b/test/syscalls/linux/socket_netlink_route_util.h
@@ -39,10 +39,19 @@ PosixErrorOr<std::vector<Link>> DumpLinks();
// Returns the loopback link on the system. ENOENT if not found.
PosixErrorOr<Link> LoopbackLink();
-// LinkAddLocalAddr sets IFA_LOCAL attribute on the interface.
+// LinkAddLocalAddr adds a new IFA_LOCAL address to the interface.
PosixError LinkAddLocalAddr(int index, int family, int prefixlen,
const void* addr, int addrlen);
+// LinkAddExclusiveLocalAddr adds a new IFA_LOCAL address with NLM_F_EXCL flag
+// to the interface.
+PosixError LinkAddExclusiveLocalAddr(int index, int family, int prefixlen,
+ const void* addr, int addrlen);
+
+// LinkReplaceLocalAddr replaces an IFA_LOCAL address on the interface.
+PosixError LinkReplaceLocalAddr(int index, int family, int prefixlen,
+ const void* addr, int addrlen);
+
// LinkDelLocalAddr removes IFA_LOCAL attribute on the interface.
PosixError LinkDelLocalAddr(int index, int family, int prefixlen,
const void* addr, int addrlen);
diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc
index e11792309..a760581b5 100644
--- a/test/syscalls/linux/socket_test_util.cc
+++ b/test/syscalls/linux/socket_test_util.cc
@@ -753,8 +753,7 @@ PosixErrorOr<int> SendMsg(int sock, msghdr* msg, char buf[], int buf_size) {
return ret;
}
-PosixErrorOr<int> RecvMsgTimeout(int sock, char buf[], int buf_size,
- int timeout) {
+PosixErrorOr<int> RecvTimeout(int sock, char buf[], int buf_size, int timeout) {
fd_set rfd;
struct timeval to = {.tv_sec = timeout, .tv_usec = 0};
FD_ZERO(&rfd);
@@ -767,6 +766,19 @@ PosixErrorOr<int> RecvMsgTimeout(int sock, char buf[], int buf_size,
return ret;
}
+PosixErrorOr<int> RecvMsgTimeout(int sock, struct msghdr* msg, int timeout) {
+ fd_set rfd;
+ struct timeval to = {.tv_sec = timeout, .tv_usec = 0};
+ FD_ZERO(&rfd);
+ FD_SET(sock, &rfd);
+
+ int ret;
+ RETURN_ERROR_IF_SYSCALL_FAIL(ret = select(1, &rfd, NULL, NULL, &to));
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ ret = RetryEINTR(recvmsg)(sock, msg, MSG_DONTWAIT));
+ return ret;
+}
+
void RecvNoData(int sock) {
char data = 0;
struct iovec iov;
diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h
index 468bc96e0..5e205339f 100644
--- a/test/syscalls/linux/socket_test_util.h
+++ b/test/syscalls/linux/socket_test_util.h
@@ -467,9 +467,12 @@ PosixError FreeAvailablePort(int port);
// SendMsg converts a buffer to an iovec and adds it to msg before sending it.
PosixErrorOr<int> SendMsg(int sock, msghdr* msg, char buf[], int buf_size);
-// RecvMsgTimeout calls select on sock with timeout and then calls recv on sock.
-PosixErrorOr<int> RecvMsgTimeout(int sock, char buf[], int buf_size,
- int timeout);
+// RecvTimeout calls select on sock with timeout and then calls recv on sock.
+PosixErrorOr<int> RecvTimeout(int sock, char buf[], int buf_size, int timeout);
+
+// RecvMsgTimeout calls select on sock with timeout and then calls recvmsg on
+// sock.
+PosixErrorOr<int> RecvMsgTimeout(int sock, msghdr* msg, int timeout);
// RecvNoData checks that no data is receivable on sock.
void RecvNoData(int sock);
diff --git a/test/syscalls/linux/socket_unix_stream.cc b/test/syscalls/linux/socket_unix_stream.cc
index 1edcb15a7..ad9c4bf37 100644
--- a/test/syscalls/linux/socket_unix_stream.cc
+++ b/test/syscalls/linux/socket_unix_stream.cc
@@ -121,6 +121,19 @@ TEST_P(StreamUnixSocketPairTest, SetAndGetSocketLinger) {
EXPECT_EQ(0, memcmp(&got_linger, &sl, length));
}
+TEST_P(StreamUnixSocketPairTest, GetSocketAcceptConn) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int got = -1;
+ socklen_t length = sizeof(got);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_ACCEPTCONN, &got, &length),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_EQ(length, sizeof(got));
+ EXPECT_EQ(got, 0);
+}
+
INSTANTIATE_TEST_SUITE_P(
AllUnixDomainSockets, StreamUnixSocketPairTest,
::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>(
diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc
index a1d2b9b11..c2369db54 100644
--- a/test/syscalls/linux/splice.cc
+++ b/test/syscalls/linux/splice.cc
@@ -26,9 +26,11 @@
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "test/util/file_descriptor.h"
+#include "test/util/signal_util.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
+#include "test/util/timer_util.h"
namespace gvisor {
namespace testing {
@@ -772,6 +774,59 @@ TEST(SpliceTest, FromPipeToDevZero) {
SyscallSucceedsWithValue(0));
}
+static volatile int signaled = 0;
+void SigUsr1Handler(int sig, siginfo_t* info, void* context) { signaled = 1; }
+
+TEST(SpliceTest, ToPipeWithSmallCapacityDoesNotSpin_NoRandomSave) {
+ // Writes to a pipe that are less than PIPE_BUF must be atomic. This test
+ // creates a pipe with only 128 bytes of capacity (< PIPE_BUF) and checks that
+ // splicing to the pipe does not spin. See b/170743336.
+
+ // Create a file with one page of data.
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), absl::string_view(buf.data(), buf.size()),
+ TempPath::kDefaultFileMode));
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
+
+ // Create a pipe with size 4096, and fill all but 128 bytes of it.
+ int p[2];
+ ASSERT_THAT(pipe(p), SyscallSucceeds());
+ ASSERT_THAT(fcntl(p[1], F_SETPIPE_SZ, kPageSize), SyscallSucceeds());
+ const int kWriteSize = kPageSize - 128;
+ std::vector<char> writeBuf(kWriteSize);
+ RandomizeBuffer(writeBuf.data(), writeBuf.size());
+ ASSERT_THAT(write(p[1], writeBuf.data(), writeBuf.size()),
+ SyscallSucceedsWithValue(kWriteSize));
+
+ // Set up signal handler.
+ struct sigaction sa = {};
+ sa.sa_sigaction = SigUsr1Handler;
+ sa.sa_flags = SA_SIGINFO;
+ const auto cleanup_sigact =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGUSR1, sa));
+
+ // Send SIGUSR1 to this thread in 1 second.
+ struct sigevent sev = {};
+ sev.sigev_notify = SIGEV_THREAD_ID;
+ sev.sigev_signo = SIGUSR1;
+ sev.sigev_notify_thread_id = gettid();
+ auto timer = ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev));
+ struct itimerspec its = {};
+ its.it_value = absl::ToTimespec(absl::Seconds(1));
+ DisableSave ds; // Asserting an EINTR.
+ ASSERT_NO_ERRNO(timer.Set(0, its));
+
+ // Now splice the file to the pipe. This should block, but not spin, and
+ // should return EINTR because it is interrupted by the signal.
+ EXPECT_THAT(splice(fd.get(), nullptr, p[1], nullptr, kPageSize, 0),
+ SyscallFailsWithErrno(EINTR));
+
+ // Alarm should have been handled.
+ EXPECT_EQ(signaled, 1);
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc
index 92260b1e1..6e7142a42 100644
--- a/test/syscalls/linux/stat.cc
+++ b/test/syscalls/linux/stat.cc
@@ -31,6 +31,7 @@
#include "test/util/cleanup.h"
#include "test/util/file_descriptor.h"
#include "test/util/fs_util.h"
+#include "test/util/save_util.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
@@ -328,7 +329,10 @@ TEST_F(StatTest, LeadingDoubleSlash) {
ASSERT_THAT(lstat(double_slash_path.c_str(), &double_slash_st),
SyscallSucceeds());
EXPECT_EQ(st.st_dev, double_slash_st.st_dev);
- EXPECT_EQ(st.st_ino, double_slash_st.st_ino);
+ // Inode numbers for gofer-accessed files may change across save/restore.
+ if (!IsRunningWithSaveRestore()) {
+ EXPECT_EQ(st.st_ino, double_slash_st.st_ino);
+ }
}
// Test that a rename doesn't change the underlying file.
@@ -346,8 +350,14 @@ TEST_F(StatTest, StatDoesntChangeAfterRename) {
EXPECT_EQ(st_old.st_nlink, st_new.st_nlink);
EXPECT_EQ(st_old.st_dev, st_new.st_dev);
+ // Inode numbers for gofer-accessed files on which no reference is held may
+ // change across save/restore because the information that the gofer client
+ // uses to track file identity (9P QID path) is inconsistent between gofer
+ // processes, which are restarted across save/restore.
+ //
// Overlay filesystems may synthesize directory inode numbers on the fly.
- if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir()))) {
+ if (!IsRunningWithSaveRestore() &&
+ !ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir()))) {
EXPECT_EQ(st_old.st_ino, st_new.st_ino);
}
EXPECT_EQ(st_old.st_mode, st_new.st_mode);
@@ -541,6 +551,26 @@ TEST_F(StatTest, LstatELOOPPath) {
ASSERT_THAT(lstat(path.c_str(), &s), SyscallFailsWithErrno(ELOOP));
}
+TEST(SimpleStatTest, DifferentFilesHaveDifferentDeviceInodeNumberPairs) {
+ // TODO(gvisor.dev/issue/1624): This test case fails in VFS1 save/restore
+ // tests because VFS1 gofer inode number assignment restarts after
+ // save/restore, such that the inodes for file1 and file2 (which are
+ // unreferenced and therefore not retained in sentry checkpoints before the
+ // calls to lstat()) are assigned the same inode number.
+ SKIP_IF(IsRunningWithVFS1() && IsRunningWithSaveRestore());
+
+ TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ TempPath file2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ MaybeSave();
+ struct stat st1 = ASSERT_NO_ERRNO_AND_VALUE(Lstat(file1.path()));
+ MaybeSave();
+ struct stat st2 = ASSERT_NO_ERRNO_AND_VALUE(Lstat(file2.path()));
+ EXPECT_FALSE(st1.st_dev == st2.st_dev && st1.st_ino == st2.st_ino)
+ << "both files have device number " << st1.st_dev << " and inode number "
+ << st1.st_ino;
+}
+
// Ensure that inode allocation for anonymous devices work correctly across
// save/restore. In particular, inode numbers should be unique across S/R.
TEST(SimpleStatTest, AnonDeviceAllocatesUniqueInodesAcrossSaveRestore) {
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index 9f522f833..ebd873068 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -1725,6 +1725,63 @@ TEST_P(SimpleTcpSocketTest, CloseNonConnectedLingerOption) {
ASSERT_LT((end_time - start_time), absl::Seconds(kLingerTimeout));
}
+// Tests that SO_ACCEPTCONN returns non zero value for listening sockets.
+TEST_P(TcpSocketTest, GetSocketAcceptConnListener) {
+ int got = -1;
+ socklen_t length = sizeof(got);
+ ASSERT_THAT(getsockopt(listener_, SOL_SOCKET, SO_ACCEPTCONN, &got, &length),
+ SyscallSucceeds());
+ ASSERT_EQ(length, sizeof(got));
+ EXPECT_EQ(got, 1);
+}
+
+// Tests that SO_ACCEPTCONN returns zero value for not listening sockets.
+TEST_P(TcpSocketTest, GetSocketAcceptConnNonListener) {
+ int got = -1;
+ socklen_t length = sizeof(got);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ACCEPTCONN, &got, &length),
+ SyscallSucceeds());
+ ASSERT_EQ(length, sizeof(got));
+ EXPECT_EQ(got, 0);
+
+ ASSERT_THAT(getsockopt(t_, SOL_SOCKET, SO_ACCEPTCONN, &got, &length),
+ SyscallSucceeds());
+ ASSERT_EQ(length, sizeof(got));
+ EXPECT_EQ(got, 0);
+}
+
+TEST_P(SimpleTcpSocketTest, GetSocketAcceptConnWithShutdown) {
+ // TODO(b/171345701): Fix the TCP state for listening socket on shutdown.
+ SKIP_IF(IsRunningOnGvisor());
+
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ // Initialize address to the loopback one.
+ sockaddr_storage addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t addrlen = sizeof(addr);
+
+ // Bind to some port then start listening.
+ ASSERT_THAT(bind(s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(s.get(), SOMAXCONN), SyscallSucceeds());
+
+ int got = -1;
+ socklen_t length = sizeof(got);
+ ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_ACCEPTCONN, &got, &length),
+ SyscallSucceeds());
+ ASSERT_EQ(length, sizeof(got));
+ EXPECT_EQ(got, 1);
+
+ EXPECT_THAT(shutdown(s.get(), SHUT_RD), SyscallSucceeds());
+ ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_ACCEPTCONN, &got, &length),
+ SyscallSucceeds());
+ ASSERT_EQ(length, sizeof(got));
+ EXPECT_EQ(got, 0);
+}
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest,
::testing::Values(AF_INET, AF_INET6));
diff --git a/test/syscalls/linux/timers.cc b/test/syscalls/linux/timers.cc
index cac94d9e1..93a98adb1 100644
--- a/test/syscalls/linux/timers.cc
+++ b/test/syscalls/linux/timers.cc
@@ -322,11 +322,6 @@ TEST(IntervalTimerTest, PeriodicGroupDirectedSignal) {
EXPECT_GE(counted_signals.load(), kCycles);
}
-// From Linux's include/uapi/asm-generic/siginfo.h.
-#ifndef sigev_notify_thread_id
-#define sigev_notify_thread_id _sigev_un._tid
-#endif
-
TEST(IntervalTimerTest, PeriodicThreadDirectedSignal) {
constexpr int kSigno = SIGUSR1;
constexpr int kSigvalue = 42;
diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc
index 1a7673317..d65275fd3 100644
--- a/test/syscalls/linux/udp_socket.cc
+++ b/test/syscalls/linux/udp_socket.cc
@@ -679,6 +679,43 @@ TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) {
SyscallSucceedsWithValue(sizeof(buf)));
}
+TEST_P(UdpSocketTest, ConnectAndSendNoReceiver) {
+ ASSERT_NO_ERRNO(BindLoopback());
+ // Close the socket to release the port so that we get an ICMP error.
+ ASSERT_THAT(close(bind_.release()), SyscallSucceeds());
+
+ // Connect to loopback:bind_addr_ which should *hopefully* not be bound by an
+ // UDP socket. There is no easy way to ensure that the UDP port is not bound
+ // by another conncurrently running test. *This is potentially flaky*.
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ char buf[512];
+ EXPECT_THAT(send(sock_.get(), buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ constexpr int kTimeout = 1000;
+ // Poll to make sure we get the ICMP error back before issuing more writes.
+ struct pollfd pfd = {sock_.get(), POLLERR, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+
+ // Next write should fail with ECONNREFUSED due to the ICMP error generated in
+ // response to the previous write.
+ ASSERT_THAT(send(sock_.get(), buf, sizeof(buf), 0),
+ SyscallFailsWithErrno(ECONNREFUSED));
+
+ // The next write should succeed again since the last write call would have
+ // retrieved and cleared the socket error.
+ ASSERT_THAT(send(sock_.get(), buf, sizeof(buf), 0), SyscallSucceeds());
+
+ // Poll to make sure we get the ICMP error back before issuing more writes.
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+
+ // Next write should fail with ECONNREFUSED due to the ICMP error generated in
+ // response to the previous write.
+ ASSERT_THAT(send(sock_.get(), buf, sizeof(buf), 0),
+ SyscallFailsWithErrno(ECONNREFUSED));
+}
+
TEST_P(UdpSocketTest, ZerolengthWriteAllowed) {
// TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes.
SKIP_IF(IsRunningWithHostinet());
@@ -838,7 +875,7 @@ TEST_P(UdpSocketTest, ReceiveBeforeConnect) {
// Receive the data. It works because it was sent before the connect.
char received[sizeof(buf)];
EXPECT_THAT(
- RecvMsgTimeout(bind_.get(), received, sizeof(received), 1 /*timeout*/),
+ RecvTimeout(bind_.get(), received, sizeof(received), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(received)));
EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
@@ -928,9 +965,8 @@ TEST_P(UdpSocketTest, ReadShutdownNonblockPendingData) {
SyscallSucceedsWithValue(1));
// We should get the data even though read has been shutdown.
- EXPECT_THAT(
- RecvMsgTimeout(bind_.get(), received, 2 /*buf_size*/, 1 /*timeout*/),
- IsPosixErrorOkAndHolds(2));
+ EXPECT_THAT(RecvTimeout(bind_.get(), received, 2 /*buf_size*/, 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(2));
// Because we read less than the entire packet length, since it's a packet
// based socket any subsequent reads should return EWOULDBLOCK.
@@ -1698,8 +1734,8 @@ TEST_P(UdpSocketTest, RecvBufLimitsEmptyRcvBuf) {
sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_),
SyscallSucceedsWithValue(buf.size()));
std::vector<char> received(buf.size());
- EXPECT_THAT(RecvMsgTimeout(bind_.get(), received.data(), received.size(),
- 1 /*timeout*/),
+ EXPECT_THAT(RecvTimeout(bind_.get(), received.data(), received.size(),
+ 1 /*timeout*/),
IsPosixErrorOkAndHolds(received.size()));
}
@@ -1714,8 +1750,8 @@ TEST_P(UdpSocketTest, RecvBufLimitsEmptyRcvBuf) {
SyscallSucceedsWithValue(buf.size()));
std::vector<char> received(buf.size());
- ASSERT_THAT(RecvMsgTimeout(bind_.get(), received.data(), received.size(),
- 1 /*timeout*/),
+ ASSERT_THAT(RecvTimeout(bind_.get(), received.data(), received.size(),
+ 1 /*timeout*/),
IsPosixErrorOkAndHolds(received.size()));
}
}
@@ -1785,8 +1821,8 @@ TEST_P(UdpSocketTest, RecvBufLimits) {
for (int i = 0; i < sent - 1; i++) {
// Receive the data.
std::vector<char> received(buf.size());
- EXPECT_THAT(RecvMsgTimeout(bind_.get(), received.data(), received.size(),
- 1 /*timeout*/),
+ EXPECT_THAT(RecvTimeout(bind_.get(), received.data(), received.size(),
+ 1 /*timeout*/),
IsPosixErrorOkAndHolds(received.size()));
EXPECT_EQ(memcmp(buf.data(), received.data(), buf.size()), 0);
}
@@ -1851,6 +1887,22 @@ TEST_P(UdpSocketTest, GetSocketDetachFilter) {
SyscallFailsWithErrno(ENOPROTOOPT));
}
+TEST_P(UdpSocketTest, SendToZeroPort) {
+ char buf[8];
+ struct sockaddr_storage addr = InetLoopbackAddr();
+
+ // Sending to an invalid port should fail.
+ SetPort(&addr, 0);
+ EXPECT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0,
+ reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
+ SyscallFailsWithErrno(EINVAL));
+
+ SetPort(&addr, 1234);
+ EXPECT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0,
+ reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+}
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, UdpSocketTest,
::testing::Values(AddressFamily::kIpv4,
AddressFamily::kIpv6,
diff --git a/test/util/BUILD b/test/util/BUILD
index 26c2b6a2f..1b028a477 100644
--- a/test/util/BUILD
+++ b/test/util/BUILD
@@ -155,6 +155,10 @@ cc_library(
],
hdrs = ["save_util.h"],
defines = select_system(),
+ deps = [
+ ":logging",
+ "@com_google_absl//absl/types:optional",
+ ],
)
cc_library(
diff --git a/test/util/posix_error.cc b/test/util/posix_error.cc
index cebf7e0ac..deed0c05b 100644
--- a/test/util/posix_error.cc
+++ b/test/util/posix_error.cc
@@ -87,7 +87,7 @@ bool PosixErrorIsMatcherCommonImpl::MatchAndExplain(
return false;
}
- if (!message_matcher_.Matches(error.error_message())) {
+ if (!message_matcher_.Matches(error.message())) {
return false;
}
diff --git a/test/util/posix_error.h b/test/util/posix_error.h
index ad666bce0..b634a7f78 100644
--- a/test/util/posix_error.h
+++ b/test/util/posix_error.h
@@ -26,11 +26,6 @@
namespace gvisor {
namespace testing {
-class PosixErrorIsMatcherCommonImpl;
-
-template <typename T>
-class PosixErrorOr;
-
class ABSL_MUST_USE_RESULT PosixError {
public:
PosixError() {}
@@ -49,7 +44,8 @@ class ABSL_MUST_USE_RESULT PosixError {
// PosixErrorOr.
const PosixError& error() const { return *this; }
- std::string error_message() const { return msg_; }
+ int errno_value() const { return errno_; }
+ std::string message() const { return msg_; }
// ToString produces a full string representation of this posix error
// including the printable representation of the errno and the error message.
@@ -61,14 +57,8 @@ class ABSL_MUST_USE_RESULT PosixError {
void IgnoreError() const {}
private:
- int errno_value() const { return errno_; }
int errno_ = 0;
std::string msg_;
-
- friend class PosixErrorIsMatcherCommonImpl;
-
- template <typename T>
- friend class PosixErrorOr;
};
template <typename T>
@@ -94,15 +84,12 @@ class ABSL_MUST_USE_RESULT PosixErrorOr {
template <typename U>
PosixErrorOr& operator=(PosixErrorOr<U> other);
- // Return a reference to the error or NoError().
- PosixError error() const;
-
- // Returns this->error().error_message();
- std::string error_message() const;
-
// Returns true if this PosixErrorOr contains some T.
bool ok() const;
+ // Return a copy of the contained PosixError or NoError().
+ PosixError error() const;
+
// Returns a reference to our current value, or CHECK-fails if !this->ok().
const T& ValueOrDie() const&;
T& ValueOrDie() &;
@@ -115,7 +102,6 @@ class ABSL_MUST_USE_RESULT PosixErrorOr {
void IgnoreError() const {}
private:
- int errno_value() const;
absl::variant<T, PosixError> value_;
friend class PosixErrorIsMatcherCommonImpl;
@@ -171,16 +157,6 @@ PosixError PosixErrorOr<T>::error() const {
}
template <typename T>
-int PosixErrorOr<T>::errno_value() const {
- return error().errno_value();
-}
-
-template <typename T>
-std::string PosixErrorOr<T>::error_message() const {
- return error().error_message();
-}
-
-template <typename T>
bool PosixErrorOr<T>::ok() const {
return absl::holds_alternative<T>(value_);
}
diff --git a/test/util/save_util.cc b/test/util/save_util.cc
index 384d626f0..59d47e06e 100644
--- a/test/util/save_util.cc
+++ b/test/util/save_util.cc
@@ -21,35 +21,47 @@
#include <atomic>
#include <cerrno>
-#define GVISOR_COOPERATIVE_SAVE_TEST "GVISOR_COOPERATIVE_SAVE_TEST"
+#include "absl/types/optional.h"
namespace gvisor {
namespace testing {
namespace {
-enum class CooperativeSaveMode {
- kUnknown = 0, // cooperative_save_mode is statically-initialized to 0
- kAvailable,
- kNotAvailable,
-};
-
-std::atomic<CooperativeSaveMode> cooperative_save_mode;
-
-bool CooperativeSaveEnabled() {
- auto mode = cooperative_save_mode.load();
- if (mode == CooperativeSaveMode::kUnknown) {
- mode = (getenv(GVISOR_COOPERATIVE_SAVE_TEST) != nullptr)
- ? CooperativeSaveMode::kAvailable
- : CooperativeSaveMode::kNotAvailable;
- cooperative_save_mode.store(mode);
+std::atomic<absl::optional<bool>> cooperative_save_present;
+std::atomic<absl::optional<bool>> random_save_present;
+
+bool CooperativeSavePresent() {
+ auto present = cooperative_save_present.load();
+ if (!present.has_value()) {
+ present = getenv("GVISOR_COOPERATIVE_SAVE_TEST") != nullptr;
+ cooperative_save_present.store(present);
+ }
+ return present.value();
+}
+
+bool RandomSavePresent() {
+ auto present = random_save_present.load();
+ if (!present.has_value()) {
+ present = getenv("GVISOR_RANDOM_SAVE_TEST") != nullptr;
+ random_save_present.store(present);
}
- return mode == CooperativeSaveMode::kAvailable;
+ return present.value();
}
std::atomic<int> save_disable;
} // namespace
+bool IsRunningWithSaveRestore() {
+ return CooperativeSavePresent() || RandomSavePresent();
+}
+
+void MaybeSave() {
+ if (CooperativeSavePresent() && save_disable.load() == 0) {
+ internal::DoCooperativeSave();
+ }
+}
+
DisableSave::DisableSave() { save_disable++; }
DisableSave::~DisableSave() { reset(); }
@@ -61,11 +73,5 @@ void DisableSave::reset() {
}
}
-namespace internal {
-bool ShouldSave() {
- return CooperativeSaveEnabled() && (save_disable.load() == 0);
-}
-} // namespace internal
-
} // namespace testing
} // namespace gvisor
diff --git a/test/util/save_util.h b/test/util/save_util.h
index bddad6120..e7218ae88 100644
--- a/test/util/save_util.h
+++ b/test/util/save_util.h
@@ -17,9 +17,17 @@
namespace gvisor {
namespace testing {
-// Disable save prevents saving while the given function executes.
+
+// Returns true if the environment in which the calling process is executing
+// allows the test to be checkpointed and restored during execution.
+bool IsRunningWithSaveRestore();
+
+// May perform a co-operative save cycle.
//
-// This lasts the duration of the object, unless reset is called.
+// errno is guaranteed to be preserved.
+void MaybeSave();
+
+// Causes MaybeSave to become a no-op until destroyed or reset.
class DisableSave {
public:
DisableSave();
@@ -37,13 +45,13 @@ class DisableSave {
bool reset_ = false;
};
-// May perform a co-operative save cycle.
+namespace internal {
+
+// Causes a co-operative save cycle to occur.
//
// errno is guaranteed to be preserved.
-void MaybeSave();
+void DoCooperativeSave();
-namespace internal {
-bool ShouldSave();
} // namespace internal
} // namespace testing
diff --git a/test/util/save_util_linux.cc b/test/util/save_util_linux.cc
index fbac94912..57431b3ea 100644
--- a/test/util/save_util_linux.cc
+++ b/test/util/save_util_linux.cc
@@ -30,19 +30,19 @@
namespace gvisor {
namespace testing {
-
-void MaybeSave() {
- if (internal::ShouldSave()) {
- int orig_errno = errno;
- // We use it to trigger saving the sentry state
- // when this syscall is called.
- // Notice: this needs to be a valid syscall
- // that is not used in any of the syscall tests.
- syscall(SYS_TRIGGER_SAVE, nullptr, 0);
- errno = orig_errno;
- }
+namespace internal {
+
+void DoCooperativeSave() {
+ int orig_errno = errno;
+ // We use it to trigger saving the sentry state
+ // when this syscall is called.
+ // Notice: this needs to be a valid syscall
+ // that is not used in any of the syscall tests.
+ syscall(SYS_TRIGGER_SAVE, nullptr, 0);
+ errno = orig_errno;
}
+} // namespace internal
} // namespace testing
} // namespace gvisor
diff --git a/test/util/save_util_other.cc b/test/util/save_util_other.cc
index 931af2c29..7749ded76 100644
--- a/test/util/save_util_other.cc
+++ b/test/util/save_util_other.cc
@@ -14,13 +14,17 @@
#ifndef __linux__
+#include "test/util/logging.h"
+
namespace gvisor {
namespace testing {
+namespace internal {
-void MaybeSave() {
- // Saving is never available in a non-linux environment.
+void DoCooperativeSave() {
+ TEST_CHECK_MSG(false, "DoCooperativeSave not implemented");
}
+} // namespace internal
} // namespace testing
} // namespace gvisor
diff --git a/test/util/signal_util.h b/test/util/signal_util.h
index e7b66aa51..20eebd7e4 100644
--- a/test/util/signal_util.h
+++ b/test/util/signal_util.h
@@ -88,7 +88,7 @@ inline void FixupFault(ucontext_t* ctx) {
#elif __aarch64__
inline void Fault() {
// Zero and dereference x0.
- asm("mov xzr, x0\r\n"
+ asm("mov x0, xzr\r\n"
"str xzr, [x0]\r\n"
:
:
diff --git a/test/util/timer_util.h b/test/util/timer_util.h
index 926e6632f..e389108ef 100644
--- a/test/util/timer_util.h
+++ b/test/util/timer_util.h
@@ -33,6 +33,11 @@
namespace gvisor {
namespace testing {
+// From Linux's include/uapi/asm-generic/siginfo.h.
+#ifndef sigev_notify_thread_id
+#define sigev_notify_thread_id _sigev_un._tid
+#endif
+
// Returns the current time.
absl::Time Now(clockid_t id);
diff --git a/tools/bazel.mk b/tools/bazel.mk
index 88431ce66..3a7de427f 100644
--- a/tools/bazel.mk
+++ b/tools/bazel.mk
@@ -26,13 +26,13 @@ BRANCH_NAME := $(shell (git branch --show-current 2>/dev/null || \
BUILD_ROOTS := bazel-bin/ bazel-out/
# Bazel container configuration (see below).
-USER ?= gvisor
-HASH ?= $(shell readlink -m $(CURDIR) | md5sum | cut -c1-8)
+USER := $(shell whoami)
+HASH := $(shell readlink -m $(CURDIR) | md5sum | cut -c1-8)
BUILDER_BASE := gvisor.dev/images/default
BUILDER_IMAGE := gvisor.dev/images/builder
-BUILDER_NAME ?= gvisor-builder-$(HASH)
-DOCKER_NAME ?= gvisor-bazel-$(HASH)
-DOCKER_PRIVILEGED ?= --privileged
+BUILDER_NAME := gvisor-builder-$(HASH)
+DOCKER_NAME := gvisor-bazel-$(HASH)
+DOCKER_PRIVILEGED := --privileged
BAZEL_CACHE := $(shell readlink -m ~/.cache/bazel/)
GCLOUD_CONFIG := $(shell readlink -m ~/.config/gcloud/)
DOCKER_SOCKET := /var/run/docker.sock
@@ -59,6 +59,25 @@ ifeq (true,$(shell [[ -t 0 ]] && echo true))
FULL_DOCKER_EXEC_OPTIONS += --tty
endif
+# Add basic UID/GID options.
+#
+# Note that USERADD_DOCKER and GROUPADD_DOCKER are both defined as "deferred"
+# variables in Make terminology, that is they will be expanded at time of use
+# and may include other variables, including those defined below.
+#
+# NOTE: we pass -l to useradd below because otherwise you can hit a bug
+# best described here:
+# https://github.com/moby/moby/issues/5419#issuecomment-193876183
+# TLDR; trying to add to /var/log/lastlog (sparse file) runs the machine out
+# out of disk space.
+ifneq ($(UID),0)
+USERADD_DOCKER += useradd -l --uid $(UID) --non-unique --no-create-home \
+ --gid $(GID) $(USERADD_OPTIONS) -d $(HOME) $(USER) &&
+endif
+ifneq ($(GID),0)
+GROUPADD_DOCKER += groupadd --gid $(GID) --non-unique $(USER) &&
+endif
+
# Add docker passthrough options.
ifneq ($(DOCKER_PRIVILEGED),)
FULL_DOCKER_RUN_OPTIONS += -v "$(DOCKER_SOCKET):$(DOCKER_SOCKET)"
@@ -91,19 +110,12 @@ ifneq (,$(BAZEL_CONFIG))
OPTIONS += --config=$(BAZEL_CONFIG)
endif
-# NOTE: we pass -l to useradd below because otherwise you can hit a bug
-# best described here:
-# https://github.com/moby/moby/issues/5419#issuecomment-193876183
-# TLDR; trying to add to /var/log/lastlog (sparse file) runs the machine out
-# out of disk space.
bazel-image: load-default
@if docker ps --all | grep $(BUILDER_NAME); then docker rm -f $(BUILDER_NAME); fi
docker run --user 0:0 --entrypoint "" --name $(BUILDER_NAME) \
$(BUILDER_BASE) \
- sh -c "groupadd --gid $(GID) --non-unique $(USER) && \
- $(GROUPADD_DOCKER) \
- useradd -l --uid $(UID) --non-unique --no-create-home \
- --gid $(GID) $(USERADD_OPTIONS) -d $(HOME) $(USER) && \
+ sh -c "$(GROUPADD_DOCKER) \
+ $(USERADD_DOCKER) \
if [[ -e /dev/kvm ]]; then chmod a+rw /dev/kvm; fi"
docker commit $(BUILDER_NAME) $(BUILDER_IMAGE)
@docker rm -f $(BUILDER_NAME)
diff --git a/tools/bazeldefs/go.bzl b/tools/bazeldefs/go.bzl
index d388346a5..661c9727e 100644
--- a/tools/bazeldefs/go.bzl
+++ b/tools/bazeldefs/go.bzl
@@ -94,10 +94,10 @@ def go_rule(rule, implementation, **kwargs):
toolchains = kwargs.get("toolchains", []) + ["@io_bazel_rules_go//go:toolchain"]
return rule(implementation, attrs = attrs, toolchains = toolchains, **kwargs)
-def go_test_library(target):
- if hasattr(target.attr, "embed") and len(target.attr.embed) > 0:
- return target.attr.embed[0]
- return None
+def go_embed_libraries(target):
+ if hasattr(target.attr, "embed"):
+ return target.attr.embed
+ return []
def go_context(ctx, goos = None, goarch = None, std = False):
"""Extracts a standard Go context struct.
diff --git a/tools/bigquery/BUILD b/tools/bigquery/BUILD
index 2b0062a63..1cea9e1c9 100644
--- a/tools/bigquery/BUILD
+++ b/tools/bigquery/BUILD
@@ -9,5 +9,8 @@ go_library(
visibility = [
"//:sandbox",
],
- deps = ["@com_google_cloud_go_bigquery//:go_default_library"],
+ deps = [
+ "@com_google_cloud_go_bigquery//:go_default_library",
+ "@org_golang_google_api//option:go_default_library",
+ ],
)
diff --git a/tools/bigquery/bigquery.go b/tools/bigquery/bigquery.go
index 5f1a882de..544af3876 100644
--- a/tools/bigquery/bigquery.go
+++ b/tools/bigquery/bigquery.go
@@ -25,22 +25,30 @@ import (
"time"
bq "cloud.google.com/go/bigquery"
+ "google.golang.org/api/option"
)
-// Benchmark is the top level structure of recorded benchmark data. BigQuery
+// Suite is the top level structure for a benchmark run. BigQuery
// will infer the schema from this.
+type Suite struct {
+ Name string `bq:"name"`
+ Conditions []*Condition `bq:"conditions"`
+ Benchmarks []*Benchmark `bq:"benchmarks"`
+ Official bool `bq:"official"`
+ Timestamp time.Time `bq:"timestamp"`
+}
+
+// Benchmark represents an individual benchmark in a suite.
type Benchmark struct {
Name string `bq:"name"`
Condition []*Condition `bq:"condition"`
- Timestamp time.Time `bq:"timestamp"`
- Official bool `bq:"official"`
Metric []*Metric `bq:"metric"`
- Metadata *Metadata `bq:"metadata"`
}
-// Condition represents qualifiers for the benchmark. For example:
+// Condition represents qualifiers for the benchmark or suite. For example:
// Get_Pid/1/real_time would have Benchmark Name "Get_Pid" with "1"
-// and "real_time" parameters as conditions.
+// and "real_time" parameters as conditions. Suite conditions include
+// information such as the CL number and platform name.
type Condition struct {
Name string `bq:"name"`
Value string `bq:"value"`
@@ -53,19 +61,9 @@ type Metric struct {
Sample float64 `bq:"sample"`
}
-// Metadata about this benchmark.
-type Metadata struct {
- CL string `bq:"changelist"`
- IterationID string `bq:"iteration_id"`
- PendingCL string `bq:"pending_cl"`
- Workflow string `bq:"workflow"`
- Platform string `bq:"platform"`
- Gofer string `bq:"gofer"`
-}
-
// InitBigQuery initializes a BigQuery dataset/table in the project. If the dataset/table already exists, it is not duplicated.
-func InitBigQuery(ctx context.Context, projectID, datasetID, tableID string) error {
- client, err := bq.NewClient(ctx, projectID)
+func InitBigQuery(ctx context.Context, projectID, datasetID, tableID string, opts []option.ClientOption) error {
+ client, err := bq.NewClient(ctx, projectID, opts...)
if err != nil {
return fmt.Errorf("failed to initialize client on project %s: %v", projectID, err)
}
@@ -77,7 +75,7 @@ func InitBigQuery(ctx context.Context, projectID, datasetID, tableID string) err
}
table := dataset.Table(tableID)
- schema, err := bq.InferSchema(Benchmark{})
+ schema, err := bq.InferSchema(Suite{})
if err != nil {
return fmt.Errorf("failed to infer schema: %v", err)
}
@@ -107,26 +105,35 @@ func (bm *Benchmark) AddMetric(metricName, unit string, sample float64) {
}
// NewBenchmark initializes a new benchmark.
-func NewBenchmark(name string, iters int, official bool) *Benchmark {
+func NewBenchmark(name string, iters int) *Benchmark {
return &Benchmark{
- Name: name,
- Timestamp: time.Now().UTC(),
- Official: official,
- Metric: make([]*Metric, 0),
+ Name: name,
+ Metric: make([]*Metric, 0),
+ }
+}
+
+// NewSuite initializes a new Suite.
+func NewSuite(name string, official bool) *Suite {
+ return &Suite{
+ Name: name,
+ Timestamp: time.Now().UTC(),
+ Benchmarks: make([]*Benchmark, 0),
+ Conditions: make([]*Condition, 0),
+ Official: official,
}
}
// SendBenchmarks sends the slice of benchmarks to the BigQuery dataset/table.
-func SendBenchmarks(ctx context.Context, benchmarks []*Benchmark, projectID, datasetID, tableID string) error {
- client, err := bq.NewClient(ctx, projectID)
+func SendBenchmarks(ctx context.Context, suite *Suite, projectID, datasetID, tableID string, opts []option.ClientOption) error {
+ client, err := bq.NewClient(ctx, projectID, opts...)
if err != nil {
return fmt.Errorf("failed to initialize client on project: %s: %v", projectID, err)
}
defer client.Close()
uploader := client.Dataset(datasetID).Table(tableID).Uploader()
- if err = uploader.Put(ctx, benchmarks); err != nil {
- return fmt.Errorf("failed to upload benchmarks to proejct %s, table %s.%s: %v", projectID, datasetID, tableID, err)
+ if err = uploader.Put(ctx, suite); err != nil {
+ return fmt.Errorf("failed to upload benchmarks %s to project %s, table %s.%s: %v", suite.Name, projectID, datasetID, tableID, err)
}
return nil
diff --git a/tools/defs.bzl b/tools/defs.bzl
index bb291c512..d75e40863 100644
--- a/tools/defs.bzl
+++ b/tools/defs.bzl
@@ -86,8 +86,10 @@ def go_binary(name, nogo = True, pure = False, static = False, x_defs = None, **
)
nogo_test(
name = name + "_nogo",
+ config = "//:nogo_config",
srcs = kwargs.get("srcs", []),
- library = ":" + name + "_nogo_library",
+ deps = [":" + name + "_nogo_library"],
+ tags = ["nogo"],
)
def calculate_sets(srcs):
@@ -218,8 +220,10 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F
if nogo:
nogo_test(
name = name + "_nogo",
+ config = "//:nogo_config",
srcs = all_srcs,
- library = ":" + name,
+ deps = [":" + name],
+ tags = ["nogo"],
)
if marshal:
@@ -255,8 +259,10 @@ def go_test(name, nogo = True, **kwargs):
if nogo:
nogo_test(
name = name + "_nogo",
+ config = "//:nogo_config",
srcs = kwargs.get("srcs", []),
- library = ":" + name,
+ deps = [":" + name],
+ tags = ["nogo"],
)
def proto_library(name, srcs, deps = None, has_services = 0, **kwargs):
diff --git a/tools/github/nogo/BUILD b/tools/github/nogo/BUILD
index 0633eaf19..19b7eec4d 100644
--- a/tools/github/nogo/BUILD
+++ b/tools/github/nogo/BUILD
@@ -10,7 +10,7 @@ go_library(
"//tools/github:__subpackages__",
],
deps = [
- "//tools/nogo/util",
+ "//tools/nogo",
"@com_github_google_go_github_v28//github:go_default_library",
],
)
diff --git a/tools/github/nogo/nogo.go b/tools/github/nogo/nogo.go
index b2bc63459..27ab1b8eb 100644
--- a/tools/github/nogo/nogo.go
+++ b/tools/github/nogo/nogo.go
@@ -24,7 +24,7 @@ import (
"time"
"github.com/google/go-github/github"
- "gvisor.dev/gvisor/tools/nogo/util"
+ "gvisor.dev/gvisor/tools/nogo"
)
// FindingsPoster is a simple wrapper around the GitHub api.
@@ -35,7 +35,7 @@ type FindingsPoster struct {
dryRun bool
startTime time.Time
- findings map[util.Finding]struct{}
+ findings map[nogo.Finding]struct{}
client *github.Client
}
@@ -47,7 +47,7 @@ func NewFindingsPoster(client *github.Client, owner, repo, commit string, dryRun
commit: commit,
dryRun: dryRun,
startTime: time.Now(),
- findings: make(map[util.Finding]struct{}),
+ findings: make(map[nogo.Finding]struct{}),
client: client,
}
}
@@ -63,7 +63,7 @@ func (p *FindingsPoster) Walk(paths []string) error {
if !strings.HasSuffix(filename, ".findings") || info.IsDir() {
return nil
}
- findings, err := util.ExtractFindingsFromFile(filename)
+ findings, err := nogo.ExtractFindingsFromFile(filename)
if err != nil {
return err
}
@@ -86,7 +86,7 @@ func (p *FindingsPoster) Post() error {
if p.dryRun {
for finding, _ := range p.findings {
// Pretty print, so that this is useful for debugging.
- fmt.Printf("%s: (%s+%d) %s\n", finding.Category, finding.Path, finding.Line, finding.Message)
+ fmt.Printf("%s: (%s+%d) %s\n", finding.Category, finding.Position.Filename, finding.Position.Line, finding.Message)
}
return nil
}
@@ -115,12 +115,13 @@ func (p *FindingsPoster) Post() error {
}
annotationLevel := "failure" // Always.
for finding, _ := range p.findings {
+ title := string(finding.Category)
opts.Output.Annotations = append(opts.Output.Annotations, &github.CheckRunAnnotation{
- Path: &finding.Path,
- StartLine: &finding.Line,
- EndLine: &finding.Line,
+ Path: &finding.Position.Filename,
+ StartLine: &finding.Position.Line,
+ EndLine: &finding.Position.Line,
Message: &finding.Message,
- Title: &finding.Category,
+ Title: &title,
AnnotationLevel: &annotationLevel,
})
}
diff --git a/tools/github/reviver/github.go b/tools/github/reviver/github.go
index a95df0fb6..c4b624f2a 100644
--- a/tools/github/reviver/github.go
+++ b/tools/github/reviver/github.go
@@ -121,13 +121,24 @@ func (b *GitHubBugger) Activate(todo *Todo) (bool, error) {
return true, nil
}
+var issuePrefixes = []string{
+ "gvisor.dev/issue/",
+ "gvisor.dev/issues/",
+}
+
// parseIssueNo parses the issue number out of the issue url.
+//
+// 0 is returned if url does not correspond to an issue.
func parseIssueNo(url string) (int, error) {
- const prefix = "gvisor.dev/issue/"
-
// First check if I can handle the TODO.
- idStr := strings.TrimPrefix(url, prefix)
- if len(url) == len(idStr) {
+ var idStr string
+ for _, p := range issuePrefixes {
+ if str := strings.TrimPrefix(url, p); len(str) < len(url) {
+ idStr = str
+ break
+ }
+ }
+ if len(idStr) == 0 {
return 0, nil
}
diff --git a/tools/github/reviver/reviver_test.go b/tools/github/reviver/reviver_test.go
index a9fb1f9f1..851306c9d 100644
--- a/tools/github/reviver/reviver_test.go
+++ b/tools/github/reviver/reviver_test.go
@@ -33,6 +33,15 @@ func TestProcessLine(t *testing.T) {
},
},
{
+ line: "// TODO(foobar.com/issues/123): comment, bla. blabla.",
+ want: &Todo{
+ Issue: "foobar.com/issues/123",
+ Locations: []Location{
+ {Comment: "comment, bla. blabla."},
+ },
+ },
+ },
+ {
line: "// FIXME(b/123): internal bug",
want: &Todo{
Issue: "b/123",
diff --git a/tools/go_branch.sh b/tools/go_branch.sh
index e5c060024..71d036b12 100755
--- a/tools/go_branch.sh
+++ b/tools/go_branch.sh
@@ -14,23 +14,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-set -xeo pipefail
+set -xeou pipefail
# Discovery the package name from the go.mod file.
-declare -r module=$(cat go.mod | grep -E "^module" | cut -d' ' -f2)
-declare -r origpwd=$(pwd)
-declare -r othersrc=("go.mod" "go.sum" "AUTHORS" "LICENSE")
+declare module origpwd othersrc
+module=$(cat go.mod | grep -E "^module" | cut -d' ' -f2)
+origpwd=$(pwd)
+othersrc=("go.mod" "go.sum" "AUTHORS" "LICENSE")
+readonly module origpwd othersrc
+
# Check that gopath has been built.
-declare -r gopath_dir="$(pwd)/bazel-bin/gopath/src/${module}"
-if ! [ -d "${gopath_dir}" ]; then
+declare gopath_dir
+gopath_dir="$(pwd)/bazel-bin/gopath/src/${module}"
+readonly gopath_dir
+if ! [[ -d "${gopath_dir}" ]]; then
echo "No gopath directory found; build the :gopath target." >&2
exit 1
fi
# Create a temporary working directory, and ensure that this directory and all
# subdirectories are cleaned up upon exit.
-declare -r tmp_dir=$(mktemp -d)
+declare tmp_dir
+tmp_dir=$(mktemp -d)
+readonly tmp_dir
finish() {
cd # Leave tmp_dir.
rm -rf "${tmp_dir}"
@@ -38,21 +45,27 @@ finish() {
trap finish EXIT
# Record the current working commit.
-declare -r head=$(git describe --always)
+declare head
+head=$(git describe --always)
+readonly head
# We expect to have an existing go branch that we will use as the basis for this
# commit. That branch may be empty, but it must exist. We search for this branch
# using the local branch, the "origin" branch, and other remotes, in order.
git fetch --all
-declare -r go_branch=$( \
+declare go_branch
+go_branch=$( \
git show-ref --hash refs/heads/go || \
git show-ref --hash refs/remotes/origin/go || \
git show-ref --hash go | head -n 1 \
)
+readonly go_branch
# Clone the current repository to the temporary directory, and check out the
# current go_branch directory. We move to the new repository for convenience.
-declare -r repo_orig="$(pwd)"
+declare repo_orig
+repo_orig="$(pwd)"
+readonly repo_orig
declare -r repo_new="${tmp_dir}/repository"
git clone . "${repo_new}"
cd "${repo_new}"
@@ -68,8 +81,8 @@ git checkout -b go "${go_branch}"
#
# N.B. The git behavior changed at some point and the relevant flag was added
# to allow for override, so try the only behavior first then pass the flag.
-git merge --no-commit --strategy ours ${head} || \
- git merge --allow-unrelated-histories --no-commit --strategy ours ${head}
+git merge --no-commit --strategy ours "${head}" || \
+ git merge --allow-unrelated-histories --no-commit --strategy ours "${head}"
# Normalize the permissions on the old branch. Note that they should be
# normalized if constructed by this tool, but we do so before the rsync.
@@ -96,7 +109,7 @@ EOF
# There are a few solitary files that can get left behind due to the way bazel
# constructs the gopath target. Note that we don't find all Go files here
# because they may correspond to unused templates, etc.
-declare -ar binaries=( "runsc" "shim/v1" "shim/v2" )
+declare -ar binaries=( "runsc" "shim/v1" "shim/v2" "webhook" )
for target in "${binaries[@]}"; do
mkdir -p "${target}"
cp "${repo_orig}/${target}"/*.go "${target}/"
@@ -109,7 +122,11 @@ find . -type f -exec chmod 0644 {} \;
find . -type d -exec chmod 0755 {} \;
# Update the current working set and commit.
-git add . && git commit -m "Merge ${head} (automated)"
+# If the current working commit has already been committed to the remote go
+# branch, then we have nothing to commit here. So allow empty commit. This can
+# occur when this script is run parallely (via pull_request and push events)
+# and the push workflow finishes before the pull_request workflow can run this.
+git add . && git commit --allow-empty -m "Merge ${head} (automated)"
# Push the branch back to the original repository.
git remote add orig "${repo_orig}" && git push -f orig go:go
diff --git a/tools/go_generics/imports.go b/tools/go_generics/imports.go
index 90d3aa1e0..370650e46 100644
--- a/tools/go_generics/imports.go
+++ b/tools/go_generics/imports.go
@@ -48,7 +48,7 @@ func updateImportIdent(orig string, imports mapValue, id *ast.Ident, used map[st
// Create a new entry in the used map.
path := imports[importName]
if path == "" {
- return fmt.Errorf("Unknown path to package '%s', used in '%s'", importName, orig)
+ return fmt.Errorf("unknown path to package '%s', used in '%s'", importName, orig)
}
m = &importedPackage{
@@ -72,7 +72,7 @@ func convertExpression(s string, imports mapValue, used map[string]*importedPack
// Parse the expression in the input string.
expr, err := parser.ParseExpr(s)
if err != nil {
- return "", fmt.Errorf("Unable to parse \"%s\": %v", s, err)
+ return "", fmt.Errorf("unable to parse \"%s\": %v", s, err)
}
// Go through the AST and update references.
diff --git a/tools/go_marshal/test/escape/escape.go b/tools/go_marshal/test/escape/escape.go
index 7f62b0a2b..df14ae98e 100644
--- a/tools/go_marshal/test/escape/escape.go
+++ b/tools/go_marshal/test/escape/escape.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package escape contains test cases for escape analysis.
package escape
import (
diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go
index d9e9f341b..e7e3ed74a 100644
--- a/tools/go_marshal/test/test.go
+++ b/tools/go_marshal/test/test.go
@@ -161,7 +161,7 @@ type TestArray [sizeA]int32
// +marshal
type TestArray2 [sizeA * sizeB]int32
-// TestArray2 is a newtype on an array with a simple arithmetic expression of
+// TestArray3 is a newtype on an array with a simple arithmetic expression of
// mixed constants and literals for the array length.
//
// +marshal
diff --git a/tools/nogo/BUILD b/tools/nogo/BUILD
index 3c6be3339..12b8b597c 100644
--- a/tools/nogo/BUILD
+++ b/tools/nogo/BUILD
@@ -20,20 +20,14 @@ nogo_stdlib(
visibility = ["//visibility:public"],
)
-sh_binary(
- name = "gentest",
- srcs = ["gentest.sh"],
- visibility = ["//visibility:public"],
-)
-
go_library(
name = "nogo",
srcs = [
+ "analyzers.go",
"build.go",
"config.go",
- "matchers.go",
+ "findings.go",
"nogo.go",
- "register.go",
],
nogo = False,
visibility = ["//:sandbox"],
diff --git a/tools/nogo/analyzers.go b/tools/nogo/analyzers.go
new file mode 100644
index 000000000..b919bc2f8
--- /dev/null
+++ b/tools/nogo/analyzers.go
@@ -0,0 +1,131 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package nogo
+
+import (
+ "encoding/gob"
+
+ "golang.org/x/tools/go/analysis"
+ "golang.org/x/tools/go/analysis/passes/asmdecl"
+ "golang.org/x/tools/go/analysis/passes/assign"
+ "golang.org/x/tools/go/analysis/passes/atomic"
+ "golang.org/x/tools/go/analysis/passes/bools"
+ "golang.org/x/tools/go/analysis/passes/buildtag"
+ "golang.org/x/tools/go/analysis/passes/cgocall"
+ "golang.org/x/tools/go/analysis/passes/composite"
+ "golang.org/x/tools/go/analysis/passes/copylock"
+ "golang.org/x/tools/go/analysis/passes/errorsas"
+ "golang.org/x/tools/go/analysis/passes/httpresponse"
+ "golang.org/x/tools/go/analysis/passes/loopclosure"
+ "golang.org/x/tools/go/analysis/passes/lostcancel"
+ "golang.org/x/tools/go/analysis/passes/nilfunc"
+ "golang.org/x/tools/go/analysis/passes/nilness"
+ "golang.org/x/tools/go/analysis/passes/printf"
+ "golang.org/x/tools/go/analysis/passes/shadow"
+ "golang.org/x/tools/go/analysis/passes/shift"
+ "golang.org/x/tools/go/analysis/passes/stdmethods"
+ "golang.org/x/tools/go/analysis/passes/stringintconv"
+ "golang.org/x/tools/go/analysis/passes/structtag"
+ "golang.org/x/tools/go/analysis/passes/tests"
+ "golang.org/x/tools/go/analysis/passes/unmarshal"
+ "golang.org/x/tools/go/analysis/passes/unreachable"
+ "golang.org/x/tools/go/analysis/passes/unsafeptr"
+ "golang.org/x/tools/go/analysis/passes/unusedresult"
+ "honnef.co/go/tools/staticcheck"
+ "honnef.co/go/tools/stylecheck"
+
+ "gvisor.dev/gvisor/tools/checkescape"
+ "gvisor.dev/gvisor/tools/checkunsafe"
+)
+
+// AllAnalyzers is a list of all available analyzers.
+var AllAnalyzers = []*analysis.Analyzer{
+ asmdecl.Analyzer,
+ assign.Analyzer,
+ atomic.Analyzer,
+ bools.Analyzer,
+ buildtag.Analyzer,
+ cgocall.Analyzer,
+ composite.Analyzer,
+ copylock.Analyzer,
+ errorsas.Analyzer,
+ httpresponse.Analyzer,
+ loopclosure.Analyzer,
+ lostcancel.Analyzer,
+ nilfunc.Analyzer,
+ nilness.Analyzer,
+ printf.Analyzer,
+ shift.Analyzer,
+ stdmethods.Analyzer,
+ stringintconv.Analyzer,
+ shadow.Analyzer,
+ structtag.Analyzer,
+ tests.Analyzer,
+ unmarshal.Analyzer,
+ unreachable.Analyzer,
+ unsafeptr.Analyzer,
+ unusedresult.Analyzer,
+ checkescape.Analyzer,
+ checkunsafe.Analyzer,
+}
+
+// EscapeAnalyzers is a list of escape-related analyzers.
+var EscapeAnalyzers = []*analysis.Analyzer{
+ checkescape.EscapeAnalyzer,
+}
+
+func register(all []*analysis.Analyzer) {
+ // Register all fact types.
+ //
+ // N.B. This needs to be done recursively, because there may be
+ // analyzers in the Requires list that do not appear explicitly above.
+ registered := make(map[*analysis.Analyzer]struct{})
+ var registerOne func(*analysis.Analyzer)
+ registerOne = func(a *analysis.Analyzer) {
+ if _, ok := registered[a]; ok {
+ return
+ }
+
+ // Register dependencies.
+ for _, da := range a.Requires {
+ registerOne(da)
+ }
+
+ // Register local facts.
+ for _, f := range a.FactTypes {
+ gob.Register(f)
+ }
+
+ registered[a] = struct{}{} // Done.
+ }
+ for _, a := range all {
+ registerOne(a)
+ }
+}
+
+func init() {
+ // Add all staticcheck analyzers.
+ for _, a := range staticcheck.Analyzers {
+ AllAnalyzers = append(AllAnalyzers, a)
+ }
+ // Add all stylecheck analyzers.
+ for _, a := range stylecheck.Analyzers {
+ AllAnalyzers = append(AllAnalyzers, a)
+ }
+
+ // Register lists.
+ register(AllAnalyzers)
+ register(EscapeAnalyzers)
+}
diff --git a/tools/nogo/build.go b/tools/nogo/build.go
index 55d34760f..d173cff1f 100644
--- a/tools/nogo/build.go
+++ b/tools/nogo/build.go
@@ -20,22 +20,6 @@ import (
"os"
)
-var (
- // internalPrefix is the internal path prefix. Note that this is not
- // special, as paths should be passed relative to the repository root
- // and should not have any special prefix applied.
- internalPrefix = fmt.Sprintf("^")
-
- // internalDefault is applied when no paths are provided.
- internalDefault = fmt.Sprintf("%s/.*", notPath("external"))
-
- // generatedPrefix is a regex for generated files.
- generatedPrefix = "^(.*/)?(bazel-genfiles|bazel-out|bazel-bin)/"
-
- // externalPrefix is external workspace packages.
- externalPrefix = "^external/"
-)
-
// findStdPkg needs to find the bundled standard library packages.
func findStdPkg(GOOS, GOARCH, path string) (io.ReadCloser, error) {
if path == "C" {
diff --git a/tools/nogo/check/BUILD b/tools/nogo/check/BUILD
index 21ba2c306..e18483a18 100644
--- a/tools/nogo/check/BUILD
+++ b/tools/nogo/check/BUILD
@@ -2,8 +2,6 @@ load("//tools:defs.bzl", "go_binary")
package(licenses = ["notice"])
-# Note that the check binary must be public, since an aspect may be applied
-# across lots of different rules in different repositories.
go_binary(
name = "check",
srcs = ["main.go"],
diff --git a/tools/nogo/check/main.go b/tools/nogo/check/main.go
index 3828edf3a..69bdfe502 100644
--- a/tools/nogo/check/main.go
+++ b/tools/nogo/check/main.go
@@ -16,9 +16,99 @@
package main
import (
+ "encoding/json"
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "os"
+
"gvisor.dev/gvisor/tools/nogo"
)
+var (
+ packageFile = flag.String("package", "", "package configuration file (in JSON format)")
+ stdlibFile = flag.String("stdlib", "", "stdlib configuration file (in JSON format)")
+ findingsOutput = flag.String("findings", "", "output file (or stdout, if not specified)")
+ factsOutput = flag.String("facts", "", "output file for facts (optional)")
+ escapesOutput = flag.String("escapes", "", "output file for escapes (optional)")
+)
+
+func loadConfig(file string, config interface{}) interface{} {
+ // Load the configuration.
+ f, err := os.Open(file)
+ if err != nil {
+ log.Fatalf("unable to open configuration %q: %v", file, err)
+ }
+ defer f.Close()
+ dec := json.NewDecoder(f)
+ dec.DisallowUnknownFields()
+ if err := dec.Decode(config); err != nil {
+ log.Fatalf("unable to decode configuration: %v", err)
+ }
+ return config
+}
+
func main() {
- nogo.Main()
+ // Parse all flags.
+ flag.Parse()
+
+ var (
+ findings []nogo.Finding
+ factData []byte
+ err error
+ )
+
+ // Check & load the configuration.
+ if *packageFile != "" && *stdlibFile != "" {
+ log.Fatalf("unable to perform stdlib and package analysis; provide only one!")
+ }
+
+ // Run the configuration.
+ if *stdlibFile != "" {
+ // Perform basic analysis.
+ c := loadConfig(*stdlibFile, new(nogo.StdlibConfig)).(*nogo.StdlibConfig)
+ findings, factData, err = nogo.CheckStdlib(c, nogo.AllAnalyzers)
+
+ } else if *packageFile != "" {
+ // Perform basic analysis.
+ c := loadConfig(*packageFile, new(nogo.PackageConfig)).(*nogo.PackageConfig)
+ findings, factData, err = nogo.CheckPackage(c, nogo.AllAnalyzers, nil)
+
+ // Do we need to do escape analysis?
+ if *escapesOutput != "" {
+ escapes, _, err := nogo.CheckPackage(c, nogo.EscapeAnalyzers, nil)
+ if err != nil {
+ log.Fatalf("error performing escape analysis: %v", err)
+ }
+ if err := nogo.WriteFindingsToFile(escapes, *escapesOutput); err != nil {
+ log.Fatalf("error writing escapes to %q: %v", *escapesOutput, err)
+ }
+ }
+ } else {
+ log.Fatalf("please provide at least one of package or stdlib!")
+ }
+
+ // Check that analysis was successful.
+ if err != nil {
+ log.Fatalf("error performing analysis: %v", err)
+ }
+
+ // Save facts.
+ if *factsOutput != "" {
+ if err := ioutil.WriteFile(*factsOutput, factData, 0644); err != nil {
+ log.Fatalf("error saving findings to %q: %v", *factsOutput, err)
+ }
+ }
+
+ // Write all findings.
+ if *findingsOutput != "" {
+ if err := nogo.WriteFindingsToFile(findings, *findingsOutput); err != nil {
+ log.Fatalf("error writing findings to %q: %v", *findingsOutput, err)
+ }
+ } else {
+ for _, finding := range findings {
+ fmt.Fprintf(os.Stdout, "%s\n", finding.String())
+ }
+ }
}
diff --git a/tools/nogo/config.go b/tools/nogo/config.go
index 0853f03cf..2fea5b3e1 100644
--- a/tools/nogo/config.go
+++ b/tools/nogo/config.go
@@ -15,544 +15,247 @@
package nogo
import (
- "golang.org/x/tools/go/analysis"
- "golang.org/x/tools/go/analysis/passes/asmdecl"
- "golang.org/x/tools/go/analysis/passes/assign"
- "golang.org/x/tools/go/analysis/passes/atomic"
- "golang.org/x/tools/go/analysis/passes/bools"
- "golang.org/x/tools/go/analysis/passes/buildtag"
- "golang.org/x/tools/go/analysis/passes/cgocall"
- "golang.org/x/tools/go/analysis/passes/composite"
- "golang.org/x/tools/go/analysis/passes/copylock"
- "golang.org/x/tools/go/analysis/passes/errorsas"
- "golang.org/x/tools/go/analysis/passes/httpresponse"
- "golang.org/x/tools/go/analysis/passes/loopclosure"
- "golang.org/x/tools/go/analysis/passes/lostcancel"
- "golang.org/x/tools/go/analysis/passes/nilfunc"
- "golang.org/x/tools/go/analysis/passes/nilness"
- "golang.org/x/tools/go/analysis/passes/printf"
- "golang.org/x/tools/go/analysis/passes/shadow"
- "golang.org/x/tools/go/analysis/passes/shift"
- "golang.org/x/tools/go/analysis/passes/stdmethods"
- "golang.org/x/tools/go/analysis/passes/stringintconv"
- "golang.org/x/tools/go/analysis/passes/structtag"
- "golang.org/x/tools/go/analysis/passes/tests"
- "golang.org/x/tools/go/analysis/passes/unmarshal"
- "golang.org/x/tools/go/analysis/passes/unreachable"
- "golang.org/x/tools/go/analysis/passes/unsafeptr"
- "golang.org/x/tools/go/analysis/passes/unusedresult"
- "honnef.co/go/tools/staticcheck"
- "honnef.co/go/tools/stylecheck"
-
- "gvisor.dev/gvisor/tools/checkescape"
- "gvisor.dev/gvisor/tools/checkunsafe"
+ "fmt"
+ "regexp"
)
-var analyzerConfig = map[*analysis.Analyzer]matcher{
- // Standard analyzers.
- asmdecl.Analyzer: alwaysMatches(),
- assign.Analyzer: externalExcluded(
- ".*gazelle/walk/walk.go", // False positive.
- ),
- atomic.Analyzer: alwaysMatches(),
- bools.Analyzer: alwaysMatches(),
- buildtag.Analyzer: alwaysMatches(),
- cgocall.Analyzer: alwaysMatches(),
- composite.Analyzer: and(
- disableMatches(), // Disabled for now.
- resultExcluded{
- "Object_",
- "Range{",
- },
- ),
- copylock.Analyzer: internalMatches(), // Common external issues (e.g. protos).
- errorsas.Analyzer: alwaysMatches(),
- httpresponse.Analyzer: alwaysMatches(),
- loopclosure.Analyzer: alwaysMatches(),
- lostcancel.Analyzer: internalMatches(), // Common external issues.
- nilfunc.Analyzer: alwaysMatches(),
- nilness.Analyzer: and(
- internalMatches(), // Common "tautological checks".
- internalExcluded(
- "pkg/sentry/platform/kvm/kvm_test.go", // Intentional.
- "tools/bigquery/bigquery.go", // False positive.
- ),
- ),
- printf.Analyzer: alwaysMatches(),
- shift.Analyzer: alwaysMatches(),
- stdmethods.Analyzer: internalMatches(), // Common external issues (e.g. methods named "Write").
- stringintconv.Analyzer: and(
- internalExcluded(),
- externalExcluded(
- ".*protobuf/.*.go", // Bad conversions.
- ".*flate/huffman_bit_writer.go", // Bad conversion.
+// GroupName is a named group.
+type GroupName string
+
+// AnalyzerName is a named analyzer.
+type AnalyzerName string
+
+// Group represents a named collection of files.
+type Group struct {
+ // Name is the short name for the group.
+ Name GroupName `yaml:"name"`
+
+ // Regex matches all full paths in the group.
+ Regex string `yaml:"regex"`
+ regex *regexp.Regexp `yaml:"-"`
+
+ // Default determines the default group behavior.
+ //
+ // If Default is true, all Analyzers are enabled for this
+ // group. Otherwise, Analyzers must be individually enabled
+ // by specifying a (possible empty) ItemConfig for the group
+ // in the AnalyzerConfig.
+ Default bool `yaml:"default"`
+}
+
+func (g *Group) compile() error {
+ r, err := regexp.Compile(g.Regex)
+ if err != nil {
+ return err
+ }
+ g.regex = r
+ return nil
+}
+
+// ItemConfig is an (Analyzer,Group) configuration.
+type ItemConfig struct {
+ // Exclude are analyzer exclusions.
+ //
+ // Exclude is a list of regular expressions. If the corresponding
+ // Analyzer emits a Finding for which Finding.Position.String()
+ // matches a regular expression in Exclude, the finding will not
+ // be reported.
+ Exclude []string `yaml:"exclude,omitempty"`
+ exclude []*regexp.Regexp `yaml:"-"`
+
+ // Suppress are analyzer suppressions.
+ //
+ // Suppress is a list of regular expressions. If the corresponding
+ // Analyzer emits a Finding for which Finding.Message matches a regular
+ // expression in Suppress, the finding will not be reported.
+ Suppress []string `yaml:"suppress,omitempty"`
+ suppress []*regexp.Regexp `yaml:"-"`
+}
+
+func compileRegexps(ss []string, rs *[]*regexp.Regexp) error {
+ *rs = make([]*regexp.Regexp, 0, len(ss))
+ for _, s := range ss {
+ r, err := regexp.Compile(s)
+ if err != nil {
+ return err
+ }
+ *rs = append(*rs, r)
+ }
+ return nil
+}
+
+func (i *ItemConfig) compile() error {
+ if i == nil {
+ // This may be nil if nothing is included in the
+ // item configuration. That's fine, there's nothing
+ // to compile and nothing to exclude & suppress.
+ return nil
+ }
+ if err := compileRegexps(i.Exclude, &i.exclude); err != nil {
+ return fmt.Errorf("in exclude: %w", err)
+ }
+ if err := compileRegexps(i.Suppress, &i.suppress); err != nil {
+ return fmt.Errorf("in suppress: %w", err)
+ }
+ return nil
+}
+
+func (i *ItemConfig) merge(other *ItemConfig) {
+ i.Exclude = append(i.Exclude, other.Exclude...)
+ i.Suppress = append(i.Suppress, other.Suppress...)
+}
+
+func (i *ItemConfig) shouldReport(fullPos, msg string) bool {
+ if i == nil {
+ // See above.
+ return true
+ }
+ for _, r := range i.exclude {
+ if r.MatchString(fullPos) {
+ return false
+ }
+ }
+ for _, r := range i.suppress {
+ if r.MatchString(msg) {
+ return false
+ }
+ }
+ return true
+}
+
+// AnalyzerConfig is the configuration for a single analyzers.
+//
+// This map is keyed by individual Group names, to allow for different
+// configurations depending on what Group the file belongs to.
+type AnalyzerConfig map[GroupName]*ItemConfig
+
+func (a AnalyzerConfig) compile() error {
+ for name, gc := range a {
+ if err := gc.compile(); err != nil {
+ return fmt.Errorf("invalid group %q: %v", name, err)
+ }
+ }
+ return nil
+}
+
+func (a AnalyzerConfig) merge(other AnalyzerConfig) {
+ // Merge all the groups.
+ for name, gc := range other {
+ old, ok := a[name]
+ if !ok || old == nil {
+ a[name] = gc // Not configured in a.
+ continue
+ }
+ old.merge(gc)
+ }
+}
+
+func (a AnalyzerConfig) shouldReport(groupConfig *Group, fullPos, msg string) bool {
+ gc, ok := a[groupConfig.Name]
+ if !ok {
+ return groupConfig.Default
+ }
+
+ // Note that if a section appears for a particular group
+ // for a particular analyzer, then it will now be enabled,
+ // and the group default no longer applies.
+ return gc.shouldReport(fullPos, msg)
+}
+
+// Config is a nogo configuration.
+type Config struct {
+ // Prefixes defines a set of regular expressions that
+ // are standard "prefixes", so that files can be grouped
+ // and specific rules applied to individual groups.
+ Groups []Group `yaml:"groups"`
- // Runtime internal violations.
- ".*reflect/value.go",
- ".*encoding/xml/xml.go",
- ".*runtime/pprof/internal/profile/proto.go",
- ".*fmt/scan.go",
- ".*go/types/conversions.go",
- ".*golang.org/x/net/dns/dnsmessage/message.go",
- ),
- ),
- shadow.Analyzer: disableMatches(), // Disabled for now.
- structtag.Analyzer: internalMatches(), // External not subject to rules.
- tests.Analyzer: alwaysMatches(),
- unmarshal.Analyzer: alwaysMatches(),
- unreachable.Analyzer: internalMatches(),
- unsafeptr.Analyzer: and(
- internalMatches(),
- internalExcluded(
- ".*_test.go", // Exclude tests.
- "pkg/flipcall/.*_unsafe.go", // Special case.
- "pkg/gohacks/gohacks_unsafe.go", // Special case.
- "pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go", // Special case.
- "pkg/sentry/platform/kvm/bluepill_unsafe.go", // Special case.
- "pkg/sentry/platform/kvm/machine_unsafe.go", // Special case.
- "pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go", // Special case.
- "pkg/sentry/platform/safecopy/safecopy_unsafe.go", // Special case.
- "pkg/sentry/vfs/mount_unsafe.go", // Special case.
- "pkg/sentry/platform/systrap/stub_unsafe.go", // Special case.
- "pkg/sentry/platform/systrap/switchto_google_unsafe.go", // Special case.
- "pkg/sentry/platform/systrap/sysmsg_thread_unsafe.go", // Special case.
- ),
- ),
- unusedresult.Analyzer: alwaysMatches(),
+ // Global is the global analyzer config.
+ Global AnalyzerConfig `yaml:"global"`
- // Internal analyzers: external packages not subject.
- checkescape.Analyzer: internalMatches(),
- checkunsafe.Analyzer: internalMatches(),
+ // Analyzers are individual analyzer configurations. The
+ // key for each analyzer is the name of the analyzer. The
+ // value is either a boolean (enable/disable), or a map to
+ // the groups above.
+ Analyzers map[AnalyzerName]AnalyzerConfig `yaml:"analyzers"`
}
-func init() {
- staticMatcher := and(
- // Only match internal, non-generated files.
- internalMatches(),
- generatedExcluded(),
+// Merge merges two configurations.
+func (c *Config) Merge(other *Config) {
+ // Merge all groups.
+ for _, g := range other.Groups {
+ // Is there a matching group? If yes, we just delete
+ // it. This will preserve the order provided in the
+ // overriding file, even if it differs.
+ for i := 0; i < len(c.Groups); i++ {
+ if g.Name == c.Groups[i].Name {
+ copy(c.Groups[i:], c.Groups[i+1:])
+ c.Groups = c.Groups[:len(c.Groups)-1]
+ break
+ }
+ }
+ c.Groups = append(c.Groups, g)
+ }
- // We use ALL_CAPS for system definitions,
- // which are common enough in the code base
- // that we shouldn't annotate exceptions.
- //
- // Same story for underscores.
- resultExcluded([]string{
- "should not use ALL_CAPS in Go names",
- "should not use underscores in Go names",
- }),
+ // Merge global configurations.
+ c.Global.merge(other.Global)
- // Exclude existing matches.
- internalExcluded(
- "pkg/abi/linux/fuse.go:22",
- "pkg/abi/linux/fuse.go:25",
- "pkg/abi/linux/socket.go:113",
- "pkg/abi/linux/tty.go:73",
- "pkg/bpf/decoder.go:112",
- "pkg/cpuid/cpuid_x86.go:675",
- "pkg/eventchannel/event.go:193",
- "pkg/eventchannel/event.go:27",
- "pkg/eventchannel/event_test.go:22",
- "pkg/eventchannel/rate.go:19",
- "pkg/gohacks/gohacks_unsafe.go:33",
- "pkg/log/json.go:30",
- "pkg/log/log.go:359",
- "pkg/merkletree/merkletree.go:230",
- "pkg/merkletree/merkletree.go:243",
- "pkg/merkletree/merkletree.go:249",
- "pkg/merkletree/merkletree.go:266",
- "pkg/merkletree/merkletree.go:355",
- "pkg/merkletree/merkletree.go:369",
- "pkg/metric/metric_test.go:20",
- "pkg/p9/p9test/client_test.go:687",
- "pkg/p9/transport_test.go:196",
- "pkg/pool/pool.go:15",
- "pkg/refs/refcounter.go:510",
- "pkg/refs/refcounter_test.go:169",
- "pkg/refs_vfs2/refs.go:16",
- "pkg/safemem/block_unsafe.go:89",
- "pkg/seccomp/seccomp.go:82",
- "pkg/segment/test/set_functions.go:15",
- "pkg/sentry/arch/signal.go:166",
- "pkg/sentry/arch/signal.go:171",
- "pkg/sentry/control/pprof.go:196",
- "pkg/sentry/devices/memdev/full.go:58",
- "pkg/sentry/devices/memdev/null.go:59",
- "pkg/sentry/devices/memdev/random.go:68",
- "pkg/sentry/devices/memdev/zero.go:86",
- "pkg/sentry/fdimport/fdimport.go:15",
- "pkg/sentry/fs/attr.go:257",
- "pkg/sentry/fsbridge/fs.go:116",
- "pkg/sentry/fsbridge/vfs.go:124",
- "pkg/sentry/fsbridge/vfs.go:70",
- "pkg/sentry/fs/copy_up.go:365",
- "pkg/sentry/fs/copy_up_test.go:65",
- "pkg/sentry/fs/dev/net_tun.go:161",
- "pkg/sentry/fs/dev/net_tun.go:63",
- "pkg/sentry/fs/dev/null.go:97",
- "pkg/sentry/fs/dirent_cache.go:64",
- "pkg/sentry/fs/file_overlay.go:327",
- "pkg/sentry/fs/file_overlay.go:524",
- "pkg/sentry/fs/filetest/filetest.go:55",
- "pkg/sentry/fs/filetest/filetest.go:60",
- "pkg/sentry/fs/fs.go:77",
- "pkg/sentry/fs/fsutil/file.go:290",
- "pkg/sentry/fs/fsutil/file.go:346",
- "pkg/sentry/fs/fsutil/host_file_mapper.go:105",
- "pkg/sentry/fs/fsutil/inode_cached.go:676",
- "pkg/sentry/fs/fsutil/inode_cached.go:772",
- "pkg/sentry/fs/gofer/attr.go:120",
- "pkg/sentry/fs/gofer/fifo.go:33",
- "pkg/sentry/fs/gofer/inode.go:410",
- "pkg/sentry/fsimpl/devpts/devpts.go:110",
- "pkg/sentry/fsimpl/devpts/devpts.go:246",
- "pkg/sentry/fsimpl/devpts/devpts.go:50",
- "pkg/sentry/fsimpl/devpts/master.go:110",
- "pkg/sentry/fsimpl/devpts/master.go:55",
- "pkg/sentry/fsimpl/devpts/replica.go:113",
- "pkg/sentry/fsimpl/devpts/replica.go:57",
- "pkg/sentry/fsimpl/devtmpfs/devtmpfs.go:54",
- "pkg/sentry/fsimpl/ext/disklayout/superblock_64.go:97",
- "pkg/sentry/fsimpl/ext/disklayout/superblock_old.go:92",
- "pkg/sentry/fsimpl/ext/disklayout/block_group_32.go:44",
- "pkg/sentry/fsimpl/ext/disklayout/inode_new.go:91",
- "pkg/sentry/fsimpl/ext/disklayout/inode_old.go:93",
- "pkg/sentry/fsimpl/ext/disklayout/superblock_32.go:66",
- "pkg/sentry/fsimpl/ext/disklayout/block_group_64.go:53",
- "pkg/sentry/fsimpl/eventfd/eventfd.go:268",
- "pkg/sentry/fsimpl/ext/directory.go:163",
- "pkg/sentry/fsimpl/ext/directory.go:164",
- "pkg/sentry/fsimpl/ext/extent_file.go:142",
- "pkg/sentry/fsimpl/ext/extent_file.go:143",
- "pkg/sentry/fsimpl/ext/ext.go:105",
- "pkg/sentry/fsimpl/ext/filesystem.go:287",
- "pkg/sentry/fsimpl/ext/regular_file.go:153",
- "pkg/sentry/fsimpl/ext/symlink.go:113",
- "pkg/sentry/fsimpl/fuse/connection_control.go:194",
- "pkg/sentry/fsimpl/fuse/dev.go:387",
- "pkg/sentry/fsimpl/fuse/dev_test.go:318",
- "pkg/sentry/fsimpl/fuse/fusefs.go:102",
- "pkg/sentry/fsimpl/fuse/read_write.go:129",
- "pkg/sentry/fsimpl/fuse/request_response.go:71",
- "pkg/sentry/fsimpl/gofer/directory.go:135",
- "pkg/sentry/fsimpl/gofer/filesystem.go:679",
- "pkg/sentry/fsimpl/gofer/gofer.go:1694",
- "pkg/sentry/fsimpl/gofer/gofer.go:276",
- "pkg/sentry/fsimpl/gofer/regular_file.go:81",
- "pkg/sentry/fsimpl/gofer/special_file.go:141",
- "pkg/sentry/fsimpl/host/host.go:184",
- "pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go:50",
- "pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go:90",
- "pkg/sentry/fsimpl/kernfs/fd_impl_util.go:273",
- "pkg/sentry/fsimpl/kernfs/filesystem.go:247",
- "pkg/sentry/fsimpl/kernfs/inode_impl_util.go:320",
- "pkg/sentry/fsimpl/kernfs/inode_impl_util.go:497",
- "pkg/sentry/fsimpl/kernfs/synthetic_directory.go:52",
- "pkg/sentry/fsimpl/overlay/directory.go:119",
- "pkg/sentry/fsimpl/overlay/filesystem.go:527",
- "pkg/sentry/fsimpl/overlay/non_directory.go:152",
- "pkg/sentry/fsimpl/overlay/overlay.go:115",
- "pkg/sentry/fsimpl/overlay/overlay.go:719",
- "pkg/sentry/fsimpl/pipefs/pipefs.go:74",
- "pkg/sentry/fsimpl/proc/filesystem.go:52",
- "pkg/sentry/fsimpl/proc/filesystem.go:81",
- "pkg/sentry/fsimpl/proc/subtasks.go:126",
- "pkg/sentry/fsimpl/proc/subtasks.go:189",
- "pkg/sentry/fsimpl/proc/task_fds.go:168",
- "pkg/sentry/fsimpl/proc/task_fds.go:228",
- "pkg/sentry/fsimpl/proc/task_fds.go:301",
- "pkg/sentry/fsimpl/proc/task_fds.go:318",
- "pkg/sentry/fsimpl/proc/task_fds.go:67",
- "pkg/sentry/fsimpl/proc/task_files.go:112",
- "pkg/sentry/fsimpl/proc/task_files.go:158",
- "pkg/sentry/fsimpl/proc/task_files.go:259",
- "pkg/sentry/fsimpl/proc/task_files.go:285",
- "pkg/sentry/fsimpl/proc/task_files.go:305",
- "pkg/sentry/fsimpl/proc/task_files.go:384",
- "pkg/sentry/fsimpl/proc/task_files.go:403",
- "pkg/sentry/fsimpl/proc/task_files.go:428",
- "pkg/sentry/fsimpl/proc/task_files.go:691",
- "pkg/sentry/fsimpl/proc/task_files.go:770",
- "pkg/sentry/fsimpl/proc/task_files.go:797",
- "pkg/sentry/fsimpl/proc/task_files.go:828",
- "pkg/sentry/fsimpl/proc/task_files.go:879",
- "pkg/sentry/fsimpl/proc/task_files.go:910",
- "pkg/sentry/fsimpl/proc/task_files.go:961",
- "pkg/sentry/fsimpl/proc/task.go:127",
- "pkg/sentry/fsimpl/proc/task.go:193",
- "pkg/sentry/fsimpl/proc/task_net.go:134",
- "pkg/sentry/fsimpl/proc/task_net.go:475",
- "pkg/sentry/fsimpl/proc/task_net.go:491",
- "pkg/sentry/fsimpl/proc/task_net.go:508",
- "pkg/sentry/fsimpl/proc/task_net.go:665",
- "pkg/sentry/fsimpl/proc/task_net.go:715",
- "pkg/sentry/fsimpl/proc/task_net.go:779",
- "pkg/sentry/fsimpl/proc/tasks_files.go:113",
- "pkg/sentry/fsimpl/proc/tasks_files.go:388",
- "pkg/sentry/fsimpl/proc/tasks.go:232",
- "pkg/sentry/fsimpl/proc/tasks_sys.go:145",
- "pkg/sentry/fsimpl/proc/tasks_sys.go:181",
- "pkg/sentry/fsimpl/proc/tasks_sys.go:239",
- "pkg/sentry/fsimpl/proc/tasks_sys.go:291",
- "pkg/sentry/fsimpl/proc/tasks_sys.go:375",
- "pkg/sentry/fsimpl/signalfd/signalfd.go:124",
- "pkg/sentry/fsimpl/signalfd/signalfd.go:15",
- "pkg/sentry/fsimpl/signalfd/signalfd.go:126",
- "pkg/sentry/fsimpl/sockfs/sockfs.go:36",
- "pkg/sentry/fsimpl/sockfs/sockfs.go:79",
- "pkg/sentry/fsimpl/sys/kcov.go:49",
- "pkg/sentry/fsimpl/sys/kcov.go:99",
- "pkg/sentry/fsimpl/sys/sys.go:118",
- "pkg/sentry/fsimpl/sys/sys.go:56",
- "pkg/sentry/fsimpl/testutil/testutil.go:257",
- "pkg/sentry/fsimpl/testutil/testutil.go:260",
- "pkg/sentry/fsimpl/timerfd/timerfd.go:87",
- "pkg/sentry/fsimpl/tmpfs/directory.go:112",
- "pkg/sentry/fsimpl/tmpfs/filesystem.go:195",
- "pkg/sentry/fsimpl/tmpfs/regular_file.go:226",
- "pkg/sentry/fsimpl/tmpfs/regular_file.go:346",
- "pkg/sentry/fsimpl/tmpfs/tmpfs.go:103",
- "pkg/sentry/fsimpl/tmpfs/tmpfs.go:733",
- "pkg/sentry/fsimpl/verity/filesystem.go:490",
- "pkg/sentry/fsimpl/verity/verity.go:156",
- "pkg/sentry/fsimpl/verity/verity.go:629",
- "pkg/sentry/fsimpl/verity/verity.go:672",
- "pkg/sentry/fs/mount.go:162",
- "pkg/sentry/fs/mount.go:256",
- "pkg/sentry/fs/mount_overlay.go:144",
- "pkg/sentry/fs/mounts.go:432",
- "pkg/sentry/fs/proc/exec_args.go:104",
- "pkg/sentry/fs/proc/exec_args.go:73",
- "pkg/sentry/fs/proc/fds.go:269",
- "pkg/sentry/fs/proc/loadavg.go:33",
- "pkg/sentry/fs/proc/meminfo.go:39",
- "pkg/sentry/fs/proc/mounts.go:193",
- "pkg/sentry/fs/proc/mounts.go:84",
- "pkg/sentry/fs/proc/net.go:125",
- "pkg/sentry/fs/proc/proc.go:146",
- "pkg/sentry/fs/proc/proc.go:204",
- "pkg/sentry/fs/proc/seqfile/seqfile.go:210",
- "pkg/sentry/fs/proc/sys.go:146",
- "pkg/sentry/fs/proc/sys.go:43",
- "pkg/sentry/fs/proc/sys_net.go:113",
- "pkg/sentry/fs/proc/sys_net.go:205",
- "pkg/sentry/fs/proc/sys_net.go:233",
- "pkg/sentry/fs/proc/sys_net.go:307",
- "pkg/sentry/fs/proc/sys_net.go:335",
- "pkg/sentry/fs/proc/sys_net.go:446",
- "pkg/sentry/fs/proc/sys_net.go:456",
- "pkg/sentry/fs/proc/sys_net.go:89",
- "pkg/sentry/fs/proc/task.go:170",
- "pkg/sentry/fs/proc/task.go:322",
- "pkg/sentry/fs/proc/task.go:427",
- "pkg/sentry/fs/proc/task.go:467",
- "pkg/sentry/fs/proc/task.go:500",
- "pkg/sentry/fs/proc/task.go:784",
- "pkg/sentry/fs/proc/task.go:839",
- "pkg/sentry/fs/proc/task.go:920",
- "pkg/sentry/fs/proc/uid_gid_map.go:108",
- "pkg/sentry/fs/proc/uid_gid_map.go:79",
- "pkg/sentry/fs/proc/uptime.go:75",
- "pkg/sentry/fs/ramfs/dir.go:447",
- "pkg/sentry/fs/tmpfs/inode_file.go:436",
- "pkg/sentry/fs/tmpfs/inode_file.go:537",
- "pkg/sentry/fs/tty/dir.go:313",
- "pkg/sentry/fs/tty/master.go:131",
- "pkg/sentry/fs/tty/master.go:91",
- "pkg/sentry/fs/tty/replica.go:116",
- "pkg/sentry/fs/tty/replica.go:88",
- "pkg/sentry/kernel/auth/id_map.go:269",
- "pkg/sentry/kernel/fasync/fasync.go:67",
- "pkg/sentry/kernel/kcov.go:209",
- "pkg/sentry/kernel/kcov.go:223",
- "pkg/sentry/kernel/kernel.go:343",
- "pkg/sentry/kernel/kernel.go:368",
- "pkg/sentry/kernel/pipe/node_test.go:112",
- "pkg/sentry/kernel/pipe/node_test.go:119",
- "pkg/sentry/kernel/pipe/node_test.go:130",
- "pkg/sentry/kernel/pipe/node_test.go:137",
- "pkg/sentry/kernel/pipe/node_test.go:149",
- "pkg/sentry/kernel/pipe/node_test.go:150",
- "pkg/sentry/kernel/pipe/node_test.go:158",
- "pkg/sentry/kernel/pipe/node_test.go:174",
- "pkg/sentry/kernel/pipe/node_test.go:180",
- "pkg/sentry/kernel/pipe/node_test.go:193",
- "pkg/sentry/kernel/pipe/node_test.go:202",
- "pkg/sentry/kernel/pipe/node_test.go:205",
- "pkg/sentry/kernel/pipe/node_test.go:216",
- "pkg/sentry/kernel/pipe/node_test.go:219",
- "pkg/sentry/kernel/pipe/node_test.go:271",
- "pkg/sentry/kernel/pipe/node_test.go:290",
- "pkg/sentry/kernel/pipe/pipe_test.go:93",
- "pkg/sentry/kernel/pipe/reader_writer.go:65",
- "pkg/sentry/kernel/posixtimer.go:157",
- "pkg/sentry/kernel/ptrace.go:218",
- "pkg/sentry/kernel/semaphore/semaphore.go:323",
- "pkg/sentry/kernel/sessions.go:123",
- "pkg/sentry/kernel/sessions.go:508",
- "pkg/sentry/kernel/signal_handlers.go:57",
- "pkg/sentry/kernel/task_context.go:72",
- "pkg/sentry/kernel/task_exit.go:67",
- "pkg/sentry/kernel/task_sched.go:255",
- "pkg/sentry/kernel/task_sched.go:280",
- "pkg/sentry/kernel/task_sched.go:323",
- "pkg/sentry/kernel/task_stop.go:192",
- "pkg/sentry/kernel/thread_group.go:530",
- "pkg/sentry/kernel/timekeeper.go:316",
- "pkg/sentry/kernel/vdso.go:106",
- "pkg/sentry/kernel/vdso.go:118",
- "pkg/sentry/memmap/memmap.go:103",
- "pkg/sentry/memmap/memmap.go:163",
- "pkg/sentry/mm/address_space.go:42",
- "pkg/sentry/mm/address_space.go:42",
- "pkg/sentry/mm/aio_context.go:208",
- "pkg/sentry/mm/aio_context.go:288",
- "pkg/sentry/mm/pma.go:683",
- "pkg/sentry/mm/special_mappable.go:80",
- "pkg/sentry/platform/systrap/subprocess.go:370",
- "pkg/sentry/platform/systrap/usertrap/usertrap_amd64.go:124",
- "pkg/sentry/socket/control/control.go:260",
- "pkg/sentry/socket/control/control.go:94",
- "pkg/sentry/socket/control/control_vfs2.go:37",
- "pkg/sentry/socket/hostinet/stack.go:433",
- "pkg/sentry/socket/hostinet/stack.go:438",
- "pkg/sentry/socket/hostinet/stack.go:444",
- "pkg/sentry/socket/hostinet/stack.go:460",
- "pkg/sentry/socket/netfilter/tcp_matcher.go:74",
- "pkg/sentry/socket/netfilter/udp_matcher.go:71",
- "pkg/sentry/socket/netlink/route/protocol.go:38",
- "pkg/sentry/socket/socket.go:332",
- "pkg/sentry/socket/unix/transport/connectioned.go:394",
- "pkg/sentry/socket/unix/transport/connectionless.go:152",
- "pkg/sentry/socket/unix/transport/unix.go:436",
- "pkg/sentry/socket/unix/transport/unix.go:490",
- "pkg/sentry/socket/unix/transport/unix.go:685",
- "pkg/sentry/socket/unix/transport/unix.go:795",
- "pkg/sentry/syscalls/linux/sys_sem.go:62",
- "pkg/sentry/syscalls/linux/sys_time.go:189",
- "pkg/sentry/usage/cpu.go:42",
- "pkg/sentry/vfs/anonfs.go:302",
- "pkg/sentry/vfs/anonfs.go:99",
- "pkg/sentry/vfs/dentry.go:214",
- "pkg/sentry/vfs/epoll.go:168",
- "pkg/sentry/vfs/epoll.go:314",
- "pkg/sentry/vfs/file_description.go:549",
- "pkg/sentry/vfs/file_description_impl_util.go:304",
- "pkg/sentry/vfs/file_description_impl_util.go:412",
- "pkg/sentry/vfs/filesystem.go:76",
- "pkg/sentry/vfs/lock.go:15",
- "pkg/sentry/vfs/lock.go:47",
- "pkg/sentry/vfs/memxattr/xattr.go:37",
- "pkg/sentry/vfs/mount.go:510",
- "pkg/sentry/vfs/mount.go:667",
- "pkg/sentry/vfs/mount_test.go:106",
- "pkg/sentry/vfs/mount_test.go:160",
- "pkg/sentry/vfs/mount_test.go:215",
- "pkg/sentry/vfs/mount_unsafe.go:153",
- "pkg/sentry/vfs/resolving_path.go:228",
- "pkg/sentry/vfs/vfs.go:897",
- "pkg/shim/runsc/runsc.go:16",
- "pkg/shim/runsc/utils.go:16",
- "pkg/shim/v1/proc/deleted_state.go:16",
- "pkg/shim/v1/proc/exec.go:16",
- "pkg/shim/v1/proc/exec_state.go:16",
- "pkg/shim/v1/proc/init.go:16",
- "pkg/shim/v1/proc/init_state.go:16",
- "pkg/shim/v1/proc/io.go:16",
- "pkg/shim/v1/proc/process.go:16",
- "pkg/shim/v1/proc/types.go:16",
- "pkg/shim/v1/proc/utils.go:16",
- "pkg/shim/v1/shim/api.go:16",
- "pkg/shim/v1/shim/platform.go:16",
- "pkg/shim/v1/shim/service.go:16",
- "pkg/shim/v1/utils/annotations.go:15",
- "pkg/shim/v1/utils/utils.go:15",
- "pkg/shim/v1/utils/volumes.go:15",
- "pkg/shim/v2/api.go:16",
- "pkg/shim/v2/epoll.go:18",
- "pkg/shim/v2/options/options.go:15",
- "pkg/shim/v2/options/options.go:24",
- "pkg/shim/v2/options/options.go:26",
- "pkg/shim/v2/runtimeoptions/runtimeoptions.go:16",
- "pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go", // Generated: exempt all.
- "pkg/shim/v2/runtimeoptions/runtimeoptions_test.go:22",
- "pkg/shim/v2/service.go:15",
- "pkg/shim/v2/service_linux.go:18",
- "pkg/state/tests/integer_test.go:23",
- "pkg/state/tests/integer_test.go:28",
- "pkg/sync/rwmutex_test.go:105",
- "pkg/syserr/host_linux.go:35",
- "pkg/tcpip/adapters/gonet/gonet_test.go:144",
- "pkg/tcpip/adapters/gonet/gonet_test.go:415",
- "pkg/tcpip/adapters/gonet/gonet_test.go:99",
- "pkg/tcpip/buffer/view.go:238",
- "pkg/tcpip/buffer/view.go:238",
- "pkg/tcpip/buffer/view.go:246",
- "pkg/tcpip/header/tcp.go:151",
- "pkg/tcpip/link/sharedmem/pipe/pipe_test.go:493",
- "pkg/tcpip/stack/iptables.go:293",
- "pkg/tcpip/stack/iptables_types.go:277",
- "pkg/tcpip/stack/stack.go:553",
- "pkg/tcpip/stack/transport_test.go:30",
- "pkg/tcpip/transport/packet/endpoint.go:126",
- "pkg/tcpip/transport/raw/endpoint.go:145",
- "pkg/tcpip/transport/tcp/sack_scoreboard.go:167",
- "pkg/unet/unet_test.go:634",
- "pkg/unet/unet_test.go:662",
- "pkg/unet/unet_test.go:703",
- "pkg/unet/unet_test.go:98",
- "pkg/usermem/addr.go:34",
- "pkg/usermem/usermem.go:171",
- "pkg/usermem/usermem.go:170",
- "runsc/boot/compat.go:22",
- "runsc/boot/compat.go:56",
- "runsc/boot/loader.go:1115",
- "runsc/boot/loader.go:1120",
- "runsc/cmd/checkpoint.go:151",
- "runsc/config/flags.go:32",
- "runsc/container/container.go:641",
- "runsc/container/container.go:988",
- "runsc/specutils/specutils.go:172",
- "runsc/specutils/specutils.go:428",
- "runsc/specutils/specutils.go:436",
- "runsc/specutils/specutils.go:442",
- "runsc/specutils/specutils.go:447",
- "runsc/specutils/specutils.go:454",
- "test/cmd/test_app/fds.go:171",
- "test/iptables/filter_output.go:251",
- "test/packetimpact/testbench/connections.go:77",
- "tools/bigquery/bigquery.go:106",
- "tools/checkescape/test1/test1.go:108",
- "tools/checkescape/test1/test1.go:122",
- "tools/checkescape/test1/test1.go:137",
- "tools/checkescape/test1/test1.go:151",
- "tools/checkescape/test1/test1.go:170",
- "tools/checkescape/test1/test1.go:39",
- "tools/checkescape/test1/test1.go:45",
- "tools/checkescape/test1/test1.go:50",
- "tools/checkescape/test1/test1.go:64",
- "tools/checkescape/test1/test1.go:80",
- "tools/checkescape/test1/test1.go:94",
- "tools/go_generics/imports.go:51",
- "tools/go_generics/imports.go:75",
- "tools/go_marshal/gomarshal/generator.go:177",
- "tools/go_marshal/gomarshal/generator.go:81",
- "tools/go_marshal/gomarshal/generator.go:85",
- "tools/go_marshal/test/escape/escape.go:15",
- "tools/go_marshal/test/test.go:164",
- ),
- )
+ // Merge all analyzer configurations.
+ for name, ac := range other.Analyzers {
+ old, ok := c.Analyzers[name]
+ if !ok {
+ c.Analyzers[name] = ac // No analyzer in original config.
+ continue
+ }
+ old.merge(ac)
+ }
+}
- // Add all staticcheck analyzers; internal only.
- for _, a := range staticcheck.Analyzers {
- analyzerConfig[a] = staticMatcher
+// Compile compiles a configuration to make it useable.
+func (c *Config) Compile() error {
+ for i := 0; i < len(c.Groups); i++ {
+ if err := c.Groups[i].compile(); err != nil {
+ return fmt.Errorf("invalid group %q: %w", c.Groups[i].Name, err)
+ }
}
- // Add all stylecheck analyzers; internal only.
- for _, a := range stylecheck.Analyzers {
- analyzerConfig[a] = staticMatcher
+ if err := c.Global.compile(); err != nil {
+ return fmt.Errorf("invalid global: %w", err)
}
+ for name, ac := range c.Analyzers {
+ if err := ac.compile(); err != nil {
+ return fmt.Errorf("invalid analyzer %q: %w", name, err)
+ }
+ }
+ return nil
}
-var escapesConfig = map[*analysis.Analyzer]matcher{
- // Informational only: include all packages.
- checkescape.EscapeAnalyzer: alwaysMatches(),
+// ShouldReport returns true iff the finding should match the Config.
+func (c *Config) ShouldReport(finding Finding) bool {
+ fullPos := finding.Position.String()
+
+ // Find the matching group.
+ var groupConfig *Group
+ for i := 0; i < len(c.Groups); i++ {
+ if c.Groups[i].regex.MatchString(fullPos) {
+ groupConfig = &c.Groups[i]
+ break
+ }
+ }
+
+ // If there is no group matching this path, then
+ // we default to accept the finding.
+ if groupConfig == nil {
+ return true
+ }
+
+ // Suppress via global rule?
+ if !c.Global.shouldReport(groupConfig, fullPos, finding.Message) {
+ return false
+ }
+
+ // Try the analyzer config.
+ ac, ok := c.Analyzers[finding.Category]
+ if !ok {
+ return groupConfig.Default
+ }
+ return ac.shouldReport(groupConfig, fullPos, finding.Message)
}
diff --git a/tools/nogo/defs.bzl b/tools/nogo/defs.bzl
index 543598b52..161ea972e 100644
--- a/tools/nogo/defs.bzl
+++ b/tools/nogo/defs.bzl
@@ -1,6 +1,28 @@
"""Nogo rules."""
-load("//tools/bazeldefs:go.bzl", "go_context", "go_importpath", "go_rule", "go_test_library")
+load("//tools/bazeldefs:go.bzl", "go_context", "go_embed_libraries", "go_importpath", "go_rule")
+
+NogoConfigInfo = provider(
+ "information about a nogo configuration",
+ fields = {
+ "srcs": "the collection of configuration files",
+ },
+)
+
+def _nogo_config_impl(ctx):
+ return [NogoConfigInfo(
+ srcs = ctx.files.srcs,
+ )]
+
+nogo_config = rule(
+ implementation = _nogo_config_impl,
+ attrs = {
+ "srcs": attr.label_list(
+ doc = "a list of yaml files (schema defined by tool/nogo/config.go).",
+ allow_files = True,
+ ),
+ },
+)
NogoTargetInfo = provider(
"information about the Go target",
@@ -20,11 +42,14 @@ nogo_target = go_rule(
rule,
implementation = _nogo_target_impl,
attrs = {
- # goarch is the build architecture. This will normally be provided by a
- # select statement, but this information is propagated to other rules.
- "goarch": attr.string(mandatory = True),
- # goos is similarly the build operating system target.
- "goos": attr.string(mandatory = True),
+ "goarch": attr.string(
+ doc = "the Go build architecture (propagated to other rules).",
+ mandatory = True,
+ ),
+ "goos": attr.string(
+ doc = "the Go OS target (propagated to other rules).",
+ mandatory = True,
+ ),
},
)
@@ -81,7 +106,7 @@ NogoStdlibInfo = provider(
"information for nogo analysis (standard library facts)",
fields = {
"facts": "serialized standard library facts",
- "findings": "package findings (if relevant)",
+ "raw_findings": "raw package findings (if relevant)",
},
)
@@ -90,7 +115,7 @@ def _nogo_stdlib_impl(ctx):
nogo_target_info = ctx.attr._nogo_target[NogoTargetInfo]
go_ctx = go_context(ctx, goos = nogo_target_info.goos, goarch = nogo_target_info.goarch)
facts = ctx.actions.declare_file(ctx.label.name + ".facts")
- findings = ctx.actions.declare_file(ctx.label.name + ".findings")
+ raw_findings = ctx.actions.declare_file(ctx.label.name + ".raw_findings")
config = struct(
Srcs = [f.path for f in go_ctx.stdlib_srcs],
GOOS = go_ctx.goos,
@@ -101,15 +126,15 @@ def _nogo_stdlib_impl(ctx):
ctx.actions.write(config_file, config.to_json())
ctx.actions.run(
inputs = [config_file] + go_ctx.stdlib_srcs,
- outputs = [facts, findings],
+ outputs = [facts, raw_findings],
tools = depset(go_ctx.runfiles.to_list() + ctx.files._nogo_objdump_tool),
executable = ctx.files._nogo_check[0],
- mnemonic = "GoStandardLibraryAnalysis",
+ mnemonic = "NogoStandardLibraryAnalysis",
progress_message = "Analyzing Go Standard Library",
arguments = go_ctx.nogo_args + [
"-objdump_tool=%s" % ctx.files._nogo_objdump_tool[0].path,
"-stdlib=%s" % config_file.path,
- "-findings=%s" % findings.path,
+ "-findings=%s" % raw_findings.path,
"-facts=%s" % facts.path,
],
)
@@ -117,7 +142,7 @@ def _nogo_stdlib_impl(ctx):
# Return the stdlib facts as output.
return [NogoStdlibInfo(
facts = facts,
- findings = findings,
+ raw_findings = raw_findings,
)]
nogo_stdlib = go_rule(
@@ -148,7 +173,8 @@ NogoInfo = provider(
"information for nogo analysis",
fields = {
"facts": "serialized package facts",
- "findings": "package findings (if relevant)",
+ "raw_findings": "raw package findings (if relevant)",
+ "escapes": "escape-only findings (if relevant)",
"importpath": "package import path",
"binaries": "package binary files",
"srcs": "srcs (for go_test support)",
@@ -174,14 +200,12 @@ def _nogo_aspect_impl(target, ctx):
# If we're using the "library" attribute, then we need to aggregate the
# original library sources and dependencies into this target to perform
# proper type analysis.
- if ctx.rule.kind == "go_test":
- library = go_test_library(ctx.rule)
- if library != None:
- info = library[NogoInfo]
- if hasattr(info, "srcs"):
- srcs = srcs + info.srcs
- if hasattr(info, "deps"):
- deps = deps + info.deps
+ for embed in go_embed_libraries(ctx.rule):
+ info = embed[NogoInfo]
+ if hasattr(info, "srcs"):
+ srcs = srcs + info.srcs
+ if hasattr(info, "deps"):
+ deps = deps + info.deps
# Start with all target files and srcs as input.
inputs = target.files.to_list() + srcs
@@ -214,6 +238,7 @@ def _nogo_aspect_impl(target, ctx):
# Collect all info from shadow dependencies.
fact_map = dict()
import_map = dict()
+ all_raw_findings = []
for dep in deps:
# There will be no file attribute set for all transitive dependencies
# that are not go_library or go_binary rules, such as a proto rules.
@@ -231,6 +256,9 @@ def _nogo_aspect_impl(target, ctx):
import_map[info.importpath] = x_files[0]
fact_map[info.importpath] = info.facts.path
+ # Collect all findings; duplicates are resolved at the end.
+ all_raw_findings.extend(info.raw_findings)
+
# Ensure the above are available as inputs.
inputs.append(info.facts)
inputs += info.binaries
@@ -244,7 +272,7 @@ def _nogo_aspect_impl(target, ctx):
nogo_target_info = ctx.attr._nogo_target[NogoTargetInfo]
go_ctx = go_context(ctx, goos = nogo_target_info.goos, goarch = nogo_target_info.goarch)
facts = ctx.actions.declare_file(target.label.name + ".facts")
- findings = ctx.actions.declare_file(target.label.name + ".findings")
+ raw_findings = ctx.actions.declare_file(target.label.name + ".raw_findings")
escapes = ctx.actions.declare_file(target.label.name + ".escapes")
config = struct(
ImportPath = importpath,
@@ -262,39 +290,39 @@ def _nogo_aspect_impl(target, ctx):
inputs.append(config_file)
ctx.actions.run(
inputs = inputs,
- outputs = [facts, findings, escapes],
+ outputs = [facts, raw_findings, escapes],
tools = depset(go_ctx.runfiles.to_list() + ctx.files._nogo_objdump_tool),
executable = ctx.files._nogo_check[0],
- mnemonic = "GoStaticAnalysis",
+ mnemonic = "NogoAnalysis",
progress_message = "Analyzing %s" % target.label,
arguments = go_ctx.nogo_args + [
"-binary=%s" % target_objfile.path,
"-objdump_tool=%s" % ctx.files._nogo_objdump_tool[0].path,
"-package=%s" % config_file.path,
- "-findings=%s" % findings.path,
+ "-findings=%s" % raw_findings.path,
"-facts=%s" % facts.path,
"-escapes=%s" % escapes.path,
],
)
+ # Flatten all findings from all dependencies.
+ #
+ # This is done because all the filtering must be done at the
+ # top-level nogo_test to dynamically apply a configuration.
+ # This does not actually add any additional work here, but
+ # will simply propagate the full list of files.
+ all_raw_findings = [stdlib_info.raw_findings] + depset(all_raw_findings).to_list() + [raw_findings]
+
# Return the package facts as output.
- return [
- NogoInfo(
- facts = facts,
- findings = findings,
- importpath = importpath,
- binaries = binaries,
- srcs = srcs,
- deps = deps,
- ),
- OutputGroupInfo(
- # Expose all findings (should just be a single file). This can be
- # used for build analysis of the nogo findings.
- nogo_findings = depset([findings]),
- # Expose all escape analysis findings (see above).
- nogo_escapes = depset([escapes]),
- ),
- ]
+ return [NogoInfo(
+ facts = facts,
+ raw_findings = all_raw_findings,
+ escapes = escapes,
+ importpath = importpath,
+ binaries = binaries,
+ srcs = srcs,
+ deps = deps,
+ )]
nogo_aspect = go_rule(
aspect,
@@ -327,41 +355,72 @@ nogo_aspect = go_rule(
def _nogo_test_impl(ctx):
"""Check nogo findings."""
- # Build a runner that checks the facts files.
- findings = [dep[NogoInfo].findings for dep in ctx.attr.deps]
- runner = ctx.actions.declare_file(ctx.label.name)
+ # Ensure there's a single dependency.
+ if len(ctx.attr.deps) != 1:
+ fail("nogo_test requires exactly one dep.")
+ raw_findings = ctx.attr.deps[0][NogoInfo].raw_findings
+ escapes = ctx.attr.deps[0][NogoInfo].escapes
+
+ # Build a step that applies the configuration.
+ config_srcs = ctx.attr.config[NogoConfigInfo].srcs
+ findings = ctx.actions.declare_file(ctx.label.name + ".findings")
ctx.actions.run(
- inputs = findings + ctx.files.srcs,
- outputs = [runner],
- tools = depset(ctx.files._gentest),
- executable = ctx.files._gentest[0],
- mnemonic = "Gentest",
+ inputs = raw_findings + ctx.files.srcs + config_srcs,
+ outputs = [findings],
+ tools = depset(ctx.files._filter),
+ executable = ctx.files._filter[0],
+ mnemonic = "GoStaticAnalysis",
progress_message = "Generating %s" % ctx.label,
- arguments = [runner.path] + [f.path for f in findings],
+ arguments = ["-input=%s" % f.path for f in raw_findings] +
+ ["-config=%s" % f.path for f in config_srcs] +
+ ["-output=%s" % findings.path],
)
+
+ # Build a runner that checks the filtered facts.
+ #
+ # Note that this calls the filter binary without any configuration, so all
+ # findings will be included. But this is expected, since we've already
+ # filtered out everything that should not be included.
+ runner = ctx.actions.declare_file(ctx.label.name)
+ runner_content = [
+ "#!/bin/bash",
+ "exec %s -input=%s" % (ctx.files._filter[0].short_path, findings.short_path),
+ "",
+ ]
+ ctx.actions.write(runner, "\n".join(runner_content), is_executable = True)
+
return [DefaultInfo(
+ # The runner just executes the filter again, on the
+ # newly generated filtered findings. We still need
+ # the filter tool as part of our runfiles, however.
+ runfiles = ctx.runfiles(files = ctx.files._filter + [findings]),
executable = runner,
+ ), OutputGroupInfo(
+ # Propagate the filtered filters, for consumption by
+ # build tooling. Note that the build tooling typically
+ # pays attention to the mnemoic above, so this must be
+ # what is expected by the tooling.
+ nogo_findings = depset([findings]),
+ # Expose all escape analysis findings (see above).
+ nogo_escapes = depset([escapes]),
)]
-_nogo_test = rule(
+nogo_test = rule(
implementation = _nogo_test_impl,
attrs = {
- # deps should have only a single element.
- "deps": attr.label_list(aspects = [nogo_aspect]),
- # srcs exist here only to ensure that this target is
- # directly affected by changes to the source files.
- "srcs": attr.label_list(allow_files = True),
- "_gentest": attr.label(default = "//tools/nogo:gentest"),
+ "config": attr.label(
+ mandatory = True,
+ doc = "A rule of kind nogo_config.",
+ ),
+ "deps": attr.label_list(
+ aspects = [nogo_aspect],
+ doc = "Exactly one Go dependency to be analyzed.",
+ ),
+ "srcs": attr.label_list(
+ allow_files = True,
+ doc = "Relevant src files. This is ignored except to make the nogo_test directly affected by the files.",
+ ),
+ "_filter": attr.label(default = "//tools/nogo/filter:filter"),
},
test = True,
)
-
-def nogo_test(name, srcs, library, **kwargs):
- tags = kwargs.pop("tags", []) + ["nogo"]
- _nogo_test(
- name = name,
- srcs = srcs,
- deps = [library],
- tags = tags,
- **kwargs
- )
diff --git a/tools/nogo/filter/BUILD b/tools/nogo/filter/BUILD
new file mode 100644
index 000000000..e56a783e2
--- /dev/null
+++ b/tools/nogo/filter/BUILD
@@ -0,0 +1,14 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "filter",
+ srcs = ["main.go"],
+ nogo = False,
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tools/nogo",
+ "@in_gopkg_yaml_v2//:go_default_library",
+ ],
+)
diff --git a/tools/nogo/filter/main.go b/tools/nogo/filter/main.go
new file mode 100644
index 000000000..9cf41b3b0
--- /dev/null
+++ b/tools/nogo/filter/main.go
@@ -0,0 +1,131 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary check is the nogo entrypoint.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "os"
+ "strings"
+
+ yaml "gopkg.in/yaml.v2"
+ "gvisor.dev/gvisor/tools/nogo"
+)
+
+type stringList []string
+
+func (s *stringList) String() string {
+ return strings.Join(*s, ",")
+}
+
+func (s *stringList) Set(value string) error {
+ *s = append(*s, value)
+ return nil
+}
+
+var (
+ inputFiles stringList
+ configFiles stringList
+ outputFile string
+ showConfig bool
+)
+
+func init() {
+ flag.Var(&inputFiles, "input", "findings input files")
+ flag.StringVar(&outputFile, "output", "", "findings output file")
+ flag.Var(&configFiles, "config", "findings configuration files")
+ flag.BoolVar(&showConfig, "show-config", false, "dump configuration only")
+}
+
+func main() {
+ flag.Parse()
+
+ // Load all available findings.
+ var findings []nogo.Finding
+ for _, filename := range inputFiles {
+ inputFindings, err := nogo.ExtractFindingsFromFile(filename)
+ if err != nil {
+ log.Fatalf("unable to extract findings from %s: %v", filename, err)
+ }
+ findings = append(findings, inputFindings...)
+ }
+
+ // Open and merge all configuations.
+ config := &nogo.Config{
+ Global: make(nogo.AnalyzerConfig),
+ Analyzers: make(map[nogo.AnalyzerName]nogo.AnalyzerConfig),
+ }
+ for _, filename := range configFiles {
+ content, err := ioutil.ReadFile(filename)
+ if err != nil {
+ log.Fatalf("unable to read %s: %v", filename, err)
+ }
+ var newConfig nogo.Config // For current file.
+ if err := yaml.Unmarshal(content, &newConfig); err != nil {
+ log.Fatalf("unable to decode %s: %v", filename, err)
+ }
+ config.Merge(&newConfig)
+ if showConfig {
+ bytes, err := yaml.Marshal(&newConfig)
+ if err != nil {
+ log.Fatalf("error marshalling config: %v", err)
+ }
+ mergedBytes, err := yaml.Marshal(config)
+ if err != nil {
+ log.Fatalf("error marshalling config: %v", err)
+ }
+ fmt.Fprintf(os.Stdout, "Loaded configuration from %s:\n%s\n", filename, string(bytes))
+ fmt.Fprintf(os.Stdout, "Merged configuration:\n%s\n", string(mergedBytes))
+ }
+ }
+ if err := config.Compile(); err != nil {
+ log.Fatalf("error compiling config: %v", err)
+ }
+ if showConfig {
+ os.Exit(0)
+ }
+
+ // Filter the findings (and aggregate by group).
+ filteredFindings := make([]nogo.Finding, 0, len(findings))
+ for _, finding := range findings {
+ if ok := config.ShouldReport(finding); ok {
+ filteredFindings = append(filteredFindings, finding)
+ }
+ }
+
+ // Write the output (if required).
+ //
+ // If the outputFile is specified, then we exit here. Otherwise,
+ // we continue to write to stdout and treat like a test.
+ if outputFile != "" {
+ if err := nogo.WriteFindingsToFile(filteredFindings, outputFile); err != nil {
+ log.Fatalf("unable to write findings: %v", err)
+ }
+ return
+ }
+
+ // Treat the run as a test.
+ if len(filteredFindings) == 0 {
+ fmt.Fprintf(os.Stdout, "PASS\n")
+ os.Exit(0)
+ }
+ for _, finding := range filteredFindings {
+ fmt.Fprintf(os.Stdout, "%s\n", finding.String())
+ }
+ os.Exit(1)
+}
diff --git a/tools/nogo/findings.go b/tools/nogo/findings.go
new file mode 100644
index 000000000..5bd850269
--- /dev/null
+++ b/tools/nogo/findings.go
@@ -0,0 +1,63 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package nogo
+
+import (
+ "encoding/json"
+ "fmt"
+ "go/token"
+ "io/ioutil"
+)
+
+// Finding is a single finding.
+type Finding struct {
+ Category AnalyzerName
+ Position token.Position
+ Message string
+}
+
+// String implements fmt.Stringer.String.
+func (f *Finding) String() string {
+ return fmt.Sprintf("%s: %s: %s", f.Category, f.Position.String(), f.Message)
+}
+
+// WriteFindingsToFile writes findings to a file.
+func WriteFindingsToFile(findings []Finding, filename string) error {
+ content, err := WriteFindingsToBytes(findings)
+ if err != nil {
+ return err
+ }
+ return ioutil.WriteFile(filename, content, 0644)
+}
+
+// WriteFindingsToBytes serializes findings as bytes.
+func WriteFindingsToBytes(findings []Finding) ([]byte, error) {
+ return json.Marshal(findings)
+}
+
+// ExtractFindingsFromFile loads findings from a file.
+func ExtractFindingsFromFile(filename string) ([]Finding, error) {
+ content, err := ioutil.ReadFile(filename)
+ if err != nil {
+ return nil, err
+ }
+ return ExtractFindingsFromBytes(content)
+}
+
+// ExtractFindingsFromBytes loads findings from bytes.
+func ExtractFindingsFromBytes(content []byte) (findings []Finding, err error) {
+ err = json.Unmarshal(content, &findings)
+ return findings, err
+}
diff --git a/tools/nogo/gentest.sh b/tools/nogo/gentest.sh
deleted file mode 100755
index 0a762f9f6..000000000
--- a/tools/nogo/gentest.sh
+++ /dev/null
@@ -1,48 +0,0 @@
-#!/bin/bash
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-set -euo pipefail
-
-if [[ "$#" -lt 2 ]]; then
- echo "usage: $0 <output> <findings...>"
- exit 2
-fi
-declare violations=0
-declare output=$1
-shift
-
-# Start the script.
-echo "#!/bin/sh" > "${output}"
-
-# Read a list of findings files.
-declare filename
-declare line
-for filename in "$@"; do
- if [[ -z "${filename}" ]]; then
- continue
- fi
- while read -r line; do
- line="${line@Q}"
- violations=$((${violations}+1));
- echo "echo -e '\\033[0;31m${line}\\033[0;31m\\033[0m'" >> "${output}"
- done < "${filename}"
-done
-
-# Show violations.
-if [[ "${violations}" -eq 0 ]]; then
- echo "echo -e '\\033[0;32mPASS\\033[0;31m\\033[0m'" >> "${output}"
-else
- echo "exit 1" >> "${output}"
-fi
diff --git a/tools/nogo/matchers.go b/tools/nogo/matchers.go
deleted file mode 100644
index b7b73fa27..000000000
--- a/tools/nogo/matchers.go
+++ /dev/null
@@ -1,172 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package nogo
-
-import (
- "go/token"
- "regexp"
- "strings"
-
- "golang.org/x/tools/go/analysis"
-)
-
-type matcher interface {
- ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool
-}
-
-// pathRegexps filters explicit paths.
-type pathRegexps struct {
- expr []*regexp.Regexp
-
- // include, if true, indicates that paths matching any regexp in expr
- // match.
- //
- // If false, paths matching no regexps in expr match.
- include bool
-}
-
-// buildRegexps builds a list of regular expressions.
-//
-// This will panic on error.
-func buildRegexps(prefix string, args ...string) []*regexp.Regexp {
- result := make([]*regexp.Regexp, 0, len(args))
- for _, arg := range args {
- result = append(result, regexp.MustCompile(prefix+arg))
- }
- return result
-}
-
-// notPath works around the lack of backtracking.
-//
-// It is used to construct a regular expression for non-matching components.
-func notPath(name string) string {
- sb := strings.Builder{}
- sb.WriteString("(")
- for i := range name {
- if i > 0 {
- sb.WriteString("|")
- }
- sb.WriteString(name[:i])
- sb.WriteString("[^")
- sb.WriteByte(name[i])
- sb.WriteString("/][^/]*")
- }
- sb.WriteString(")")
- return sb.String()
-}
-
-// ShouldReport implements matcher.ShouldReport.
-func (p *pathRegexps) ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool {
- fullPos := fs.Position(d.Pos).String()
- for _, path := range p.expr {
- if path.MatchString(fullPos) {
- return p.include
- }
- }
- return !p.include
-}
-
-// internalExcluded excludes specific internal paths.
-func internalExcluded(paths ...string) *pathRegexps {
- return &pathRegexps{
- expr: buildRegexps(internalPrefix, paths...),
- include: false,
- }
-}
-
-// excludedExcluded excludes specific external paths.
-func externalExcluded(paths ...string) *pathRegexps {
- return &pathRegexps{
- expr: buildRegexps(externalPrefix, paths...),
- include: false,
- }
-}
-
-// internalMatches returns a path matcher for internal packages.
-func internalMatches() *pathRegexps {
- return &pathRegexps{
- expr: buildRegexps(internalPrefix, internalDefault),
- include: true,
- }
-}
-
-// generatedExcluded excludes all generated code.
-func generatedExcluded() *pathRegexps {
- return &pathRegexps{
- expr: buildRegexps(generatedPrefix, ".*"),
- include: false,
- }
-}
-
-// resultExcluded excludes explicit message contents.
-type resultExcluded []string
-
-// ShouldReport implements matcher.ShouldReport.
-func (r resultExcluded) ShouldReport(d analysis.Diagnostic, _ *token.FileSet) bool {
- for _, str := range r {
- if strings.Contains(d.Message, str) {
- return false
- }
- }
- return true // Not excluded.
-}
-
-// andMatcher is a composite matcher.
-type andMatcher struct {
- all []matcher
-}
-
-// ShouldReport implements matcher.ShouldReport.
-func (a *andMatcher) ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool {
- for _, m := range a.all {
- if !m.ShouldReport(d, fs) {
- return false
- }
- }
- return true
-}
-
-// and is a syntactic convension for andMatcher.
-func and(ms ...matcher) *andMatcher {
- return &andMatcher{
- all: ms,
- }
-}
-
-// anyMatcher matches everything.
-type anyMatcher struct{}
-
-// ShouldReport implements matcher.ShouldReport.
-func (anyMatcher) ShouldReport(analysis.Diagnostic, *token.FileSet) bool {
- return true
-}
-
-// alwaysMatches returns an anyMatcher instance.
-func alwaysMatches() anyMatcher {
- return anyMatcher{}
-}
-
-// neverMatcher will never match.
-type neverMatcher struct{}
-
-// ShouldReport implements matcher.ShouldReport.
-func (neverMatcher) ShouldReport(analysis.Diagnostic, *token.FileSet) bool {
- return false
-}
-
-// disableMatches returns a neverMatcher instance.
-func disableMatches() neverMatcher {
- return neverMatcher{}
-}
diff --git a/tools/nogo/nogo.go b/tools/nogo/nogo.go
index e19e3c237..779d4d6d8 100644
--- a/tools/nogo/nogo.go
+++ b/tools/nogo/nogo.go
@@ -21,7 +21,6 @@ package nogo
import (
"encoding/json"
"errors"
- "flag"
"fmt"
"go/ast"
"go/build"
@@ -45,20 +44,20 @@ import (
"gvisor.dev/gvisor/tools/checkescape"
)
-// stdlibConfig is serialized as the configuration.
+// StdlibConfig is serialized as the configuration.
//
// This contains everything required for stdlib analysis.
-type stdlibConfig struct {
+type StdlibConfig struct {
Srcs []string
GOOS string
GOARCH string
Tags []string
}
-// packageConfig is serialized as the configuration.
+// PackageConfig is serialized as the configuration.
//
// This contains everything required for single package analysis.
-type packageConfig struct {
+type PackageConfig struct {
ImportPath string
GoFiles []string
NonGoFiles []string
@@ -84,7 +83,7 @@ type saver func([]byte) error
//
// This is done because all stdlib data is stored together, and we don't want
// to load this data many times over.
-func (c *packageConfig) factLoader() (loader, error) {
+func (c *PackageConfig) factLoader() (loader, error) {
allFacts := make(map[string][]byte)
if c.StdlibFacts != "" {
data, err := ioutil.ReadFile(c.StdlibFacts)
@@ -114,7 +113,7 @@ func (c *packageConfig) factLoader() (loader, error) {
// shouldInclude indicates whether the file should be included.
//
// NOTE: This does only basic parsing of tags.
-func (c *packageConfig) shouldInclude(path string) (bool, error) {
+func (c *PackageConfig) shouldInclude(path string) (bool, error) {
ctx := build.Default
ctx.GOOS = c.GOOS
ctx.GOARCH = c.GOARCH
@@ -128,7 +127,7 @@ func (c *packageConfig) shouldInclude(path string) (bool, error) {
// files, and the facts. Note that this importer implementation will always
// pass when a given package is not available.
type importer struct {
- *packageConfig
+ *PackageConfig
fset *token.FileSet
cache map[string]*types.Package
lastErr error
@@ -185,14 +184,14 @@ func (i *importer) Import(path string) (*types.Package, error) {
// ErrSkip indicates the package should be skipped.
var ErrSkip = errors.New("skipped")
-// checkStdlib checks the standard library.
+// CheckStdlib checks the standard library.
//
// This constructs a synthetic package configuration for each library in the
-// standard library sources, and call checkPackage repeatedly.
+// standard library sources, and call CheckPackage repeatedly.
//
// Note that not all parts of the source are expected to build. We skip obvious
// test files, and cmd files, which should not be dependencies.
-func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]string, []byte, error) {
+func CheckStdlib(config *StdlibConfig, analyzers []*analysis.Analyzer) (allFindings []Finding, facts []byte, err error) {
if len(config.Srcs) == 0 {
return nil, nil, nil
}
@@ -225,7 +224,7 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str
}
// Aggregate all files by directory.
- packages := make(map[string]*packageConfig)
+ packages := make(map[string]*PackageConfig)
for _, file := range config.Srcs {
if !strings.HasPrefix(file, rootSrcPrefix) {
// Superflouous file.
@@ -243,7 +242,7 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str
}
c, ok := packages[pkg]
if !ok {
- c = &packageConfig{
+ c = &PackageConfig{
ImportPath: pkg,
GOOS: config.GOOS,
GOARCH: config.GOARCH,
@@ -262,7 +261,6 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str
}
// Closure to check a single package.
- allFindings := make([]string, 0)
stdlibFacts := make(map[string][]byte)
stdlibErrs := make(map[string]error)
var checkOne func(pkg string) error // Recursive.
@@ -301,7 +299,7 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str
}()
// Run the analysis.
- findings, factData, err := checkPackage(config, ac, checkOne)
+ findings, factData, err := CheckPackage(config, analyzers, checkOne)
if err != nil {
// If we can't analyze a package from the standard library,
// then we skip it. It will simply not have any findings.
@@ -344,7 +342,7 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str
return allFindings, factData, nil
}
-// checkPackage runs all analyzers.
+// CheckPackage runs all given analyzers.
//
// The implementation was adapted from [1], which was in turn adpated from [2].
// This returns a list of matching analysis issues, or an error if the analysis
@@ -352,9 +350,9 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str
//
// [1] bazelbuid/rules_go/tools/builders/nogo_main.go
// [2] golang.org/x/tools/go/checker/internal/checker
-func checkPackage(config *packageConfig, ac map[*analysis.Analyzer]matcher, importCallback func(string) error) ([]string, []byte, error) {
+func CheckPackage(config *PackageConfig, analyzers []*analysis.Analyzer, importCallback func(string) error) (findings []Finding, factData []byte, err error) {
imp := &importer{
- packageConfig: config,
+ PackageConfig: config,
fset: token.NewFileSet(),
cache: make(map[string]*types.Package),
callback: importCallback,
@@ -406,7 +404,6 @@ func checkPackage(config *packageConfig, ac map[*analysis.Analyzer]matcher, impo
// Register fact types and establish dependencies between analyzers.
// The visit closure will execute recursively, and populate results
// will all required analysis results.
- diagnostics := make(map[*analysis.Analyzer][]analysis.Diagnostic)
results := make(map[*analysis.Analyzer]interface{})
var visit func(*analysis.Analyzer) error // For recursion.
visit = func(a *analysis.Analyzer) error {
@@ -421,27 +418,25 @@ func checkPackage(config *packageConfig, ac map[*analysis.Analyzer]matcher, impo
}
}
- // Prepare the matcher.
- m := ac[a]
- report := func(d analysis.Diagnostic) {
- if m.ShouldReport(d, imp.fset) {
- diagnostics[a] = append(diagnostics[a], d)
- }
- }
-
// Run the analysis.
factFilter := make(map[reflect.Type]bool)
for _, f := range a.FactTypes {
factFilter[reflect.TypeOf(f)] = true
}
p := &analysis.Pass{
- Analyzer: a,
- Fset: imp.fset,
- Files: syntax,
- Pkg: types,
- TypesInfo: typesInfo,
- ResultOf: results, // All results.
- Report: report,
+ Analyzer: a,
+ Fset: imp.fset,
+ Files: syntax,
+ Pkg: types,
+ TypesInfo: typesInfo,
+ ResultOf: results, // All results.
+ Report: func(d analysis.Diagnostic) {
+ findings = append(findings, Finding{
+ Category: AnalyzerName(a.Name),
+ Position: imp.fset.Position(d.Pos),
+ Message: d.Message,
+ })
+ },
ImportPackageFact: facts.ImportPackageFact,
ExportPackageFact: facts.ExportPackageFact,
ImportObjectFact: facts.ImportObjectFact,
@@ -464,7 +459,7 @@ func checkPackage(config *packageConfig, ac map[*analysis.Analyzer]matcher, impo
}
// Visit all analyzers recursively.
- for a, _ := range ac {
+ for _, a := range analyzers {
if imp.lastErr == ErrSkip {
continue // No local analysis.
}
@@ -473,114 +468,6 @@ func checkPackage(config *packageConfig, ac map[*analysis.Analyzer]matcher, impo
}
}
- // Convert all diagnostics to strings.
- findings := make([]string, 0, len(diagnostics))
- for a, ds := range diagnostics {
- for _, d := range ds {
- // Include the anlyzer name for debugability and configuration.
- findings = append(findings, fmt.Sprintf("%s: %s: %s", a.Name, imp.fset.Position(d.Pos), d.Message))
- }
- }
-
// Return all findings.
- factData := facts.Encode()
- return findings, factData, nil
-}
-
-var (
- packageFile = flag.String("package", "", "package configuration file (in JSON format)")
- stdlibFile = flag.String("stdlib", "", "stdlib configuration file (in JSON format)")
- findingsOutput = flag.String("findings", "", "output file (or stdout, if not specified)")
- factsOutput = flag.String("facts", "", "output file for facts (optional)")
- escapesOutput = flag.String("escapes", "", "output file for escapes (optional)")
-)
-
-func loadConfig(file string, config interface{}) interface{} {
- // Load the configuration.
- f, err := os.Open(file)
- if err != nil {
- log.Fatalf("unable to open configuration %q: %v", file, err)
- }
- defer f.Close()
- dec := json.NewDecoder(f)
- dec.DisallowUnknownFields()
- if err := dec.Decode(config); err != nil {
- log.Fatalf("unable to decode configuration: %v", err)
- }
- return config
-}
-
-// Main is the entrypoint; it should be called directly from main.
-//
-// N.B. This package registers it's own flags.
-func Main() {
- // Parse all flags.
- flag.Parse()
-
- var (
- findings []string
- factData []byte
- err error
- )
-
- // Check the configuration.
- if *packageFile != "" && *stdlibFile != "" {
- log.Fatalf("unable to perform stdlib and package analysis; provide only one!")
- } else if *stdlibFile != "" {
- // Perform basic analysis.
- c := loadConfig(*stdlibFile, new(stdlibConfig)).(*stdlibConfig)
- findings, factData, err = checkStdlib(c, analyzerConfig)
- } else if *packageFile != "" {
- // Perform basic analysis.
- c := loadConfig(*packageFile, new(packageConfig)).(*packageConfig)
- findings, factData, err = checkPackage(c, analyzerConfig, nil)
- // Do we need to do escape analysis?
- if *escapesOutput != "" {
- f, err := os.OpenFile(*escapesOutput, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
- if err != nil {
- log.Fatalf("unable to open output %q: %v", *escapesOutput, err)
- }
- defer f.Close()
- escapes, _, err := checkPackage(c, escapesConfig, nil)
- if err != nil {
- log.Fatalf("error performing escape analysis: %v", err)
- }
- for _, escape := range escapes {
- fmt.Fprintf(f, "%s\n", escape)
- }
- }
- } else {
- log.Fatalf("please provide at least one of package or stdlib!")
- }
-
- // Save facts.
- if *factsOutput != "" {
- if err := ioutil.WriteFile(*factsOutput, factData, 0644); err != nil {
- log.Fatalf("error saving findings to %q: %v", *factsOutput, err)
- }
- }
-
- // Open the output file.
- var w io.Writer = os.Stdout
- if *findingsOutput != "" {
- f, err := os.OpenFile(*findingsOutput, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
- if err != nil {
- log.Fatalf("unable to open output %q: %v", *findingsOutput, err)
- }
- defer f.Close()
- w = f
- }
-
- // Handle findings & errors.
- if err != nil {
- log.Fatalf("error checking package: %v", err)
- }
- if len(findings) == 0 {
- return
- }
-
- // Print findings.
- for _, finding := range findings {
- fmt.Fprintf(w, "%s\n", finding)
- }
+ return findings, facts.Encode(), nil
}
diff --git a/tools/nogo/register.go b/tools/nogo/register.go
deleted file mode 100644
index 34b173937..000000000
--- a/tools/nogo/register.go
+++ /dev/null
@@ -1,67 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package nogo
-
-import (
- "encoding/gob"
- "log"
-
- "golang.org/x/tools/go/analysis"
-)
-
-// analyzers returns all configured analyzers.
-func analyzers() (all []*analysis.Analyzer) {
- for a, _ := range analyzerConfig {
- all = append(all, a)
- }
- for a, _ := range escapesConfig {
- all = append(all, a)
- }
- return all
-}
-
-func init() {
- // Validate basic configuration.
- if err := analysis.Validate(analyzers()); err != nil {
- log.Fatalf("unable to validate analyzer: %v", err)
- }
-
- // Register all fact types.
- //
- // N.B. This needs to be done recursively, because there may be
- // analyzers in the Requires list that do not appear explicitly above.
- registered := make(map[*analysis.Analyzer]struct{})
- var register func(*analysis.Analyzer)
- register = func(a *analysis.Analyzer) {
- if _, ok := registered[a]; ok {
- return
- }
-
- // Regsiter dependencies.
- for _, da := range a.Requires {
- register(da)
- }
-
- // Register local facts.
- for _, f := range a.FactTypes {
- gob.Register(f)
- }
-
- registered[a] = struct{}{} // Done.
- }
- for _, a := range analyzers() {
- register(a)
- }
-}
diff --git a/tools/nogo/util/BUILD b/tools/nogo/util/BUILD
deleted file mode 100644
index 7ab340b51..000000000
--- a/tools/nogo/util/BUILD
+++ /dev/null
@@ -1,9 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "util",
- srcs = ["util.go"],
- visibility = ["//visibility:public"],
-)
diff --git a/tools/nogo/util/util.go b/tools/nogo/util/util.go
deleted file mode 100644
index 919fec799..000000000
--- a/tools/nogo/util/util.go
+++ /dev/null
@@ -1,85 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package util contains nogo-related utilities.
-package util
-
-import (
- "fmt"
- "io/ioutil"
- "regexp"
- "strconv"
- "strings"
-)
-
-// findingRegexp is used to parse findings.
-var findingRegexp = regexp.MustCompile(`([a-zA-Z0-9_\/\.-]+): (-|([a-zA-Z0-9_\/\.-]+):([0-9]+)(:([0-9]+))?): (.*)`)
-
-const (
- categoryIndex = 1
- fullPathAndLineIndex = 2
- fullPathIndex = 3
- lineIndex = 4
- messageIndex = 7
-)
-
-// Finding is a single finding.
-type Finding struct {
- Category string
- Path string
- Line int
- Message string
-}
-
-// ExtractFindingsFromFile loads findings from a file.
-func ExtractFindingsFromFile(filename string) ([]Finding, error) {
- content, err := ioutil.ReadFile(filename)
- if err != nil {
- return nil, err
- }
- return ExtractFindingsFromBytes(content)
-}
-
-// ExtractFindingsFromBytes loads findings from bytes.
-func ExtractFindingsFromBytes(content []byte) (findings []Finding, err error) {
- lines := strings.Split(string(content), "\n")
- for _, singleLine := range lines {
- // Skip blank lines.
- singleLine = strings.TrimSpace(singleLine)
- if singleLine == "" {
- continue
- }
- m := findingRegexp.FindStringSubmatch(singleLine)
- if m == nil {
- // We shouldn't see findings like this.
- return findings, fmt.Errorf("poorly formated line: %v", singleLine)
- }
- if m[fullPathAndLineIndex] == "-" {
- continue // No source file available.
- }
- // Cleanup the message.
- message := m[messageIndex]
- message = strings.Replace(message, " → ", "\n → ", -1)
- message = strings.Replace(message, " or ", "\n or ", -1)
- // Construct a new annotation.
- lineNumber, _ := strconv.ParseUint(m[lineIndex], 10, 32)
- findings = append(findings, Finding{
- Category: m[categoryIndex],
- Path: m[fullPathIndex],
- Line: int(lineNumber),
- Message: message,
- })
- }
- return findings, nil
-}
diff --git a/tools/parsers/BUILD b/tools/parsers/BUILD
index 7d9c9a3fb..6932bba9a 100644
--- a/tools/parsers/BUILD
+++ b/tools/parsers/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_binary", "go_library", "go_test")
package(licenses = ["notice"])
@@ -7,6 +7,7 @@ go_test(
size = "small",
srcs = ["go_parser_test.go"],
library = ":parsers",
+ nogo = False,
deps = [
"//tools/bigquery",
"@com_github_google_go_cmp//cmp:go_default_library",
@@ -19,9 +20,26 @@ go_library(
srcs = [
"go_parser.go",
],
+ nogo = False,
visibility = ["//:sandbox"],
deps = [
"//test/benchmarks/tools",
"//tools/bigquery",
],
)
+
+go_binary(
+ name = "parser",
+ testonly = 1,
+ srcs = [
+ "parser_main.go",
+ "version.go",
+ ],
+ nogo = False,
+ x_defs = {"main.version": "{STABLE_VERSION}"},
+ deps = [
+ ":parsers",
+ "//runsc/flag",
+ "//tools/bigquery",
+ ],
+)
diff --git a/tools/parsers/go_parser.go b/tools/parsers/go_parser.go
index 2cf74c883..57e538149 100644
--- a/tools/parsers/go_parser.go
+++ b/tools/parsers/go_parser.go
@@ -27,20 +27,21 @@ import (
"gvisor.dev/gvisor/tools/bigquery"
)
-// parseOutput expects golang benchmark output returns a Benchmark struct formatted for BigQuery.
-func parseOutput(output string, metadata *bigquery.Metadata, official bool) ([]*bigquery.Benchmark, error) {
- var benchmarks []*bigquery.Benchmark
+// ParseOutput expects golang benchmark output and returns a struct formatted
+// for BigQuery.
+func ParseOutput(output string, name string, official bool) (*bigquery.Suite, error) {
+ suite := bigquery.NewSuite(name, official)
lines := strings.Split(output, "\n")
for _, line := range lines {
- bm, err := parseLine(line, metadata, official)
+ bm, err := parseLine(line)
if err != nil {
return nil, fmt.Errorf("failed to parse line '%s': %v", line, err)
}
if bm != nil {
- benchmarks = append(benchmarks, bm)
+ suite.Benchmarks = append(suite.Benchmarks, bm)
}
}
- return benchmarks, nil
+ return suite, nil
}
// parseLine handles parsing a benchmark line into a bigquery.Benchmark.
@@ -58,9 +59,8 @@ func parseOutput(output string, metadata *bigquery.Metadata, official bool) ([]*
// {Name: ns/op, Unit: ns/op, Sample: 1397875880}
// {Name: requests_per_second, Unit: QPS, Sample: 140 }
// }
-// Metadata: metadata
//}
-func parseLine(line string, metadata *bigquery.Metadata, official bool) (*bigquery.Benchmark, error) {
+func parseLine(line string) (*bigquery.Benchmark, error) {
fields := strings.Fields(line)
// Check if this line is a Benchmark line. Otherwise ignore the line.
@@ -78,8 +78,7 @@ func parseLine(line string, metadata *bigquery.Metadata, official bool) (*bigque
return nil, fmt.Errorf("parse name/params: %v", err)
}
- bm := bigquery.NewBenchmark(name, iters, official)
- bm.Metadata = metadata
+ bm := bigquery.NewBenchmark(name, iters)
for _, p := range params {
bm.AddCondition(p.Name, p.Value)
}
diff --git a/tools/parsers/go_parser_test.go b/tools/parsers/go_parser_test.go
index 36996b7c8..f0737d46b 100644
--- a/tools/parsers/go_parser_test.go
+++ b/tools/parsers/go_parser_test.go
@@ -94,13 +94,11 @@ func TestParseLine(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- got, err := parseLine(tc.data, nil, false)
+ got, err := parseLine(tc.data)
if err != nil {
t.Fatalf("parseLine failed with: %v", err)
}
- tc.want.Timestamp = got.Timestamp
-
if !cmp.Equal(tc.want, got, nil) {
for _, c := range got.Condition {
t.Logf("Cond: %+v", c)
@@ -150,14 +148,14 @@ BenchmarkRuby/server_threads.5-6 1 1416003331 ns/op 0.00950 average_latency.s 46
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- bms, err := parseOutput(tc.data, nil, false)
+ suite, err := ParseOutput(tc.data, "", false)
if err != nil {
t.Fatalf("parseOutput failed: %v", err)
- } else if len(bms) != tc.numBenchmarks {
- t.Fatalf("NumBenchmarks failed want: %d got: %d %+v", tc.numBenchmarks, len(bms), bms)
+ } else if len(suite.Benchmarks) != tc.numBenchmarks {
+ t.Fatalf("NumBenchmarks failed want: %d got: %d %+v", tc.numBenchmarks, len(suite.Benchmarks), suite.Benchmarks)
}
- for _, bm := range bms {
+ for _, bm := range suite.Benchmarks {
if len(bm.Metric) != tc.numMetrics {
t.Fatalf("NumMetrics failed want: %d got: %d %+v", tc.numMetrics, len(bm.Metric), bm.Metric)
}
diff --git a/tools/parsers/parser_main.go b/tools/parsers/parser_main.go
new file mode 100644
index 000000000..7cce69e03
--- /dev/null
+++ b/tools/parsers/parser_main.go
@@ -0,0 +1,135 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary parser parses Benchmark data from golang benchmarks,
+// puts it into a Schema for BigQuery, and sends it to BigQuery.
+// parser will also initialize a table with the Benchmarks BigQuery schema.
+package main
+
+import (
+ "context"
+ "fmt"
+ "io/ioutil"
+ "os"
+
+ "gvisor.dev/gvisor/runsc/flag"
+ bq "gvisor.dev/gvisor/tools/bigquery"
+ "gvisor.dev/gvisor/tools/parsers"
+)
+
+const (
+ initString = "init"
+ initDescription = "initializes a new table with benchmarks schema"
+ parseString = "parse"
+ parseDescription = "parses given benchmarks file and sends it to BigQuery table."
+)
+
+var (
+ // The init command will create a new dataset/table in the given project and initialize
+ // the table with the schema in //tools/bigquery/bigquery.go. If the table/dataset exists
+ // or has been initialized, init has no effect and successfully returns.
+ initCmd = flag.NewFlagSet(initString, flag.ContinueOnError)
+ initProject = initCmd.String("project", "", "GCP project to send benchmarks.")
+ initDataset = initCmd.String("dataset", "", "dataset to send benchmarks data.")
+ initTable = initCmd.String("table", "", "table to send benchmarks data.")
+
+ // The parse command parses benchmark data in `file` and sends it to the
+ // requested table.
+ parseCmd = flag.NewFlagSet(parseString, flag.ContinueOnError)
+ file = parseCmd.String("file", "", "file to parse for benchmarks")
+ name = parseCmd.String("suite_name", "", "name of the benchmark suite")
+ parseProject = parseCmd.String("project", "", "GCP project to send benchmarks.")
+ parseDataset = parseCmd.String("dataset", "", "dataset to send benchmarks data.")
+ parseTable = parseCmd.String("table", "", "table to send benchmarks data.")
+ official = parseCmd.Bool("official", false, "mark input data as official.")
+ runtime = parseCmd.String("runtime", "", "runtime used to run the benchmark")
+)
+
+// initBenchmarks initializes a dataset/table in a BigQuery project.
+func initBenchmarks(ctx context.Context) error {
+ return bq.InitBigQuery(ctx, *initProject, *initDataset, *initTable, nil)
+}
+
+// parseBenchmarks parses the given file into the BigQuery schema,
+// adds some custom data for the commit, and sends the data to BigQuery.
+func parseBenchmarks(ctx context.Context) error {
+ data, err := ioutil.ReadFile(*file)
+ if err != nil {
+ return fmt.Errorf("failed to read file %s: %v", *file, err)
+ }
+ suite, err := parsers.ParseOutput(string(data), *name, *official)
+ if err != nil {
+ return fmt.Errorf("failed parse data: %v", err)
+ }
+ if len(suite.Benchmarks) < 1 {
+ fmt.Fprintf(os.Stderr, "Failed to find benchmarks for file: %s", *file)
+ return nil
+ }
+
+ extraConditions := []*bq.Condition{
+ {
+ Name: "runtime",
+ Value: *runtime,
+ },
+ {
+ Name: "version",
+ Value: version,
+ },
+ }
+
+ suite.Official = *official
+ suite.Conditions = append(suite.Conditions, extraConditions...)
+ return bq.SendBenchmarks(ctx, suite, *parseProject, *parseDataset, *parseTable, nil)
+}
+
+func main() {
+ ctx := context.Background()
+ switch {
+ // the "init" command
+ case len(os.Args) >= 2 && os.Args[1] == initString:
+ if err := initCmd.Parse(os.Args[2:]); err != nil {
+ fmt.Fprintf(os.Stderr, "failed parse flags: %v\n", err)
+ os.Exit(1)
+ }
+ if err := initBenchmarks(ctx); err != nil {
+ failure := "failed to initialize project: %s dataset: %s table: %s: %v\n"
+ fmt.Fprintf(os.Stderr, failure, *parseProject, *parseDataset, *parseTable, err)
+ os.Exit(1)
+ }
+ // the "parse" command.
+ case len(os.Args) >= 2 && os.Args[1] == parseString:
+ if err := parseCmd.Parse(os.Args[2:]); err != nil {
+ fmt.Fprintf(os.Stderr, "failed parse flags: %v\n", err)
+ os.Exit(1)
+ }
+ if err := parseBenchmarks(ctx); err != nil {
+ fmt.Fprintf(os.Stderr, "failed parse benchmarks: %v\n", err)
+ os.Exit(1)
+ }
+ default:
+ printUsage()
+ os.Exit(1)
+ }
+}
+
+// printUsage prints the top level usage string.
+func printUsage() {
+ usage := `Usage: parser <command> <flags> ...
+
+Available commands:
+ %s %s
+ %s %s
+`
+ fmt.Fprintf(os.Stderr, usage, initCmd.Name(), initDescription, parseCmd.Name(), parseDescription)
+}
diff --git a/tools/parsers/version.go b/tools/parsers/version.go
new file mode 100644
index 000000000..ab9194b9d
--- /dev/null
+++ b/tools/parsers/version.go
@@ -0,0 +1,18 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+// version is set during linking.
+var version = "VERSION_MISSING"
diff --git a/tools/tag_release.sh b/tools/tag_release.sh
index b0bab74b4..50378065e 100755
--- a/tools/tag_release.sh
+++ b/tools/tag_release.sh
@@ -43,7 +43,7 @@ fi
closest_commit() {
while read line; do
- if [[ "$line" =~ "commit " ]]; then
+ if [[ "$line" =~ ^"commit " ]]; then
current_commit="${line#commit }"
continue
elif [[ "$line" =~ "PiperOrigin-RevId: " ]]; then
@@ -57,7 +57,9 @@ closest_commit() {
# Is the passed identifier a sha commit?
if ! git show "${target_commit}" &> /dev/null; then
# Extract the commit given a piper ID.
- declare -r commit="$(git log | closest_commit "${target_commit}")"
+ commit="$(set +o pipefail; \
+ git log --first-parent | closest_commit "${target_commit}")"
+ declare -r commit
else
declare -r commit="${target_commit}"
fi
diff --git a/webhook/BUILD b/webhook/BUILD
new file mode 100644
index 000000000..33c585504
--- /dev/null
+++ b/webhook/BUILD
@@ -0,0 +1,28 @@
+load("//images:defs.bzl", "docker_image")
+load("//tools:defs.bzl", "go_binary", "pkg_tar")
+
+package(licenses = ["notice"])
+
+docker_image(
+ name = "webhook_image",
+ data = ":files",
+ statements = ['ENTRYPOINT ["/webhook"]'],
+)
+
+# files is the full file system of the webhook container. It is simply:
+# /
+# └─ webhook
+pkg_tar(
+ name = "files",
+ srcs = [":webhook"],
+ extension = "tgz",
+ strip_prefix = "/third_party/gvisor/webhook",
+)
+
+go_binary(
+ name = "webhook",
+ srcs = ["main.go"],
+ pure = "on",
+ static = "on",
+ deps = ["//webhook/pkg/cli"],
+)
diff --git a/webhook/main.go b/webhook/main.go
new file mode 100644
index 000000000..220016543
--- /dev/null
+++ b/webhook/main.go
@@ -0,0 +1,24 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary main serves a mutating Kubernetes webhook.
+package main
+
+import (
+ "gvisor.dev/gvisor/webhook/pkg/cli"
+)
+
+func main() {
+ cli.Main()
+}
diff --git a/webhook/pkg/cli/BUILD b/webhook/pkg/cli/BUILD
new file mode 100644
index 000000000..ac093c556
--- /dev/null
+++ b/webhook/pkg/cli/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "cli",
+ srcs = ["cli.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/log",
+ "//webhook/pkg/injector",
+ "@io_k8s_apimachinery//pkg/apis/meta/v1:go_default_library",
+ "@io_k8s_apimachinery//pkg/util/net:go_default_library",
+ "@io_k8s_client_go//kubernetes:go_default_library",
+ "@io_k8s_client_go//rest:go_default_library",
+ ],
+)
diff --git a/webhook/pkg/cli/cli.go b/webhook/pkg/cli/cli.go
new file mode 100644
index 000000000..a07d341a2
--- /dev/null
+++ b/webhook/pkg/cli/cli.go
@@ -0,0 +1,115 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package cli provides a CLI interface for a mutating Kubernetes webhook.
+package cli
+
+import (
+ "flag"
+ "fmt"
+ "net"
+ "net/http"
+ "os"
+ "strconv"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/webhook/pkg/injector"
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+ k8snet "k8s.io/apimachinery/pkg/util/net"
+ "k8s.io/client-go/kubernetes"
+ "k8s.io/client-go/rest"
+)
+
+var (
+ address = flag.String("address", "", "The ip address the admission webhook serves on. If unspecified, a public address is selected automatically.")
+ port = flag.Int("port", 0, "The port the admission webhook serves on.")
+ podLabels = flag.String("pod-namespace-labels", "", "A comma-separated namespace label selector, the admission webhook will only take effect on pods in selected namespaces, e.g. `label1,label2`.")
+)
+
+// Main runs the webhook.
+func Main() {
+ flag.Parse()
+
+ if err := run(); err != nil {
+ log.Warningf("%v", err)
+ os.Exit(1)
+ }
+}
+
+func run() error {
+ log.Infof("Starting %s\n", injector.Name)
+
+ // Create client config.
+ cfg, err := rest.InClusterConfig()
+ if err != nil {
+ return fmt.Errorf("create in cluster config: %w", err)
+ }
+
+ // Create clientset.
+ clientset, err := kubernetes.NewForConfig(cfg)
+ if err != nil {
+ return fmt.Errorf("create kubernetes client: %w", err)
+ }
+
+ if err := injector.CreateConfiguration(clientset, parsePodLabels()); err != nil {
+ return fmt.Errorf("create webhook configuration: %w", err)
+ }
+
+ if err := startWebhookHTTPS(clientset); err != nil {
+ return fmt.Errorf("start webhook https server: %w", err)
+ }
+
+ return nil
+}
+
+func parsePodLabels() *metav1.LabelSelector {
+ rv := &metav1.LabelSelector{}
+ for _, s := range strings.Split(*podLabels, ",") {
+ req := metav1.LabelSelectorRequirement{
+ Key: strings.TrimSpace(s),
+ Operator: "Exists",
+ }
+ rv.MatchExpressions = append(rv.MatchExpressions, req)
+ }
+ return rv
+}
+
+func startWebhookHTTPS(clientset kubernetes.Interface) error {
+ log.Infof("Starting HTTPS handler")
+ defer log.Infof("Stopping HTTPS handler")
+
+ if *address == "" {
+ ip, err := k8snet.ChooseHostInterface()
+ if err != nil {
+ return fmt.Errorf("select ip address: %w", err)
+ }
+ *address = ip.String()
+ }
+ mux := http.NewServeMux()
+ mux.Handle("/", http.HandlerFunc(
+ func(w http.ResponseWriter, r *http.Request) {
+ injector.Admit(w, r)
+ }))
+ server := &http.Server{
+ // Listen on all addresses.
+ Addr: net.JoinHostPort(*address, strconv.Itoa(*port)),
+ TLSConfig: injector.GetTLSConfig(),
+ Handler: mux,
+ }
+ if err := server.ListenAndServeTLS("", ""); err != http.ErrServerClosed {
+ return fmt.Errorf("start HTTPS handler: %w", err)
+ }
+ return nil
+}
diff --git a/webhook/pkg/injector/BUILD b/webhook/pkg/injector/BUILD
new file mode 100644
index 000000000..d296981be
--- /dev/null
+++ b/webhook/pkg/injector/BUILD
@@ -0,0 +1,34 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "injector",
+ srcs = [
+ "certs.go",
+ "webhook.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/log",
+ "@com_github_mattbaird_jsonpatch//:go_default_library",
+ "@io_k8s_api//admission/v1beta1:go_default_library",
+ "@io_k8s_api//admissionregistration/v1beta1:go_default_library",
+ "@io_k8s_api//core/v1:go_default_library",
+ "@io_k8s_apimachinery//pkg/api/errors:go_default_library",
+ "@io_k8s_apimachinery//pkg/apis/meta/v1:go_default_library",
+ "@io_k8s_client_go//kubernetes:go_default_library",
+ ],
+)
+
+genrule(
+ name = "certs",
+ srcs = [":gencerts"],
+ outs = ["certs.go"],
+ cmd = "$$(cut -d ' ' -f 1 <<< \"$(locations :gencerts)\") $@",
+)
+
+sh_binary(
+ name = "gencerts",
+ srcs = ["gencerts.sh"],
+)
diff --git a/webhook/pkg/injector/gencerts.sh b/webhook/pkg/injector/gencerts.sh
new file mode 100755
index 000000000..f7fda4b63
--- /dev/null
+++ b/webhook/pkg/injector/gencerts.sh
@@ -0,0 +1,71 @@
+#!/bin/bash
+
+# Copyright 2020 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Generates the a CA cert, a server key, and a server cert signed by the CA.
+# reference:
+# https://github.com/kubernetes/kubernetes/blob/master/staging/src/k8s.io/apiserver/pkg/admission/plugin/webhook/testcerts/gencerts.sh
+set -euo pipefail
+
+# Do all the work in TMPDIR, then copy out generated code and delete TMPDIR.
+declare -r OUTDIR="$(readlink -e .)"
+declare -r TMPDIR="$(mktemp -d)"
+cd "${TMPDIR}"
+function cleanup() {
+ cd "${OUTDIR}"
+ rm -rf "${TMPDIR}"
+}
+trap cleanup EXIT
+
+declare -r CN_BASE="e2e"
+declare -r CN="gvisor-injection-admission-webhook.e2e.svc"
+
+cat > server.conf << EOF
+[req]
+req_extensions = v3_req
+distinguished_name = req_distinguished_name
+[req_distinguished_name]
+[ v3_req ]
+basicConstraints = CA:FALSE
+keyUsage = nonRepudiation, digitalSignature, keyEncipherment
+extendedKeyUsage = clientAuth, serverAuth
+EOF
+
+declare -r OUTFILE="${TMPDIR}/certs.go"
+
+# We depend on OpenSSL being present.
+
+# Create a certificate authority.
+openssl genrsa -out caKey.pem 2048
+openssl req -x509 -new -nodes -key caKey.pem -days 100000 -out caCert.pem -subj "/CN=${CN_BASE}_ca" -config server.conf
+
+# Create a server certificate.
+openssl genrsa -out serverKey.pem 2048
+# Note the CN is the DNS name of the service of the webhook.
+openssl req -new -key serverKey.pem -out server.csr -subj "/CN=${CN}" -config server.conf
+openssl x509 -req -in server.csr -CA caCert.pem -CAkey caKey.pem -CAcreateserial -out serverCert.pem -days 100000 -extensions v3_req -extfile server.conf
+
+echo "package injector" > "${OUTFILE}"
+echo "" >> "${OUTFILE}"
+echo "// This file was generated using openssl by the gencerts.sh script." >> "${OUTFILE}"
+for file in caKey caCert serverKey serverCert; do
+ DATA=$(cat "${file}.pem")
+ echo "" >> "${OUTFILE}"
+ echo "var $file = []byte(\`$DATA\`)" >> "${OUTFILE}"
+done
+
+# Copy generated code into the output directory.
+cp "${OUTFILE}" "${OUTDIR}/$1"
diff --git a/webhook/pkg/injector/webhook.go b/webhook/pkg/injector/webhook.go
new file mode 100644
index 000000000..614b5add7
--- /dev/null
+++ b/webhook/pkg/injector/webhook.go
@@ -0,0 +1,211 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package injector handles mutating webhook operations.
+package injector
+
+import (
+ "crypto/tls"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "os"
+
+ "github.com/mattbaird/jsonpatch"
+ "gvisor.dev/gvisor/pkg/log"
+ admv1beta1 "k8s.io/api/admission/v1beta1"
+ admregv1beta1 "k8s.io/api/admissionregistration/v1beta1"
+ v1 "k8s.io/api/core/v1"
+ apierrors "k8s.io/apimachinery/pkg/api/errors"
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+ kubeclientset "k8s.io/client-go/kubernetes"
+)
+
+const (
+ // Name is the name of the admission webhook service. The admission
+ // webhook must be exposed in the following service; this is mainly for
+ // the server certificate.
+ Name = "gvisor-injection-admission-webhook"
+
+ // serviceNamespace is the namespace of the admission webhook service.
+ serviceNamespace = "e2e"
+
+ fullName = Name + "." + serviceNamespace + ".svc"
+)
+
+// CreateConfiguration creates MutatingWebhookConfiguration and registers the
+// webhook admission controller with the kube-apiserver. The webhook will only
+// take effect on pods in the namespaces selected by `podNsSelector`. If `podNsSelector`
+// is empty, the webhook will take effect on all pods.
+func CreateConfiguration(clientset kubeclientset.Interface, selector *metav1.LabelSelector) error {
+ fail := admregv1beta1.Fail
+
+ config := &admregv1beta1.MutatingWebhookConfiguration{
+ ObjectMeta: metav1.ObjectMeta{
+ Name: Name,
+ },
+ Webhooks: []admregv1beta1.MutatingWebhook{
+ {
+ Name: fullName,
+ ClientConfig: admregv1beta1.WebhookClientConfig{
+ Service: &admregv1beta1.ServiceReference{
+ Name: Name,
+ Namespace: serviceNamespace,
+ },
+ CABundle: caCert,
+ },
+ Rules: []admregv1beta1.RuleWithOperations{
+ {
+ Operations: []admregv1beta1.OperationType{
+ admregv1beta1.Create,
+ },
+ Rule: admregv1beta1.Rule{
+ APIGroups: []string{"*"},
+ APIVersions: []string{"*"},
+ Resources: []string{"pods"},
+ },
+ },
+ },
+ FailurePolicy: &fail,
+ NamespaceSelector: selector,
+ },
+ },
+ }
+ log.Infof("Creating MutatingWebhookConfiguration %q", config.Name)
+ if _, err := clientset.AdmissionregistrationV1beta1().MutatingWebhookConfigurations().Create(config); err != nil {
+ if !apierrors.IsAlreadyExists(err) {
+ return fmt.Errorf("failed to create MutatingWebhookConfiguration %q: %s", config.Name, err)
+ }
+ log.Infof("MutatingWebhookConfiguration %q already exists; use the existing one", config.Name)
+ }
+ return nil
+}
+
+// GetTLSConfig retrieves the CA cert that signed the cert used by the webhook.
+func GetTLSConfig() *tls.Config {
+ serverCert, err := tls.X509KeyPair(serverCert, serverKey)
+ if err != nil {
+ log.Warningf("Failed to generate X509 key pair: %v", err)
+ os.Exit(1)
+ }
+ return &tls.Config{
+ Certificates: []tls.Certificate{serverCert},
+ }
+}
+
+// Admit performs admission checks and mutations on Pods.
+func Admit(writer http.ResponseWriter, req *http.Request) {
+ review := &admv1beta1.AdmissionReview{}
+ if err := json.NewDecoder(req.Body).Decode(review); err != nil {
+ log.Infof("Failed with error (%v) to decode Admit request: %+v", err, *req)
+ writer.WriteHeader(http.StatusBadRequest)
+ return
+ }
+
+ log.Debugf("admitPod: %+v", review)
+ var err error
+ review.Response, err = admitPod(review.Request)
+ if err != nil {
+ log.Warningf("admitPod failed: %v", err)
+ review.Response = &admv1beta1.AdmissionResponse{
+ Result: &metav1.Status{
+ Reason: metav1.StatusReasonInvalid,
+ Message: err.Error(),
+ },
+ }
+ sendResponse(writer, review)
+ return
+ }
+
+ log.Debugf("Processed admission review: %+v", review)
+ sendResponse(writer, review)
+}
+
+func sendResponse(writer http.ResponseWriter, response interface{}) {
+ b, err := json.Marshal(response)
+ if err != nil {
+ log.Warningf("Failed with error (%v) to marshal response: %+v", err, response)
+ writer.WriteHeader(http.StatusInternalServerError)
+ return
+ }
+
+ writer.WriteHeader(http.StatusOK)
+ writer.Write(b)
+}
+
+func admitPod(req *admv1beta1.AdmissionRequest) (*admv1beta1.AdmissionResponse, error) {
+ // Verify that the request is indeed a Pod.
+ resource := metav1.GroupVersionResource{Group: "", Version: "v1", Resource: "pods"}
+ if req.Resource != resource {
+ return nil, fmt.Errorf("unexpected resource %+v in pod admission", req.Resource)
+ }
+
+ // Decode the request into a Pod.
+ pod := &v1.Pod{}
+ if err := json.Unmarshal(req.Object.Raw, pod); err != nil {
+ return nil, fmt.Errorf("failed to decode pod object %s/%s", req.Namespace, req.Name)
+ }
+
+ // Copy first to change it.
+ podCopy := pod.DeepCopy()
+ updatePod(podCopy)
+ patch, err := createPatch(req.Object.Raw, podCopy)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create patch for pod %s/%s (generatedName: %s)", pod.Namespace, pod.Name, pod.GenerateName)
+ }
+
+ log.Debugf("Patched pod %s/%s (generateName: %s): %+v", pod.Namespace, pod.Name, pod.GenerateName, podCopy)
+ patchType := admv1beta1.PatchTypeJSONPatch
+ return &admv1beta1.AdmissionResponse{
+ Allowed: true,
+ Patch: patch,
+ PatchType: &patchType,
+ }, nil
+}
+
+func updatePod(pod *v1.Pod) {
+ gvisor := "gvisor"
+ pod.Spec.RuntimeClassName = &gvisor
+
+ // We don't run SELinux test for gvisor.
+ // If SELinuxOptions are specified, this is usually for volume test to pass
+ // on SELinux. This can be safely ignored.
+ if pod.Spec.SecurityContext != nil && pod.Spec.SecurityContext.SELinuxOptions != nil {
+ pod.Spec.SecurityContext.SELinuxOptions = nil
+ }
+ for i := range pod.Spec.Containers {
+ c := &pod.Spec.Containers[i]
+ if c.SecurityContext != nil && c.SecurityContext.SELinuxOptions != nil {
+ c.SecurityContext.SELinuxOptions = nil
+ }
+ }
+ for i := range pod.Spec.InitContainers {
+ c := &pod.Spec.InitContainers[i]
+ if c.SecurityContext != nil && c.SecurityContext.SELinuxOptions != nil {
+ c.SecurityContext.SELinuxOptions = nil
+ }
+ }
+}
+
+func createPatch(old []byte, newObj interface{}) ([]byte, error) {
+ new, err := json.Marshal(newObj)
+ if err != nil {
+ return nil, err
+ }
+ patch, err := jsonpatch.CreatePatch(old, new)
+ if err != nil {
+ return nil, err
+ }
+ return json.Marshal(patch)
+}
diff --git a/website/BUILD b/website/BUILD
index f3642b903..676c2b701 100644
--- a/website/BUILD
+++ b/website/BUILD
@@ -6,11 +6,16 @@ package(licenses = ["notice"])
docker_image(
name = "website",
- data = [":files"],
+ data = ":files",
statements = [
"EXPOSE 8080/tcp",
'ENTRYPOINT ["/server"]',
],
+ tags = [
+ "local",
+ "manual",
+ "nosandbox",
+ ],
)
# files is the full file system of the generated container.
diff --git a/website/_config.yml b/website/_config.yml
index 20fbb3d2d..dc44945bc 100644
--- a/website/_config.yml
+++ b/website/_config.yml
@@ -37,3 +37,10 @@ authors:
fvoznika:
name: Fabricio Voznika
email: fvoznika@google.com
+ ianlewis:
+ name: Ian Lewis
+ email: ianlewis@google.com
+ url: https://twitter.com/IanMLewis
+ mpratt:
+ name: Michael Pratt
+ email: mpratt@google.com
diff --git a/website/_includes/byline.html b/website/_includes/byline.html
index d8ae22cb0..1e808260f 100644
--- a/website/_includes/byline.html
+++ b/website/_includes/byline.html
@@ -5,7 +5,7 @@ By
{% assign author_id=include.authors[i] %}
{% assign author=site.authors[author_id] %}
{% if author %}
- <a href="mailto:{{ author.email }}">{{ author.name }}</a>
+ <a href="{% if author.url %}{{ author.url }}{% else %}mailto:{{ author.email }}{% endif %}">{{ author.name }}</a>
{% else %}
{{ author_id }}
{% endif %}
diff --git a/website/blog/2020-10-22-platform-portability.md b/website/blog/2020-10-22-platform-portability.md
new file mode 100644
index 000000000..4d82940f9
--- /dev/null
+++ b/website/blog/2020-10-22-platform-portability.md
@@ -0,0 +1,120 @@
+# Platform Portability
+
+Hardware virtualization is often seen as a requirement to provide an additional
+isolation layer for untrusted applications. However, hardware virtualization
+requires expensive bare-metal machines or cloud instances to run safely with
+good performance, increasing cost and complexity for Cloud users. gVisor,
+however, takes a more flexible approach.
+
+One of the pillars of gVisor's architecture is portability, allowing it to run
+anywhere that runs Linux. Modern Cloud-Native applications run in containers in
+many different places, from bare metal to virtual machines, and can't always
+rely on nested virtualization. It is important for gVisor to be able to support
+the environments where you run containers.
+
+gVisor achieves portability through an abstraction called a _Platform_.
+Platforms can have many implementations, and each implementation can cover
+different environments, making use of available software or hardware features.
+
+## Background
+
+Before we can understand how gVisor achieves portability using platforms, we
+should take a step back and understand how applications interact with their
+host.
+
+Container sandboxes can provide an isolation layer between the host and
+application by virtualizing one of the layers below it, including the hardware
+or operating system. Many sandboxes virtualize the hardware layer by running
+applications in virtual machines. gVisor takes a different approach by
+virtualizing the OS layer.
+
+When an application is run in a normal situation the host operating system loads
+the application into user memory and schedules it for execution. The operating
+system scheduler eventually schedules the application to a CPU and begins
+executing it. It then handles the application's requests, such as for memory and
+the lifecycle of the application. gVisor virtualizes these interactions, such as
+system calls, and context switching that happen between an application and OS.
+
+[System calls](https://en.wikipedia.org/wiki/System_call) allow applications to
+ask the OS to perform some task for it. System calls look like a normal function
+call in most programming languages though works a bit differently under the
+hood. When an application system call is encountered some special processing
+takes place to do a
+[context switch](https://en.wikipedia.org/wiki/Context_switch) into kernel mode
+and begin executing code in the kernel before returning a result to the
+application. Context switching may happen in other situations as well. For
+example, to respond to an interrupt.
+
+## The Platform Interface
+
+gVisor provides a sandbox which implements the Linux OS interface, intercepting
+OS interactions such as system calls and implements them in the sandbox kernel.
+
+It does this to limit interactions with the host, and protect the host from an
+untrusted application running in the sandbox. The Platform is the bottom layer
+of gVisor which provides the environment necessary for gVisor to control and
+manage applications. In general, the Platform must:
+
+1. Provide the ability to create and manage memory address spaces.
+2. Provide execution contexts for running applications in those memory address
+ spaces.
+3. Provide the ability to change execution context and return control to gVisor
+ at specific times (e.g. system call, page fault)
+
+This interface is conceptually simple, but very powerful. Since the Platform
+interface only requires these three capabilities, it gives gVisor enough control
+for it to act as the application's OS, while still allowing the use of very
+different isolation technologies under the hood. You can learn more about the
+Platform interface in the
+[Platform Guide](https://gvisor.dev/docs/architecture_guide/platforms/).
+
+## Implementations of the Platform Interface
+
+While gVisor can make use of technologies like hardware virtualization, it
+doesn't necessarily rely on any one technology to provide a similar level of
+isolation. The flexibility of the Platform interface allows for implementations
+that use technologies other than hardware virtualization. This allows gVisor to
+run in VMs without nested virtualization, for example. By providing an
+abstraction for the underlying platform, each implementation can make various
+tradeoffs regarding performance or hardware requirements.
+
+Currently gVisor provides two gVisor Platform implementations; the Ptrace
+Platform, and the KVM Platform, each using very different methods to implement
+the Platform interface.
+
+![gVisor Platforms](../../../../../docs/architecture_guide/platforms/platforms.png "Platforms")
+
+The Ptrace Platform uses
+[PTRACE\_SYSEMU](http://man7.org/linux/man-pages/man2/ptrace.2.html) to trap
+syscalls, and uses the host for memory mapping and context switching. This
+platform can run anywhere that ptrace is available, which includes most Linux
+systems, VMs or otherwise.
+
+The KVM Platform uses virtualization, but in an unconventional way. gVisor runs
+in a virtual machine but as both guest OS and VMM, and presents no virtualized
+hardware layer. This provides a simpler interface that can avoid hardware
+initialization for fast start up, while taking advantage of hardware
+virtualization support to improve memory isolation and performance of context
+switching.
+
+The flexibility of the Platform interface allows for a lot of room to improve
+the existing KVM and ptrace platforms, as well as the ability to utilize new
+methods for improving gVisor's performance or portability in future Platform
+implementations.
+
+## Portability
+
+Through the Platform interface, gVisor is able to support bare metal, virtual
+machines, and Cloud environments while still providing a highly secure sandbox
+for running untrusted applications. This is especially important for Cloud and
+Kubernetes users because it allows gVisor to run anywhere that Kubernetes can
+run and provide similar experiences in multi-region, hybrid, multi-platform
+environments.
+
+Give gVisor's open source platforms a try. Using a Platform is as easy as
+providing the `--platform` flag to `runsc`. See the documentation on
+[changing platforms](https://gvisor.dev/docs/user_guide/platforms/) for how to
+use different platforms with Docker. We would love to hear about your experience
+so come chat with us in our
+[Gitter channel](https://gitter.im/gvisor/community), or send us an
+[issue on Github](https://gvisor.dev/issue) if you run into any problems.
diff --git a/website/blog/BUILD b/website/blog/BUILD
index 865e403da..17beb721f 100644
--- a/website/blog/BUILD
+++ b/website/blog/BUILD
@@ -38,6 +38,17 @@ doc(
permalink = "/blog/2020/09/18/containing-a-real-vulnerability/",
)
+doc(
+ name = "platform_portability",
+ src = "2020-10-22-platform-portability.md",
+ authors = [
+ "ianlewis",
+ "mpratt",
+ ],
+ layout = "post",
+ permalink = "/blog/2020/10/22/platform-portability/",
+)
+
docs(
name = "posts",
deps = [
diff --git a/website/cmd/server/main.go b/website/cmd/server/main.go
index c401b6abd..ac09550a9 100644
--- a/website/cmd/server/main.go
+++ b/website/cmd/server/main.go
@@ -29,6 +29,7 @@ var redirects = map[string]string{
// GitHub redirects.
"/change": "https://github.com/google/gvisor",
"/issue": "https://github.com/google/gvisor/issues",
+ "/issues": "https://github.com/google/gvisor/issues",
"/issue/new": "https://github.com/google/gvisor/issues/new",
"/pr": "https://github.com/google/gvisor/pulls",
@@ -44,14 +45,16 @@ var redirects = map[string]string{
"/c/linux/amd64": "/docs/user_guide/compatibility/linux/amd64/",
// Redirect for old URLs.
- "/docs/user_guide/compatibility/amd64/": "/docs/user_guide/compatibility/linux/amd64/",
- "/docs/user_guide/compatibility/amd64": "/docs/user_guide/compatibility/linux/amd64/",
- "/docs/user_guide/kubernetes/": "/docs/user_guide/quick_start/kubernetes/",
- "/docs/user_guide/kubernetes": "/docs/user_guide/quick_start/kubernetes/",
- "/docs/user_guide/oci/": "/docs/user_guide/quick_start/oci/",
- "/docs/user_guide/oci": "/docs/user_guide/quick_start/oci/",
- "/docs/user_guide/docker/": "/docs/user_guide/quick_start/docker/",
- "/docs/user_guide/docker": "/docs/user_guide/quick_start/docker/",
+ "/docs/user_guide/compatibility/amd64/": "/docs/user_guide/compatibility/linux/amd64/",
+ "/docs/user_guide/compatibility/amd64": "/docs/user_guide/compatibility/linux/amd64/",
+ "/docs/user_guide/kubernetes/": "/docs/user_guide/quick_start/kubernetes/",
+ "/docs/user_guide/kubernetes": "/docs/user_guide/quick_start/kubernetes/",
+ "/docs/user_guide/oci/": "/docs/user_guide/quick_start/oci/",
+ "/docs/user_guide/oci": "/docs/user_guide/quick_start/oci/",
+ "/docs/user_guide/docker/": "/docs/user_guide/quick_start/docker/",
+ "/docs/user_guide/docker": "/docs/user_guide/quick_start/docker/",
+ "/blog/2020/09/22/platform-portability": "/blog/2020/10/22/platform-portability/",
+ "/blog/2020/09/22/platform-portability/": "/blog/2020/10/22/platform-portability/",
// Deprecated, but links continue to work.
"/cl": "https://gvisor-review.googlesource.com",
@@ -60,6 +63,7 @@ var redirects = map[string]string{
var prefixHelpers = map[string]string{
"change": "https://github.com/google/gvisor/commit/%s",
"issue": "https://github.com/google/gvisor/issues/%s",
+ "issues": "https://github.com/google/gvisor/issues/%s",
"pr": "https://github.com/google/gvisor/pull/%s",
// Redirects to compatibility docs.