summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.bazelrc36
-rw-r--r--.gitignore2
-rw-r--r--BUILD24
-rw-r--r--CONTRIBUTING.md6
-rw-r--r--Makefile14
-rw-r--r--WORKSPACE133
-rw-r--r--g3doc/user_guide/containerd/quick_start.md4
-rw-r--r--g3doc/user_guide/install.md19
-rw-r--r--go.mod7
-rw-r--r--go.sum44
-rw-r--r--images/defs.bzl17
-rw-r--r--nogo.yaml495
-rw-r--r--pkg/abi/linux/BUILD2
-rw-r--r--pkg/abi/linux/ioctl.go3
-rw-r--r--pkg/abi/linux/sem.go12
-rw-r--r--pkg/abi/linux/sem_amd64.go33
-rw-r--r--pkg/abi/linux/sem_arm64.go31
-rw-r--r--pkg/context/context.go24
-rw-r--r--pkg/eventchannel/BUILD8
-rw-r--r--pkg/eventchannel/event.go40
-rw-r--r--pkg/eventchannel/event.proto2
-rw-r--r--pkg/eventchannel/event_any.go25
-rw-r--r--pkg/eventchannel/event_test.go2
-rw-r--r--pkg/eventchannel/rate.go2
-rw-r--r--pkg/merkletree/BUILD10
-rw-r--r--pkg/merkletree/merkletree.go104
-rw-r--r--pkg/merkletree/merkletree_test.go394
-rw-r--r--pkg/metric/BUILD2
-rw-r--r--pkg/metric/metric_test.go2
-rw-r--r--pkg/refsvfs2/BUILD (renamed from pkg/refs_vfs2/BUILD)16
-rw-r--r--pkg/refsvfs2/refs.go (renamed from pkg/refs_vfs2/refs.go)4
-rw-r--r--pkg/refsvfs2/refs_map.go97
-rw-r--r--pkg/refsvfs2/refs_template.go (renamed from pkg/refs_vfs2/refs_template.go)42
-rw-r--r--pkg/sentry/control/state.go2
-rw-r--r--pkg/sentry/devices/tundev/tundev.go4
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper.go7
-rw-r--r--pkg/sentry/fs/proc/task.go83
-rw-r--r--pkg/sentry/fsimpl/devpts/BUILD3
-rw-r--r--pkg/sentry/fsimpl/devpts/devpts.go15
-rw-r--r--pkg/sentry/fsimpl/devpts/line_discipline.go4
-rw-r--r--pkg/sentry/fsimpl/devpts/master.go2
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/BUILD5
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/save_restore.go23
-rw-r--r--pkg/sentry/fsimpl/eventfd/eventfd.go2
-rw-r--r--pkg/sentry/fsimpl/fuse/BUILD3
-rw-r--r--pkg/sentry/fsimpl/fuse/fusefs.go18
-rw-r--r--pkg/sentry/fsimpl/fuse/read_write.go6
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD3
-rw-r--r--pkg/sentry/fsimpl/gofer/directory.go10
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go29
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go440
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer_test.go5
-rw-r--r--pkg/sentry/fsimpl/gofer/host_named_pipe.go20
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go22
-rw-r--r--pkg/sentry/fsimpl/gofer/save_restore.go329
-rw-r--r--pkg/sentry/fsimpl/gofer/socket.go7
-rw-r--r--pkg/sentry/fsimpl/gofer/special_file.go108
-rw-r--r--pkg/sentry/fsimpl/gofer/time.go12
-rw-r--r--pkg/sentry/fsimpl/host/BUILD6
-rw-r--r--pkg/sentry/fsimpl/host/control.go2
-rw-r--r--pkg/sentry/fsimpl/host/host.go249
-rw-r--r--pkg/sentry/fsimpl/host/mmap.go2
-rw-r--r--pkg/sentry/fsimpl/host/save_restore.go78
-rw-r--r--pkg/sentry/fsimpl/host/util.go6
-rw-r--r--pkg/sentry/fsimpl/kernfs/BUILD37
-rw-r--r--pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go4
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go14
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go5
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go96
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go205
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go62
-rw-r--r--pkg/sentry/fsimpl/kernfs/symlink.go8
-rw-r--r--pkg/sentry/fsimpl/kernfs/synthetic_directory.go11
-rw-r--r--pkg/sentry/fsimpl/overlay/BUILD7
-rw-r--r--pkg/sentry/fsimpl/overlay/copy_up.go70
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go60
-rw-r--r--pkg/sentry/fsimpl/overlay/overlay.go198
-rw-r--r--pkg/sentry/fsimpl/overlay/regular_file.go (renamed from pkg/sentry/fsimpl/overlay/non_directory.go)115
-rw-r--r--pkg/sentry/fsimpl/overlay/save_restore.go27
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD11
-rw-r--r--pkg/sentry/fsimpl/proc/filesystem.go34
-rw-r--r--pkg/sentry/fsimpl/proc/subtasks.go4
-rw-r--r--pkg/sentry/fsimpl/proc/task.go7
-rw-r--r--pkg/sentry/fsimpl/proc/task_fds.go16
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go173
-rw-r--r--pkg/sentry/fsimpl/proc/task_net.go34
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go30
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_files.go8
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go116
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_test.go1
-rw-r--r--pkg/sentry/fsimpl/sockfs/sockfs.go4
-rw-r--r--pkg/sentry/fsimpl/sys/BUILD3
-rw-r--r--pkg/sentry/fsimpl/sys/kcov.go4
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go68
-rw-r--r--pkg/sentry/fsimpl/testutil/kernel.go7
-rw-r--r--pkg/sentry/fsimpl/tmpfs/BUILD4
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go4
-rw-r--r--pkg/sentry/fsimpl/tmpfs/save_restore.go20
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go11
-rw-r--r--pkg/sentry/fsimpl/verity/BUILD6
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go138
-rw-r--r--pkg/sentry/fsimpl/verity/save_restore.go27
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go122
-rw-r--r--pkg/sentry/fsimpl/verity/verity_test.go263
-rw-r--r--pkg/sentry/inet/inet.go6
-rw-r--r--pkg/sentry/inet/test_stack.go21
-rw-r--r--pkg/sentry/kernel/BUILD13
-rw-r--r--pkg/sentry/kernel/abstract_socket_namespace.go10
-rw-r--r--pkg/sentry/kernel/fd_table.go6
-rw-r--r--pkg/sentry/kernel/fd_table_unsafe.go13
-rw-r--r--pkg/sentry/kernel/fs_context.go24
-rw-r--r--pkg/sentry/kernel/ipc_namespace.go2
-rw-r--r--pkg/sentry/kernel/kcov.go19
-rw-r--r--pkg/sentry/kernel/kcov_unsafe.go6
-rw-r--r--pkg/sentry/kernel/kernel.go188
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go7
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore.go29
-rw-r--r--pkg/sentry/kernel/shm/BUILD4
-rw-r--r--pkg/sentry/kernel/task_clone.go21
-rw-r--r--pkg/sentry/kernel/task_exit.go2
-rw-r--r--pkg/sentry/kernel/task_start.go14
-rw-r--r--pkg/sentry/kernel/thread_group.go7
-rw-r--r--pkg/sentry/loader/elf.go8
-rw-r--r--pkg/sentry/mm/BUILD5
-rw-r--r--pkg/sentry/pgalloc/BUILD1
-rw-r--r--pkg/sentry/pgalloc/pgalloc.go122
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.go13
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.s12
-rw-r--r--pkg/sentry/platform/kvm/kvm_const_arm64.go2
-rw-r--r--pkg/sentry/platform/kvm/machine.go33
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go32
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go29
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s34
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.go6
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.s10
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go2
-rw-r--r--pkg/sentry/socket/control/control_vfs2.go10
-rw-r--r--pkg/sentry/socket/hostinet/socket_vfs2.go3
-rw-r--r--pkg/sentry/socket/hostinet/stack.go21
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go2
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go2
-rw-r--r--pkg/sentry/socket/netlink/provider_vfs2.go2
-rw-r--r--pkg/sentry/socket/netlink/route/protocol.go62
-rw-r--r--pkg/sentry/socket/netlink/socket.go7
-rw-r--r--pkg/sentry/socket/netlink/socket_vfs2.go2
-rw-r--r--pkg/sentry/socket/netstack/netstack.go17
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go4
-rw-r--r--pkg/sentry/socket/netstack/stack.go85
-rw-r--r--pkg/sentry/socket/unix/BUILD5
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD3
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go16
-rw-r--r--pkg/sentry/socket/unix/unix.go1
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go3
-rw-r--r--pkg/sentry/state/BUILD2
-rw-r--r--pkg/sentry/state/state.go10
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_sem.go20
-rw-r--r--pkg/sentry/syscalls/linux/sys_sysinfo.go2
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fd.go20
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/splice.go6
-rw-r--r--pkg/sentry/usage/memory.go2
-rw-r--r--pkg/sentry/vfs/BUILD8
-rw-r--r--pkg/sentry/vfs/epoll.go2
-rw-r--r--pkg/sentry/vfs/file_description.go1
-rw-r--r--pkg/sentry/vfs/genericfstree/genericfstree.go11
-rw-r--r--pkg/sentry/vfs/lock.go5
-rw-r--r--pkg/sentry/vfs/mount.go48
-rw-r--r--pkg/sentry/vfs/mount_unsafe.go21
-rw-r--r--pkg/sentry/vfs/save_restore.go124
-rw-r--r--pkg/sentry/vfs/vfs.go24
-rw-r--r--pkg/shim/v2/runtimeoptions/BUILD12
-rw-r--r--pkg/shim/v2/runtimeoptions/runtimeoptions.go15
-rw-r--r--pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go383
-rw-r--r--pkg/shim/v2/runtimeoptions/runtimeoptions_test.go9
-rw-r--r--pkg/state/BUILD14
-rw-r--r--pkg/state/decode.go80
-rw-r--r--pkg/state/decode_unsafe.go57
-rw-r--r--pkg/state/encode.go221
-rw-r--r--pkg/state/pretty/pretty.go41
-rw-r--r--pkg/state/state.go10
-rw-r--r--pkg/state/tests/struct.go35
-rw-r--r--pkg/state/tests/struct_test.go34
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go11
-rw-r--r--pkg/tcpip/checker/checker.go67
-rw-r--r--pkg/tcpip/header/icmpv4.go9
-rw-r--r--pkg/tcpip/header/ipv4.go493
-rw-r--r--pkg/tcpip/header/ipv6.go6
-rw-r--r--pkg/tcpip/link/ethernet/BUILD15
-rw-r--r--pkg/tcpip/link/ethernet/ethernet.go99
-rw-r--r--pkg/tcpip/link/pipe/pipe.go39
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/pipe_test.go36
-rw-r--r--pkg/tcpip/link/tun/BUILD3
-rw-r--r--pkg/tcpip/link/tun/device.go2
-rw-r--r--pkg/tcpip/network/BUILD1
-rw-r--r--pkg/tcpip/network/arp/BUILD1
-rw-r--r--pkg/tcpip/network/arp/arp.go84
-rw-r--r--pkg/tcpip/network/arp/arp_test.go201
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go58
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go120
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go22
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go23
-rw-r--r--pkg/tcpip/network/ip_test.go422
-rw-r--r--pkg/tcpip/network/ipv4/BUILD1
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go124
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go581
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go1084
-rw-r--r--pkg/tcpip/network/ipv6/BUILD1
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go80
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go128
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go226
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go505
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go2
-rw-r--r--pkg/tcpip/stack/forwarding_test.go2
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go8
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go4
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go24
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go516
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go193
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go1235
-rw-r--r--pkg/tcpip/stack/nic.go17
-rw-r--r--pkg/tcpip/stack/nic_test.go2
-rw-r--r--pkg/tcpip/stack/nud.go8
-rw-r--r--pkg/tcpip/stack/packet_buffer.go20
-rw-r--r--pkg/tcpip/stack/registration.go15
-rw-r--r--pkg/tcpip/stack/stack.go16
-rw-r--r--pkg/tcpip/stack/stack_test.go84
-rw-r--r--pkg/tcpip/stack/transport_test.go2
-rw-r--r--pkg/tcpip/tcpip.go26
-rw-r--r--pkg/tcpip/tests/integration/BUILD1
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go13
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go7
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go2
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go2
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go7
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go2
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go38
-rw-r--r--pkg/tcpip/transport/tcp/accept.go9
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go6
-rw-r--r--pkg/tcpip/transport/tcp/sack_scoreboard.go2
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go20
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go3
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go19
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go5
-rw-r--r--pkg/waiter/waiter.go2
-rw-r--r--runsc/BUILD22
-rw-r--r--runsc/boot/BUILD3
-rw-r--r--runsc/boot/compat.go2
-rw-r--r--runsc/boot/controller.go19
-rw-r--r--runsc/boot/fs.go33
-rw-r--r--runsc/boot/loader.go9
-rw-r--r--runsc/boot/loader_test.go57
-rw-r--r--runsc/boot/vfs.go60
-rw-r--r--runsc/cgroup/cgroup.go53
-rw-r--r--runsc/cgroup/cgroup_test.go80
-rw-r--r--runsc/cli/BUILD22
-rw-r--r--runsc/cli/main.go256
-rw-r--r--runsc/cmd/boot.go6
-rw-r--r--runsc/cmd/do.go45
-rw-r--r--runsc/cmd/start.go7
-rw-r--r--runsc/config/config.go3
-rw-r--r--runsc/config/flags.go1
-rw-r--r--runsc/container/container.go18
-rw-r--r--runsc/container/container_test.go6
-rw-r--r--runsc/main.go240
-rw-r--r--runsc/specutils/specutils.go18
-rw-r--r--shim/v1/BUILD21
-rw-r--r--shim/v1/cli/BUILD30
-rw-r--r--shim/v1/cli/api.go (renamed from shim/v1/api.go)2
-rw-r--r--shim/v1/cli/cli.go267
-rw-r--r--shim/v1/cli/config.go (renamed from shim/v1/config.go)2
-rw-r--r--shim/v1/main.go249
-rw-r--r--shim/v2/BUILD9
-rw-r--r--shim/v2/cli/BUILD16
-rw-r--r--shim/v2/cli/cli.go28
-rw-r--r--shim/v2/main.go10
-rw-r--r--test/iptables/filter_output.go17
-rw-r--r--test/iptables/nat.go15
-rw-r--r--test/runner/defs.bzl12
-rw-r--r--test/runner/runner.go6
-rw-r--r--test/runtimes/exclude/java11.csv1
-rw-r--r--test/runtimes/exclude/nodejs12.4.0.csv41
-rw-r--r--test/runtimes/exclude/php7.3.6.csv3
-rw-r--r--test/runtimes/exclude/python3.7.3.csv1
-rw-r--r--test/runtimes/proctor/main.go28
-rw-r--r--test/syscalls/linux/BUILD4
-rw-r--r--test/syscalls/linux/flock.cc46
-rw-r--r--test/syscalls/linux/mknod.cc46
-rw-r--r--test/syscalls/linux/mount.cc38
-rw-r--r--test/syscalls/linux/packet_socket_raw.cc11
-rw-r--r--test/syscalls/linux/pipe.cc30
-rw-r--r--test/syscalls/linux/proc.cc409
-rw-r--r--test/syscalls/linux/proc_pid_smaps.cc2
-rw-r--r--test/syscalls/linux/raw_socket_icmp.cc13
-rw-r--r--test/syscalls/linux/semaphore.cc56
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc16
-rw-r--r--test/syscalls/linux/socket_ip_tcp_generic.cc51
-rw-r--r--test/syscalls/linux/socket_ip_udp_generic.cc14
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound.cc82
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc28
-rw-r--r--test/syscalls/linux/socket_netlink_route.cc107
-rw-r--r--test/syscalls/linux/socket_netlink_route_util.cc22
-rw-r--r--test/syscalls/linux/socket_netlink_route_util.h11
-rw-r--r--test/syscalls/linux/socket_netlink_util.cc11
-rw-r--r--test/syscalls/linux/socket_netlink_util.h8
-rw-r--r--test/syscalls/linux/socket_test_util.cc16
-rw-r--r--test/syscalls/linux/socket_test_util.h9
-rw-r--r--test/syscalls/linux/socket_unix_stream.cc13
-rw-r--r--test/syscalls/linux/stat.cc34
-rw-r--r--test/syscalls/linux/tcp_socket.cc57
-rw-r--r--test/syscalls/linux/timers.cc94
-rw-r--r--test/syscalls/linux/udp_socket.cc56
-rw-r--r--test/util/BUILD4
-rw-r--r--test/util/posix_error.cc2
-rw-r--r--test/util/posix_error.h34
-rw-r--r--test/util/save_util.cc52
-rw-r--r--test/util/save_util.h20
-rw-r--r--test/util/save_util_linux.cc24
-rw-r--r--test/util/save_util_other.cc8
-rw-r--r--test/util/timer_util.cc18
-rw-r--r--test/util/timer_util.h94
-rw-r--r--tools/bazel.mk10
-rw-r--r--tools/bazeldefs/BUILD41
-rw-r--r--tools/bazeldefs/cc.bzl43
-rw-r--r--tools/bazeldefs/defs.bzl158
-rw-r--r--tools/bazeldefs/go.bzl142
-rw-r--r--tools/bazeldefs/pkg.bzl6
-rw-r--r--tools/bigquery/BUILD5
-rw-r--r--tools/bigquery/bigquery.go60
-rw-r--r--tools/checkescape/checkescape.go11
-rw-r--r--tools/defs.bzl60
-rw-r--r--tools/github/main.go36
-rw-r--r--tools/github/nogo/BUILD2
-rw-r--r--tools/github/nogo/nogo.go56
-rw-r--r--tools/github/reviver/github.go19
-rw-r--r--tools/github/reviver/reviver_test.go9
-rw-r--r--tools/go_generics/defs.bzl108
-rw-r--r--tools/nogo/BUILD21
-rw-r--r--tools/nogo/analyzers.go131
-rw-r--r--tools/nogo/build.go16
-rw-r--r--tools/nogo/check/BUILD2
-rw-r--r--tools/nogo/check/main.go92
-rw-r--r--tools/nogo/config.go756
-rw-r--r--tools/nogo/defs.bzl265
-rw-r--r--tools/nogo/filter/BUILD14
-rw-r--r--tools/nogo/filter/main.go131
-rw-r--r--tools/nogo/findings.go63
-rwxr-xr-xtools/nogo/gentest.sh47
-rw-r--r--tools/nogo/io_bazel_rules_go-visibility.patch25
-rw-r--r--tools/nogo/matchers.go172
-rw-r--r--tools/nogo/nogo.go191
-rw-r--r--tools/nogo/register.go67
-rw-r--r--tools/nogo/util/BUILD9
-rw-r--r--tools/nogo/util/util.go85
-rw-r--r--tools/parsers/BUILD16
-rw-r--r--tools/parsers/go_parser.go17
-rw-r--r--tools/parsers/go_parser_test.go12
-rw-r--r--tools/parsers/parser_main.go129
-rw-r--r--tools/rules_go.patch14
-rwxr-xr-xtools/tag_release.sh6
-rw-r--r--website/BUILD7
-rw-r--r--website/_config.yml6
-rw-r--r--website/blog/2020-10-22-platform-portability.md120
-rw-r--r--website/blog/BUILD11
-rw-r--r--website/cmd/server/main.go20
364 files changed, 14538 insertions, 6047 deletions
diff --git a/.bazelrc b/.bazelrc
index a2fe95822..e2848ef07 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -28,20 +28,12 @@ build:remote --bes_results_url="https://source.cloud.google.com/results/invocati
build:remote --bes_timeout=600s
build:remote --project_id=gvisor-rbe
build:remote --remote_instance_name=projects/gvisor-rbe/instances/default_instance
-build:remote3 --remote_executor=grpcs://remotebuildexecution.googleapis.com
-build:remote3 --project_id=gvisor-rbe
-build:remote3 --bes_backend=buildeventservice.googleapis.com
-build:remote3 --bes_results_url="https://source.cloud.google.com/results/invocations"
-build:remote3 --bes_timeout=600s
-build:remote3 --remote_instance_name=projects/gvisor-rbe/instances/default_instance
# Enable authentication. This will pick up application default credentials by
# default. You can use --google_credentials=some_file.json to use a service
# account credential instead.
build:remote --google_default_credentials=true
build:remote --auth_scope="https://www.googleapis.com/auth/cloud-source-tools"
-build:remote3 --google_default_credentials=true
-build:remote3 --auth_scope="https://www.googleapis.com/auth/cloud-source-tools"
# Add a custom platform and toolchain that builds in a privileged docker
# container, which is required by our syscall tests.
@@ -50,31 +42,5 @@ build:remote --extra_toolchains=//tools/bazeldefs:cc-toolchain-clang-x86_64-defa
build:remote --extra_execution_platforms=//tools/bazeldefs:rbe_ubuntu1604
build:remote --platforms=//tools/bazeldefs:rbe_ubuntu1604
build:remote --crosstool_top=@rbe_default//cc:toolchain
-build:remote --jobs=100
+build:remote --jobs=300
build:remote --remote_timeout=3600
-build:remote3 --host_platform=//tools/bazeldefs:rbe_ubuntu1604_bazel3
-build:remote3 --extra_toolchains=//tools/bazeldefs:cc-toolchain-clang-x86_64-default_bazel3
-build:remote3 --extra_execution_platforms=//tools/bazeldefs:rbe_ubuntu1604_bazel3
-build:remote3 --platforms=//tools/bazeldefs:rbe_ubuntu1604_bazel3
-build:remote3 --crosstool_top=@rbe_default//cc:toolchain
-build:remote3 --jobs=100
-build:remote3 --remote_timeout=3600
-
-# Set flags for uploading to BES in order to view results in the Bazel Build
-# Results UI.
-build:results --bes_backend="buildeventservice.googleapis.com"
-build:results --bes_timeout=60s
-build:results --tls_enabled
-
-# Output BES results url
-build:results --bes_results_url="https://source.cloud.google.com/results/invocations/"
-
-# Set flags for uploading to BES without Remote Build Execution.
-build:results-local --bes_backend="buildeventservice.googleapis.com"
-build:results-local --bes_timeout=60s
-build:results-local --tls_enabled=true
-build:results-local --auth_enabled=true
-build:results-local --spawn_strategy=local
-build:results-local --remote_cache=remotebuildexecution.googleapis.com
-build:results-local --remote_timeout=3600
-build:results-local --bes_results_url="https://source.cloud.google.com/results/invocations/"
diff --git a/.gitignore b/.gitignore
index 13babef4d..a56f6ebcd 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,2 @@
# Generated bazel symlinks.
-/bazel-*
+/bazel-* \ No newline at end of file
diff --git a/BUILD b/BUILD
index 2639f8169..153464220 100644
--- a/BUILD
+++ b/BUILD
@@ -1,10 +1,17 @@
load("//tools:defs.bzl", "build_test", "gazelle", "go_path")
+load("//tools/nogo:defs.bzl", "nogo_config")
load("//website:defs.bzl", "doc")
package(licenses = ["notice"])
exports_files(["LICENSE"])
+nogo_config(
+ name = "nogo_config",
+ srcs = ["nogo.yaml"],
+ visibility = ["//:sandbox"],
+)
+
doc(
name = "contributing",
src = "CONTRIBUTING.md",
@@ -75,12 +82,19 @@ go_path(
name = "gopath",
mode = "link",
deps = [
- # Main binary.
- "//runsc",
- "//shim/v1:gvisor-containerd-shim",
- "//shim/v2:containerd-shim-runsc-v1",
+ # Main binaries.
+ #
+ # For reasons related to reproducibility of the generated
+ # files, in order to ensure that :gopath produces only a
+ # a single "pure" version of all files, we can only depend
+ # on go_library targets here, and not go_binary. Thus the
+ # binaries have been factored into a cli package, which is
+ # a good practice in any case.
+ "//runsc/cli",
+ "//shim/v1/cli",
+ "//shim/v2/cli",
- # Packages that are not dependencies of //runsc.
+ # Packages that are not dependencies of the above.
"//pkg/sentry/kernel/memevent",
"//pkg/tcpip/adapters/gonet",
"//pkg/tcpip/link/channel",
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 89180eb3f..c53df7d25 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -70,10 +70,8 @@ Rules:
* `@org_golang_x_sys//unix:go_default_library` (Go import
`golang.org/x/sys/unix`).
* Generated Go protobuf packages.
- * `@com_github_golang_protobuf//proto:go_default_library` (Go import
- `github.com/golang/protobuf/proto`).
- * `@com_github_golang_protobuf//ptypes:go_default_library` (Go import
- `github.com/golang/protobuf/ptypes`).
+ * `@org_golang_google_protobuf//proto:go_default_library` (Go import
+ `google.golang.org/protobuf`).
* `runsc` may only depend on the following packages:
diff --git a/Makefile b/Makefile
index cdfcc63b3..88f23de8d 100644
--- a/Makefile
+++ b/Makefile
@@ -130,8 +130,7 @@ unit-tests: ## Local package unit tests in pkg/..., runsc/, tools/.., etc.
.PHONY: unit-tests
tests: ## Runs all unit tests and syscall tests.
-tests: unit-tests
- @$(call submake,test TARGETS="test/syscalls/...")
+tests: unit-tests syscall-tests
.PHONY: tests
integration-tests: ## Run all standard integration tests.
@@ -147,15 +146,14 @@ network-tests: iptables-tests packetdrill-tests packetimpact-tests
INTEGRATION_TARGETS := //test/image:image_test //test/e2e:integration_test
syscall-%-tests:
- @$(call submake,test OPTIONS="--test_tag_filters runsc_$* test/syscalls/...")
+ @$(call submake,test OPTIONS="--test_tag_filters runsc_$*" TARGETS="test/syscalls/...")
syscall-native-tests:
- @$(call submake,test OPTIONS="--test_tag_filters native test/syscalls/...")
+ @$(call submake,test OPTIONS="--test_tag_filters native" TARGETS="test/syscalls/...")
.PHONY: syscall-native-tests
syscall-tests: ## Run all system call tests.
-syscall-tests: syscall-ptrace-tests syscall-kvm-tests syscall-native-tests
-.PHONY: syscall-tests
+ @$(call submake,test TARGETS="test/syscalls/...")
%-runtime-tests: load-runtimes_%
@$(call submake,install-test-runtime)
@@ -262,7 +260,7 @@ website-build: load-jekyll ## Build the site image locally.
.PHONY: website-build
website-server: website-build ## Run a local server for development.
- @docker run -i -p 8080:8080 gvisor.dev/images/website
+ @docker run -i -p 8080:8080 $(WEBSITE_IMAGE)
.PHONY: website-server
website-push: website-build ## Push a new image and update the service.
@@ -382,7 +380,7 @@ test-runtime: ## A convenient wrapper around test that provides the runtime argu
nogo: ## Surfaces all nogo findings.
@$(call submake,build OPTIONS="--build_tag_filters nogo" TARGETS="//...")
- @$(call submake,run TARGETS="//tools/github" ARGS="-path=$(BUILD_ROOT) -dry-run nogo")
+ @$(call submake,run TARGETS="//tools/github" ARGS="$(foreach dir,$(BUILD_ROOTS),-path=$(CURDIR)/$(dir)) -dry-run nogo")
.PHONY: nogo
gazelle: ## Runs gazelle to update WORKSPACE.
diff --git a/WORKSPACE b/WORKSPACE
index ef179e617..30d21e472 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -20,47 +20,28 @@ bazel_skylib_workspace()
# Note that this repository actually patches some other Go repositories as it
# loads it, in order to limit visibility. We hack this process by patching the
# patch used by the Go rules, turning the trick against itself.
+
http_archive(
name = "io_bazel_rules_go",
+ sha256 = "b725e6497741d7fc2d55fcc29a276627d10e43fa5d0bb692692890ae30d98d00",
patch_args = ["-p1"],
patches = [
- "//tools/nogo:io_bazel_rules_go-visibility.patch",
+ # Newer versions of the rules_go rules will automatically strip test
+ # binaries of symbols, which we don't want.
+ "//tools:rules_go.patch",
],
- sha256 = "db2b2d35293f405430f553bc7a865a8749a8ef60c30287e90d2b278c32771afe",
urls = [
- "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.22.3/rules_go-v0.22.3.tar.gz",
- "https://github.com/bazelbuild/rules_go/releases/download/v0.22.3/rules_go-v0.22.3.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.24.3/rules_go-v0.24.3.tar.gz",
+ "https://github.com/bazelbuild/rules_go/releases/download/v0.24.3/rules_go-v0.24.3.tar.gz",
],
)
http_archive(
name = "bazel_gazelle",
- sha256 = "d8c45ee70ec39a57e7a05e5027c32b1576cc7f16d9dd37135b0eddde45cf1b10",
- urls = [
- "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/bazel-gazelle/releases/download/v0.20.0/bazel-gazelle-v0.20.0.tar.gz",
- "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.20.0/bazel-gazelle-v0.20.0.tar.gz",
- ],
-)
-
-http_archive(
- name = "io_bazel_rules_go_bazel3", # To replace the above.
- patch_args = ["-p1"],
- patches = [
- "//tools/nogo:io_bazel_rules_go-visibility.patch",
- ],
- sha256 = "87f0fb9747854cb76a0a82430adccb6269f7d394237104a4523b51061c469171",
- urls = [
- "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.23.1/rules_go-v0.23.1.tar.gz",
- "https://github.com/bazelbuild/rules_go/releases/download/v0.23.1/rules_go-v0.23.1.tar.gz",
- ],
-)
-
-http_archive(
- name = "bazel_gazelle_bazel3", # To replace the above.
- sha256 = "bfd86b3cbe855d6c16c6fce60d76bd51f5c8dbc9cfcaef7a2bb5c1aafd0710e8",
+ sha256 = "b85f48fa105c4403326e9525ad2b2cc437babaa6e15a3fc0b1dbab0ab064bc7c",
urls = [
- "https://mirror.bazel.build/github.com/bazelbuild/bazel-gazelle/releases/download/v0.21.0/bazel-gazelle-v0.21.0.tar.gz",
- "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.21.0/bazel-gazelle-v0.21.0.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-gazelle/releases/download/v0.22.2/bazel-gazelle-v0.22.2.tar.gz",
+ "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.22.2/bazel-gazelle-v0.22.2.tar.gz",
],
)
@@ -117,16 +98,6 @@ rules_proto_toolchains()
# See releases at https://releases.bazel.build/bazel-toolchains.html
http_archive(
name = "bazel_toolchains",
- sha256 = "239a1a673861eabf988e9804f45da3b94da28d1aff05c373b013193c315d9d9e",
- strip_prefix = "bazel-toolchains-3.0.1",
- urls = [
- "https://github.com/bazelbuild/bazel-toolchains/releases/download/3.0.1/bazel-toolchains-3.0.1.tar.gz",
- "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/releases/download/3.0.1/bazel-toolchains-3.0.1.tar.gz",
- ],
-)
-
-http_archive(
- name = "bazel_toolchains_bazel3", # To replace the above.
sha256 = "144290c4166bd67e76a54f96cd504ed86416ca3ca82030282760f0823c10be48",
strip_prefix = "bazel-toolchains-3.1.1",
urls = [
@@ -208,8 +179,8 @@ http_archive(
go_repository(
name = "com_github_sirupsen_logrus",
importpath = "github.com/sirupsen/logrus",
- sum = "h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4=",
- version = "v1.4.2",
+ sum = "h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I=",
+ version = "v1.6.0",
)
go_repository(
@@ -357,8 +328,8 @@ go_repository(
go_repository(
name = "org_golang_x_tools",
importpath = "golang.org/x/tools",
- sum = "h1:k7tVuG0g1JwmD3Jh8oAl1vQ1C3jb4Hi/dUl1wWDBJpQ=",
- version = "v0.0.0-20200918232735-d647fc253266",
+ sum = "h1:vWQvJ/Z0Lu+9/8oQ/pAYXNzbc7CMnBl+tULGVHOy3oE=",
+ version = "v0.0.0-20201002184944-ecd9fd270d5d",
)
go_repository(
@@ -441,8 +412,8 @@ go_repository(
go_repository(
name = "com_github_konsorten_go_windows_terminal_sequences",
importpath = "github.com/konsorten/go-windows-terminal-sequences",
- sum = "h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=",
- version = "v1.0.2",
+ sum = "h1:vWQvJ/Z0Lu+9/8oQ/pAYXNzbc7CMnBl+tULGVHOy3oE=",
+ version = "v1.0.3",
)
go_repository(
@@ -561,8 +532,8 @@ go_repository(
go_repository(
name = "com_github_containerd_continuity",
importpath = "github.com/containerd/continuity",
- sum = "h1:PEmIrUvwG9Yyv+0WKZqjXfSFDeZjs/q15g0m08BYS9k=",
- version = "v0.0.0-20200710164510-efbc4488d8fe",
+ sum = "h1:jEIoR0aA5GogXZ8pP3DUzE+zrhaF6/1rYZy+7KkYEWM=",
+ version = "v0.0.0-20200928162600-f2cc35102c2a",
)
go_repository(
@@ -603,8 +574,8 @@ go_repository(
go_repository(
name = "com_github_dustin_go_humanize",
importpath = "github.com/dustin/go-humanize",
- sum = "h1:qk/FSDDxo05wdJH28W+p5yivv7LuLYLRXPPD8KQCtZs=",
- version = "v0.0.0-20171111073723-bb3d318650d4",
+ sum = "h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=",
+ version = "v1.0.0",
)
go_repository(
@@ -622,13 +593,6 @@ go_repository(
)
go_repository(
- name = "com_github_fsnotify_fsnotify",
- importpath = "github.com/fsnotify/fsnotify",
- sum = "h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=",
- version = "v1.4.7",
-)
-
-go_repository(
name = "com_github_godbus_dbus",
importpath = "github.com/godbus/dbus",
sum = "h1:BWhy2j3IXJhjCbC68FptL43tDKIq8FladmaTs3Xs7Z8=",
@@ -685,13 +649,6 @@ go_repository(
)
go_repository(
- name = "com_github_hpcloud_tail",
- importpath = "github.com/hpcloud/tail",
- sum = "h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=",
- version = "v1.0.0",
-)
-
-go_repository(
name = "com_github_inconshreveable_mousetrap",
importpath = "github.com/inconshreveable/mousetrap",
sum = "h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=",
@@ -720,20 +677,6 @@ go_repository(
)
go_repository(
- name = "com_github_onsi_ginkgo",
- importpath = "github.com/onsi/ginkgo",
- sum = "h1:q/mM8GF/n0shIN8SaAZ0V+jnLPzen6WIVZdiwrRlMlo=",
- version = "v1.10.1",
-)
-
-go_repository(
- name = "com_github_onsi_gomega",
- importpath = "github.com/onsi/gomega",
- sum = "h1:XPnZz8VVBHjVsy1vzJmRwIcSwiUO+JFfrv/xGiigmME=",
- version = "v1.7.0",
-)
-
-go_repository(
name = "com_github_opencontainers_runc",
importpath = "github.com/opencontainers/runc",
sum = "h1:GlxAyO6x8rfZYN9Tt0Kti5a/cP41iuiO2yYT0IJGY8Y=",
@@ -797,34 +740,6 @@ go_repository(
)
go_repository(
- name = "in_gopkg_airbrake_gobrake_v2",
- importpath = "gopkg.in/airbrake/gobrake.v2",
- sum = "h1:7z2uVWwn7oVeeugY1DtlPAy5H+KYgB1KeKTnqjNatLo=",
- version = "v2.0.9",
-)
-
-go_repository(
- name = "in_gopkg_fsnotify_v1",
- importpath = "gopkg.in/fsnotify.v1",
- sum = "h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=",
- version = "v1.4.7",
-)
-
-go_repository(
- name = "in_gopkg_gemnasium_logrus_airbrake_hook_v2",
- importpath = "gopkg.in/gemnasium/logrus-airbrake-hook.v2",
- sum = "h1:OAj3g0cR6Dx/R07QgQe8wkA9RNjB2u4i700xBkIT4e0=",
- version = "v2.1.2",
-)
-
-go_repository(
- name = "in_gopkg_tomb_v1",
- importpath = "gopkg.in/tomb.v1",
- sum = "h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=",
- version = "v1.0.0-20141024135613-dd632973f1e7",
-)
-
-go_repository(
name = "in_gopkg_yaml_v2",
importpath = "gopkg.in/yaml.v2",
sum = "h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=",
@@ -848,15 +763,15 @@ go_repository(
go_repository(
name = "org_golang_google_genproto",
importpath = "google.golang.org/genproto",
- sum = "h1:wDju+RU97qa0FZT0QnZDg9Uc2dH0Ql513kFvHocz+WM=",
- version = "v0.0.0-20200117163144-32f20d992d24",
+ sum = "h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY=",
+ version = "v0.0.0-20200526211855-cb27e3aa2013",
)
go_repository(
name = "org_golang_google_protobuf",
importpath = "google.golang.org/protobuf",
- sum = "h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM=",
- version = "v1.23.0",
+ sum = "h1:poC0iCcx0QXFYlS6nuq/8K+Ng5T55k0FXdzq52hVi4w=",
+ version = "v1.25.1-0.20200808011614-a180de9f97d9",
)
go_repository(
diff --git a/g3doc/user_guide/containerd/quick_start.md b/g3doc/user_guide/containerd/quick_start.md
index b6a3186d8..a98fe5c4a 100644
--- a/g3doc/user_guide/containerd/quick_start.md
+++ b/g3doc/user_guide/containerd/quick_start.md
@@ -1,7 +1,7 @@
# Containerd Quick Start
-This document describes how to install and configure `containerd-shim-runsc-v1`
-using the containerd runtime handler support on `containerd` 1.2 or later.
+This document describes how to use `containerd-shim-runsc-v1` with the
+containerd runtime handler support on `containerd` 1.2 or later.
> ⚠️ NOTE: If you are using Kubernetes and set up your cluster using kubeadm you
> may run into issues. See the [FAQ](../FAQ.md#runtime-handler) for details.
diff --git a/g3doc/user_guide/install.md b/g3doc/user_guide/install.md
index abb9e8582..c3ced9d61 100644
--- a/g3doc/user_guide/install.md
+++ b/g3doc/user_guide/install.md
@@ -13,15 +13,19 @@ To download and install the latest release manually follow these steps:
(
set -e
URL=https://storage.googleapis.com/gvisor/releases/release/latest
- wget ${URL}/runsc ${URL}/runsc.sha512
- sha512sum -c runsc.sha512
- rm -f runsc.sha512
- sudo mv runsc /usr/local/bin
- sudo chmod a+rx /usr/local/bin/runsc
+ wget ${URL}/runsc ${URL}/runsc.sha512 \
+ ${URL}/gvisor-containerd-shim ${URL}/gvisor-containerd-shim.sha512 \
+ ${URL}/containerd-shim-runsc-v1 ${URL}/containerd-shim-runsc-v1.sha512
+ sha512sum -c runsc.sha512 \
+ -c gvisor-containerd-shim.sha512 \
+ -c containerd-shim-runsc-v1.sha512
+ rm -f *.sha512
+ chmod a+rx runsc gvisor-containerd-shim containerd-shim-runsc-v1
+ sudo mv runsc gvisor-containerd-shim containerd-shim-runsc-v1 /usr/local/bin
)
```
-To install gVisor with Docker, run the following commands:
+To install gVisor as a Docker runtime, run the following commands:
```bash
/usr/local/bin/runsc install
@@ -165,5 +169,6 @@ You can use this link with the steps described in
Note that `apt` installation of a specific point release is not supported.
After installation, try out `runsc` by following the
-[Docker Quick Start](./quick_start/docker.md) or
+[Docker Quick Start](./quick_start/docker.md),
+[Containerd QuickStart](./containerd/quick_start.md), or
[OCI Quick Start](./quick_start/oci.md).
diff --git a/go.mod b/go.mod
index 7cc40f9ab..e6df99177 100644
--- a/go.mod
+++ b/go.mod
@@ -12,7 +12,7 @@ require (
github.com/cilium/ebpf v0.0.0-20200110133405-4032b1d8aae3 // indirect
github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41 // indirect
github.com/containerd/containerd v1.3.4 // indirect
- github.com/containerd/continuity v0.0.0-20200710164510-efbc4488d8fe // indirect
+ github.com/containerd/continuity v0.0.0-20200928162600-f2cc35102c2a // indirect
github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00 // indirect
github.com/containerd/go-runc v0.0.0-20200220073739-7016d3ce2328 // indirect
github.com/containerd/ttrpc v0.0.0-20200121165050-0be804eadb15 // indirect
@@ -29,14 +29,12 @@ require (
github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e // indirect
github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 // indirect
github.com/gogo/googleapis v1.4.0 // indirect
- github.com/golang/protobuf v1.4.2 // indirect
github.com/google/go-cmp v0.5.1 // indirect
github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8 // indirect
github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8 // indirect
github.com/hashicorp/go-multierror v1.0.0 // indirect
github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1 // indirect
github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9 // indirect
- github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.0.1 // indirect
github.com/opencontainers/runc v0.1.1 // indirect
github.com/opencontainers/runtime-spec v1.0.2-0.20181111125026-1722abf79c2f // indirect
@@ -48,8 +46,9 @@ require (
go.uber.org/atomic v1.7.0 // indirect
go.uber.org/multierr v1.2.0 // indirect
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect
- golang.org/x/tools v0.0.0-20200918232735-d647fc253266 // indirect
+ golang.org/x/tools v0.0.0-20201002184944-ecd9fd270d5d // indirect
google.golang.org/grpc v1.29.0 // indirect
+ google.golang.org/protobuf v1.25.1-0.20200808011614-a180de9f97d9 // indirect
gopkg.in/yaml.v2 v2.2.8 // indirect
gotest.tools v2.2.0+incompatible // indirect
)
diff --git a/go.sum b/go.sum
index 150d9b5b7..e713d2eaa 100644
--- a/go.sum
+++ b/go.sum
@@ -48,8 +48,8 @@ github.com/containerd/containerd v1.3.2/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMX
github.com/containerd/containerd v1.3.4 h1:3o0smo5SKY7H6AJCmJhsnCjR2/V2T8VmiHt7seN2/kI=
github.com/containerd/containerd v1.3.4/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMXFTttgp+kVtyUA=
github.com/containerd/continuity v0.0.0-20190426062206-aaeac12a7ffc/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y=
-github.com/containerd/continuity v0.0.0-20200710164510-efbc4488d8fe h1:PEmIrUvwG9Yyv+0WKZqjXfSFDeZjs/q15g0m08BYS9k=
-github.com/containerd/continuity v0.0.0-20200710164510-efbc4488d8fe/go.mod h1:cECdGN1O8G9bgKTlLhuPJimka6Xb/Gg7vYzCTNVxhvo=
+github.com/containerd/continuity v0.0.0-20200928162600-f2cc35102c2a h1:jEIoR0aA5GogXZ8pP3DUzE+zrhaF6/1rYZy+7KkYEWM=
+github.com/containerd/continuity v0.0.0-20200928162600-f2cc35102c2a/go.mod h1:W0qIOTD7mp2He++YVq+kgfXezRYqzP1uDuMVH1bITDY=
github.com/containerd/fifo v0.0.0-20190226154929-a9fb20d87448/go.mod h1:ODA38xgv3Kuk8dQz2ZQXpnv/UZZUHUCL7pnLehbXgQI=
github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00 h1:lsjC5ENBl+Zgf38+B0ymougXFp0BaubeIVETltYZTQw=
github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00/go.mod h1:jPQ2IAeZRCYxpS/Cm1495vGFww6ecHmMk1YJH2Q5ln0=
@@ -82,12 +82,11 @@ github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/dpjacques/clockwork v0.1.1-0.20200827220843-c1f524b839be h1:l+j1wSnHcimOzeeKxtspsl6tCBTyikdYxcWqFZ+Ho2c=
github.com/dpjacques/clockwork v0.1.1-0.20200827220843-c1f524b839be/go.mod h1:D8mP2A8vVT2GkXqPorSBmhnshhkFBYgzhA90KmJt25Y=
-github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
+github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
-github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e h1:BWhy2j3IXJhjCbC68FptL43tDKIq8FladmaTs3Xs7Z8=
github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4=
@@ -116,8 +115,8 @@ github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:x
github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
-github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0=
-github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
+github.com/golang/protobuf v1.4.1 h1:ZFgWrT+bLgsYPirOnRfKLYJLvssAegOj/hgyMFdJZe0=
+github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
@@ -125,6 +124,7 @@ github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5a
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.1 h1:JFrFEBb2xKufg6XkJsJr+WbKb4FQlURi5RUcBveYu9k=
github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8 h1:zOOUQavr8D4AZrcV4ylUpbGa5j3jfeslN6Xculz3tVU=
@@ -147,7 +147,6 @@ github.com/hashicorp/go-multierror v1.0.0 h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uP
github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
-github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
@@ -158,6 +157,8 @@ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+o
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
+github.com/konsorten/go-windows-terminal-sequences v1.0.3 h1:CE8S1cTafDpPvMhIxNJKvHsGVBgn1xWYf1NbHQhywc8=
+github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1 h1:zc0R6cOw98cMengLA0fvU55mqbnN7sd/tBMLzSejp+M=
@@ -165,11 +166,7 @@ github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1/go.mod h1:pFQYn66WHrOpPYN
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9 h1:Sha2bQdoWE5YQPTlJOL31rmce94/tYi113SlFo1xQ2c=
github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8=
-github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
-github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
-github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/opencontainers/go-digest v0.0.0-20180430190053-c9281466c8b2/go.mod h1:cMLVZDEM3+U2I4VmLI6N8jQYUd2OVphdqWwCJHrFt2s=
-github.com/opencontainers/go-digest v1.0.0-rc1/go.mod h1:cMLVZDEM3+U2I4VmLI6N8jQYUd2OVphdqWwCJHrFt2s=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.0.1 h1:JMemWkRwHx4Zj+fVxWoMCFm/8sYGGrUVojFA6h/TRcI=
@@ -184,7 +181,6 @@ github.com/opencontainers/runtime-spec v1.0.2 h1:UfAcuLBJB9Coz72x1hgl8O5RVzTdNia
github.com/opencontainers/runtime-spec v1.0.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
github.com/pborman/uuid v1.2.0 h1:J7Q5mO4ysT1dv8hyrUGHb9+ooztCXu1D8MY8DZYsu3g=
github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
-github.com/pkg/errors v0.8.1-0.20171018195549-f15c970de5b7/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@@ -195,10 +191,11 @@ github.com/prometheus/procfs v0.0.0-20190522114515-bc1a522cf7b1/go.mod h1:TjEm7z
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
-github.com/sirupsen/logrus v1.0.4-0.20170822132746-89742aefa4b2/go.mod h1:pMByvHTf9Beacp5x1UXfOR9xyW/9antXMhjMPG0dEzc=
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
+github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I=
+github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/spf13/cobra v0.0.2-0.20171109065643-2da4a54c5cee/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ=
github.com/spf13/pflag v1.0.1-0.20171106142849-4c012f6dcd95/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@@ -222,7 +219,6 @@ go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/multierr v1.2.0 h1:6I+W7f5VwC5SV9dNrZ3qXrDB9mD0dyGOi/ZJmYw03T4=
go.uber.org/multierr v1.2.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
-golang.org/x/crypto v0.0.0-20171113213409-9f005a07e0d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@@ -281,7 +277,6 @@ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208 h1:qwRHBd0NqMbJxfbotnDhm2By
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -327,8 +322,8 @@ golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/tools v0.0.0-20200918232735-d647fc253266 h1:k7tVuG0g1JwmD3Jh8oAl1vQ1C3jb4Hi/dUl1wWDBJpQ=
-golang.org/x/tools v0.0.0-20200918232735-d647fc253266/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU=
+golang.org/x/tools v0.0.0-20201002184944-ecd9fd270d5d h1:vWQvJ/Z0Lu+9/8oQ/pAYXNzbc7CMnBl+tULGVHOy3oE=
+golang.org/x/tools v0.0.0-20201002184944-ecd9fd270d5d/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -354,8 +349,9 @@ google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8=
google.golang.org/genproto v0.0.0-20200115191322-ca5a22157cba/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
-google.golang.org/genproto v0.0.0-20200117163144-32f20d992d24 h1:wDju+RU97qa0FZT0QnZDg9Uc2dH0Ql513kFvHocz+WM=
google.golang.org/genproto v0.0.0-20200117163144-32f20d992d24/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
+google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY=
+google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
@@ -363,6 +359,7 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac
google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
+google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
google.golang.org/grpc v1.29.0 h1:2pJjwYOdkZ9HlN4sWRYBg9ttH5bCOlsueaM+b/oYjwo=
google.golang.org/grpc v1.29.0/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
@@ -370,16 +367,13 @@ google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
-google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM=
-google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
-gopkg.in/airbrake/gobrake.v2 v2.0.9/go.mod h1:/h5ZAUhDkGaJfjzjKLSjv6zCL6O0LLBxU4K+aSYdM/U=
+google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
+google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
+google.golang.org/protobuf v1.25.1-0.20200808011614-a180de9f97d9 h1:poC0iCcx0QXFYlS6nuq/8K+Ng5T55k0FXdzq52hVi4w=
+google.golang.org/protobuf v1.25.1-0.20200808011614-a180de9f97d9/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
-gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
-gopkg.in/gemnasium/logrus-airbrake-hook.v2 v2.1.2/go.mod h1:Xk6kEKp8OKb+X14hQBKWaSkCsqBpgog8nAV2xsGOxlo=
-gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
-gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw=
diff --git a/images/defs.bzl b/images/defs.bzl
index 61d7bbf73..c1f96e312 100644
--- a/images/defs.bzl
+++ b/images/defs.bzl
@@ -2,30 +2,33 @@
def _docker_image_impl(ctx):
importer = ctx.actions.declare_file(ctx.label.name)
+
importer_content = [
"#!/bin/bash",
"set -euo pipefail",
+ "source_file='%s'" % ctx.file.data.path,
+ "if [[ ! -f \"$source_file\" ]]; then",
+ " source_file='%s'" % ctx.file.data.short_path,
+ "fi",
"exec docker import " + " ".join([
"-c '%s'" % attr
for attr in ctx.attr.statements
- ]) + " " + " ".join([
- "'%s'" % f.path
- for f in ctx.files.data
- ]) + " $1",
+ ]) + " \"$source_file\" $1",
"",
]
+
ctx.actions.write(importer, "\n".join(importer_content), is_executable = True)
return [DefaultInfo(
- runfiles = ctx.runfiles(ctx.files.data),
+ runfiles = ctx.runfiles([ctx.file.data]),
executable = importer,
)]
docker_image = rule(
implementation = _docker_image_impl,
- doc = "Tool to load a Docker image; takes a single parameter (image name).",
+ doc = "Tool to import a Docker image; takes a single parameter (image name).",
attrs = {
"statements": attr.string_list(doc = "Extra Dockerfile directives."),
- "data": attr.label_list(doc = "All image data."),
+ "data": attr.label(doc = "Image filesystem tarball", allow_single_file = [".tgz", ".tar.gz"]),
},
executable = True,
)
diff --git a/nogo.yaml b/nogo.yaml
new file mode 100644
index 000000000..07fd9aa4d
--- /dev/null
+++ b/nogo.yaml
@@ -0,0 +1,495 @@
+groups:
+ # We define three basic groups: generated (all generated files),
+ # exteranl (all files outside the repository), and internal (all
+ # files within the local repository). We can't enforce many style
+ # checks on generated and external code, so enable those cases
+ # selectively for analyzers below.
+ - name: generated
+ regex: "^(bazel-genfiles|bazel-out|bazel-bin)/"
+ default: true
+ - name: external
+ regex: "^external/"
+ default: false
+ - name: internal
+ regex: ".*"
+ default: true
+global:
+ generated:
+ suppress:
+ # Suppress the basic style checks for
+ # generated code, but keep the analysis
+ # that are required for quality & security.
+ - "should not use ALL_CAPS in Go names"
+ - "should not use underscores"
+ - "comment on exported"
+ - "methods on the same type should have the same receiver name"
+ - "at least one file in a package"
+ - "package comment should be of the form"
+ # Generated code may have dead code paths.
+ - "identical build constraints"
+ - "no value of type"
+ - "is never used"
+ # go_embed_data rules generate unicode literals.
+ - "string literal contains the Unicode format character"
+ - "string literal contains the Unicode control character"
+ # Some external code will generate protov1
+ # implementations. These should be ignored.
+ - "proto.* is deprecated"
+ - "xxx_messageInfo_.*"
+ - "receiver name should be a reflection of its identity"
+ # Generated gRPC code is not compliant either.
+ - "error strings should not be capitalized"
+ - "grpc.Errorf is deprecated"
+ internal:
+ suppress:
+ # We use ALL_CAPS for system definitions,
+ # which are common enough in the code base
+ # that we shouldn't annotate exceptions.
+ #
+ # Same story for underscores.
+ - "should not use ALL_CAPS in Go names"
+ - "should not use underscores in Go names"
+ exclude:
+ # A variety of staticcheck and stylecheck
+ # rules apply here. These should be fixed
+ # and removed from here, and the global
+ # rules should be used sparingly.
+ - pkg/abi/linux/fuse.go:22
+ - pkg/abi/linux/fuse.go:25
+ - pkg/abi/linux/socket.go:113
+ - pkg/abi/linux/tty.go:73
+ - pkg/bpf/decoder.go:112
+ - pkg/cpuid/cpuid_x86.go:675
+ - pkg/eventchannel/event.go:193
+ - pkg/eventchannel/event.go:27
+ - pkg/eventchannel/event_test.go:22
+ - pkg/eventchannel/rate.go:19
+ - pkg/gohacks/gohacks_unsafe.go:33
+ - pkg/log/json.go:30
+ - pkg/log/log.go:359
+ - pkg/merkletree/merkletree.go:230
+ - pkg/merkletree/merkletree.go:243
+ - pkg/merkletree/merkletree.go:249
+ - pkg/merkletree/merkletree.go:266
+ - pkg/merkletree/merkletree.go:355
+ - pkg/merkletree/merkletree.go:369
+ - pkg/metric/metric_test.go:20
+ - pkg/p9/p9test/client_test.go:687
+ - pkg/p9/transport_test.go:196
+ - pkg/pool/pool.go:15
+ - pkg/refs/refcounter.go:510
+ - pkg/refs/refcounter_test.go:169
+ - pkg/refs_vfs2/refs.go:16
+ - pkg/safemem/block_unsafe.go:89
+ - pkg/seccomp/seccomp.go:82
+ - pkg/segment/test/set_functions.go:15
+ - pkg/sentry/arch/signal.go:166
+ - pkg/sentry/arch/signal.go:171
+ - pkg/sentry/control/pprof.go:196
+ - pkg/sentry/devices/memdev/full.go:58
+ - pkg/sentry/devices/memdev/null.go:59
+ - pkg/sentry/devices/memdev/random.go:68
+ - pkg/sentry/devices/memdev/zero.go:86
+ - pkg/sentry/fdimport/fdimport.go:15
+ - pkg/sentry/fs/attr.go:257
+ - pkg/sentry/fsbridge/fs.go:116
+ - pkg/sentry/fsbridge/vfs.go:124
+ - pkg/sentry/fsbridge/vfs.go:70
+ - pkg/sentry/fs/copy_up.go:365
+ - pkg/sentry/fs/copy_up_test.go:65
+ - pkg/sentry/fs/dev/net_tun.go:161
+ - pkg/sentry/fs/dev/net_tun.go:63
+ - pkg/sentry/fs/dev/null.go:97
+ - pkg/sentry/fs/dirent_cache.go:64
+ - pkg/sentry/fs/file_overlay.go:327
+ - pkg/sentry/fs/file_overlay.go:524
+ - pkg/sentry/fs/filetest/filetest.go:55
+ - pkg/sentry/fs/filetest/filetest.go:60
+ - pkg/sentry/fs/fs.go:77
+ - pkg/sentry/fs/fsutil/file.go:290
+ - pkg/sentry/fs/fsutil/file.go:346
+ - pkg/sentry/fs/fsutil/host_file_mapper.go:105
+ - pkg/sentry/fs/fsutil/inode_cached.go:676
+ - pkg/sentry/fs/fsutil/inode_cached.go:772
+ - pkg/sentry/fs/gofer/attr.go:120
+ - pkg/sentry/fs/gofer/fifo.go:33
+ - pkg/sentry/fs/gofer/inode.go:410
+ - pkg/sentry/fsimpl/devpts/devpts.go:110
+ - pkg/sentry/fsimpl/devpts/devpts.go:246
+ - pkg/sentry/fsimpl/devpts/devpts.go:50
+ - pkg/sentry/fsimpl/devpts/master.go:110
+ - pkg/sentry/fsimpl/devpts/master.go:55
+ - pkg/sentry/fsimpl/devpts/replica.go:113
+ - pkg/sentry/fsimpl/devpts/replica.go:57
+ - pkg/sentry/fsimpl/devtmpfs/devtmpfs.go:54
+ - pkg/sentry/fsimpl/ext/disklayout/superblock_64.go:97
+ - pkg/sentry/fsimpl/ext/disklayout/superblock_old.go:92
+ - pkg/sentry/fsimpl/ext/disklayout/block_group_32.go:44
+ - pkg/sentry/fsimpl/ext/disklayout/inode_new.go:91
+ - pkg/sentry/fsimpl/ext/disklayout/inode_old.go:93
+ - pkg/sentry/fsimpl/ext/disklayout/superblock_32.go:66
+ - pkg/sentry/fsimpl/ext/disklayout/block_group_64.go:53
+ - pkg/sentry/fsimpl/eventfd/eventfd.go:268
+ - pkg/sentry/fsimpl/ext/directory.go:163
+ - pkg/sentry/fsimpl/ext/directory.go:164
+ - pkg/sentry/fsimpl/ext/extent_file.go:142
+ - pkg/sentry/fsimpl/ext/extent_file.go:143
+ - pkg/sentry/fsimpl/ext/ext.go:105
+ - pkg/sentry/fsimpl/ext/filesystem.go:287
+ - pkg/sentry/fsimpl/ext/regular_file.go:153
+ - pkg/sentry/fsimpl/ext/symlink.go:113
+ - pkg/sentry/fsimpl/fuse/connection_control.go:194
+ - pkg/sentry/fsimpl/fuse/dev.go:387
+ - pkg/sentry/fsimpl/fuse/dev_test.go:318
+ - pkg/sentry/fsimpl/fuse/fusefs.go:102
+ - pkg/sentry/fsimpl/fuse/read_write.go:129
+ - pkg/sentry/fsimpl/fuse/request_response.go:71
+ - pkg/sentry/fsimpl/gofer/directory.go:135
+ - pkg/sentry/fsimpl/gofer/filesystem.go:679
+ - pkg/sentry/fsimpl/gofer/gofer.go:1694
+ - pkg/sentry/fsimpl/gofer/gofer.go:276
+ - pkg/sentry/fsimpl/gofer/regular_file.go:81
+ - pkg/sentry/fsimpl/gofer/special_file.go:141
+ - pkg/sentry/fsimpl/host/host.go:184
+ - pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go:50
+ - pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go:90
+ - pkg/sentry/fsimpl/kernfs/fd_impl_util.go:273
+ - pkg/sentry/fsimpl/kernfs/filesystem.go:247
+ - pkg/sentry/fsimpl/kernfs/inode_impl_util.go:320
+ - pkg/sentry/fsimpl/kernfs/inode_impl_util.go:497
+ - pkg/sentry/fsimpl/kernfs/synthetic_directory.go:52
+ - pkg/sentry/fsimpl/overlay/directory.go:119
+ - pkg/sentry/fsimpl/overlay/filesystem.go:527
+ - pkg/sentry/fsimpl/overlay/non_directory.go:152
+ - pkg/sentry/fsimpl/overlay/overlay.go:115
+ - pkg/sentry/fsimpl/overlay/overlay.go:719
+ - pkg/sentry/fsimpl/pipefs/pipefs.go:74
+ - pkg/sentry/fsimpl/proc/filesystem.go:52
+ - pkg/sentry/fsimpl/proc/filesystem.go:81
+ - pkg/sentry/fsimpl/proc/subtasks.go:126
+ - pkg/sentry/fsimpl/proc/subtasks.go:189
+ - pkg/sentry/fsimpl/proc/task_fds.go:168
+ - pkg/sentry/fsimpl/proc/task_fds.go:228
+ - pkg/sentry/fsimpl/proc/task_fds.go:301
+ - pkg/sentry/fsimpl/proc/task_fds.go:318
+ - pkg/sentry/fsimpl/proc/task_fds.go:67
+ - pkg/sentry/fsimpl/proc/task_files.go:112
+ - pkg/sentry/fsimpl/proc/task_files.go:158
+ - pkg/sentry/fsimpl/proc/task_files.go:259
+ - pkg/sentry/fsimpl/proc/task_files.go:285
+ - pkg/sentry/fsimpl/proc/task_files.go:305
+ - pkg/sentry/fsimpl/proc/task_files.go:384
+ - pkg/sentry/fsimpl/proc/task_files.go:403
+ - pkg/sentry/fsimpl/proc/task_files.go:428
+ - pkg/sentry/fsimpl/proc/task_files.go:691
+ - pkg/sentry/fsimpl/proc/task_files.go:770
+ - pkg/sentry/fsimpl/proc/task_files.go:797
+ - pkg/sentry/fsimpl/proc/task_files.go:828
+ - pkg/sentry/fsimpl/proc/task_files.go:879
+ - pkg/sentry/fsimpl/proc/task_files.go:910
+ - pkg/sentry/fsimpl/proc/task_files.go:961
+ - pkg/sentry/fsimpl/proc/task.go:127
+ - pkg/sentry/fsimpl/proc/task.go:193
+ - pkg/sentry/fsimpl/proc/task_net.go:134
+ - pkg/sentry/fsimpl/proc/task_net.go:475
+ - pkg/sentry/fsimpl/proc/task_net.go:491
+ - pkg/sentry/fsimpl/proc/task_net.go:508
+ - pkg/sentry/fsimpl/proc/task_net.go:665
+ - pkg/sentry/fsimpl/proc/task_net.go:715
+ - pkg/sentry/fsimpl/proc/task_net.go:779
+ - pkg/sentry/fsimpl/proc/tasks_files.go:113
+ - pkg/sentry/fsimpl/proc/tasks_files.go:388
+ - pkg/sentry/fsimpl/proc/tasks.go:232
+ - pkg/sentry/fsimpl/proc/tasks_sys.go:145
+ - pkg/sentry/fsimpl/proc/tasks_sys.go:181
+ - pkg/sentry/fsimpl/proc/tasks_sys.go:239
+ - pkg/sentry/fsimpl/proc/tasks_sys.go:291
+ - pkg/sentry/fsimpl/proc/tasks_sys.go:375
+ - pkg/sentry/fsimpl/signalfd/signalfd.go:124
+ - pkg/sentry/fsimpl/signalfd/signalfd.go:15
+ - pkg/sentry/fsimpl/signalfd/signalfd.go:126
+ - pkg/sentry/fsimpl/sockfs/sockfs.go:36
+ - pkg/sentry/fsimpl/sockfs/sockfs.go:79
+ - pkg/sentry/fsimpl/sys/kcov.go:49
+ - pkg/sentry/fsimpl/sys/kcov.go:99
+ - pkg/sentry/fsimpl/sys/sys.go:118
+ - pkg/sentry/fsimpl/sys/sys.go:56
+ - pkg/sentry/fsimpl/testutil/testutil.go:257
+ - pkg/sentry/fsimpl/testutil/testutil.go:260
+ - pkg/sentry/fsimpl/timerfd/timerfd.go:87
+ - pkg/sentry/fsimpl/tmpfs/directory.go:112
+ - pkg/sentry/fsimpl/tmpfs/filesystem.go:195
+ - pkg/sentry/fsimpl/tmpfs/regular_file.go:226
+ - pkg/sentry/fsimpl/tmpfs/regular_file.go:346
+ - pkg/sentry/fsimpl/tmpfs/tmpfs.go:103
+ - pkg/sentry/fsimpl/tmpfs/tmpfs.go:733
+ - pkg/sentry/fsimpl/verity/filesystem.go:490
+ - pkg/sentry/fsimpl/verity/verity.go:156
+ - pkg/sentry/fsimpl/verity/verity.go:629
+ - pkg/sentry/fsimpl/verity/verity.go:672
+ - pkg/sentry/fs/mount.go:162
+ - pkg/sentry/fs/mount.go:256
+ - pkg/sentry/fs/mount_overlay.go:144
+ - pkg/sentry/fs/mounts.go:432
+ - pkg/sentry/fs/proc/exec_args.go:104
+ - pkg/sentry/fs/proc/exec_args.go:73
+ - pkg/sentry/fs/proc/fds.go:269
+ - pkg/sentry/fs/proc/loadavg.go:33
+ - pkg/sentry/fs/proc/meminfo.go:39
+ - pkg/sentry/fs/proc/mounts.go:193
+ - pkg/sentry/fs/proc/mounts.go:84
+ - pkg/sentry/fs/proc/net.go:125
+ - pkg/sentry/fs/proc/proc.go:146
+ - pkg/sentry/fs/proc/proc.go:204
+ - pkg/sentry/fs/proc/seqfile/seqfile.go:210
+ - pkg/sentry/fs/proc/sys.go:146
+ - pkg/sentry/fs/proc/sys.go:43
+ - pkg/sentry/fs/proc/sys_net.go:113
+ - pkg/sentry/fs/proc/sys_net.go:205
+ - pkg/sentry/fs/proc/sys_net.go:233
+ - pkg/sentry/fs/proc/sys_net.go:307
+ - pkg/sentry/fs/proc/sys_net.go:335
+ - pkg/sentry/fs/proc/sys_net.go:446
+ - pkg/sentry/fs/proc/sys_net.go:456
+ - pkg/sentry/fs/proc/sys_net.go:89
+ - pkg/sentry/fs/proc/task.go:170
+ - pkg/sentry/fs/proc/task.go:322
+ - pkg/sentry/fs/proc/task.go:427
+ - pkg/sentry/fs/proc/task.go:467
+ - pkg/sentry/fs/proc/task.go:500
+ - pkg/sentry/fs/proc/task.go:784
+ - pkg/sentry/fs/proc/task.go:839
+ - pkg/sentry/fs/proc/task.go:920
+ - pkg/sentry/fs/proc/uid_gid_map.go:108
+ - pkg/sentry/fs/proc/uid_gid_map.go:79
+ - pkg/sentry/fs/proc/uptime.go:75
+ - pkg/sentry/fs/ramfs/dir.go:447
+ - pkg/sentry/fs/tmpfs/inode_file.go:436
+ - pkg/sentry/fs/tmpfs/inode_file.go:537
+ - pkg/sentry/fs/tty/dir.go:313
+ - pkg/sentry/fs/tty/master.go:131
+ - pkg/sentry/fs/tty/master.go:91
+ - pkg/sentry/fs/tty/replica.go:116
+ - pkg/sentry/fs/tty/replica.go:88
+ - pkg/sentry/kernel/auth/id_map.go:269
+ - pkg/sentry/kernel/fasync/fasync.go:67
+ - pkg/sentry/kernel/kcov.go:209
+ - pkg/sentry/kernel/kcov.go:223
+ - pkg/sentry/kernel/kernel.go:343
+ - pkg/sentry/kernel/kernel.go:368
+ - pkg/sentry/kernel/pipe/node_test.go:112
+ - pkg/sentry/kernel/pipe/node_test.go:119
+ - pkg/sentry/kernel/pipe/node_test.go:130
+ - pkg/sentry/kernel/pipe/node_test.go:137
+ - pkg/sentry/kernel/pipe/node_test.go:149
+ - pkg/sentry/kernel/pipe/node_test.go:150
+ - pkg/sentry/kernel/pipe/node_test.go:158
+ - pkg/sentry/kernel/pipe/node_test.go:174
+ - pkg/sentry/kernel/pipe/node_test.go:180
+ - pkg/sentry/kernel/pipe/node_test.go:193
+ - pkg/sentry/kernel/pipe/node_test.go:202
+ - pkg/sentry/kernel/pipe/node_test.go:205
+ - pkg/sentry/kernel/pipe/node_test.go:216
+ - pkg/sentry/kernel/pipe/node_test.go:219
+ - pkg/sentry/kernel/pipe/node_test.go:271
+ - pkg/sentry/kernel/pipe/node_test.go:290
+ - pkg/sentry/kernel/pipe/pipe_test.go:93
+ - pkg/sentry/kernel/pipe/reader_writer.go:65
+ - pkg/sentry/kernel/posixtimer.go:157
+ - pkg/sentry/kernel/ptrace.go:218
+ - pkg/sentry/kernel/semaphore/semaphore.go:323
+ - pkg/sentry/kernel/sessions.go:123
+ - pkg/sentry/kernel/sessions.go:508
+ - pkg/sentry/kernel/signal_handlers.go:57
+ - pkg/sentry/kernel/task_context.go:72
+ - pkg/sentry/kernel/task_exit.go:67
+ - pkg/sentry/kernel/task_sched.go:255
+ - pkg/sentry/kernel/task_sched.go:280
+ - pkg/sentry/kernel/task_sched.go:323
+ - pkg/sentry/kernel/task_stop.go:192
+ - pkg/sentry/kernel/thread_group.go:530
+ - pkg/sentry/kernel/timekeeper.go:316
+ - pkg/sentry/kernel/vdso.go:106
+ - pkg/sentry/kernel/vdso.go:118
+ - pkg/sentry/memmap/memmap.go:103
+ - pkg/sentry/memmap/memmap.go:163
+ - pkg/sentry/mm/address_space.go:42
+ - pkg/sentry/mm/address_space.go:42
+ - pkg/sentry/mm/aio_context.go:208
+ - pkg/sentry/mm/aio_context.go:288
+ - pkg/sentry/mm/pma.go:683
+ - pkg/sentry/mm/special_mappable.go:80
+ - pkg/sentry/syscalls/linux/sys_sem.go:62
+ - pkg/sentry/syscalls/linux/sys_time.go:189
+ - pkg/sentry/usage/cpu.go:42
+ - pkg/sentry/vfs/anonfs.go:302
+ - pkg/sentry/vfs/anonfs.go:99
+ - pkg/sentry/vfs/dentry.go:214
+ - pkg/sentry/vfs/epoll.go:168
+ - pkg/sentry/vfs/epoll.go:314
+ - pkg/sentry/vfs/file_description.go:549
+ - pkg/sentry/vfs/file_description_impl_util.go:304
+ - pkg/sentry/vfs/file_description_impl_util.go:412
+ - pkg/sentry/vfs/filesystem.go:76
+ - pkg/sentry/vfs/lock.go:15
+ - pkg/sentry/vfs/lock.go:47
+ - pkg/sentry/vfs/memxattr/xattr.go:37
+ - pkg/sentry/vfs/mount.go:510
+ - pkg/sentry/vfs/mount.go:667
+ - pkg/sentry/vfs/mount_test.go:106
+ - pkg/sentry/vfs/mount_test.go:160
+ - pkg/sentry/vfs/mount_test.go:215
+ - pkg/sentry/vfs/mount_unsafe.go:153
+ - pkg/sentry/vfs/resolving_path.go:228
+ - pkg/sentry/vfs/vfs.go:897
+ - pkg/shim/runsc/runsc.go:16
+ - pkg/shim/runsc/utils.go:16
+ - pkg/shim/v1/proc/deleted_state.go:16
+ - pkg/shim/v1/proc/exec.go:16
+ - pkg/shim/v1/proc/exec_state.go:16
+ - pkg/shim/v1/proc/init.go:16
+ - pkg/shim/v1/proc/init_state.go:16
+ - pkg/shim/v1/proc/io.go:16
+ - pkg/shim/v1/proc/process.go:16
+ - pkg/shim/v1/proc/types.go:16
+ - pkg/shim/v1/proc/utils.go:16
+ - pkg/shim/v1/shim/api.go:16
+ - pkg/shim/v1/shim/platform.go:16
+ - pkg/shim/v1/shim/service.go:16
+ - pkg/shim/v1/utils/annotations.go:15
+ - pkg/shim/v1/utils/utils.go:15
+ - pkg/shim/v1/utils/volumes.go:15
+ - pkg/shim/v2/api.go:16
+ - pkg/shim/v2/epoll.go:18
+ - pkg/shim/v2/options/options.go:15
+ - pkg/shim/v2/options/options.go:24
+ - pkg/shim/v2/options/options.go:26
+ - pkg/shim/v2/runtimeoptions/runtimeoptions.go:16
+ - pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go # Generated: exempt all.
+ - pkg/shim/v2/runtimeoptions/runtimeoptions_test.go:22
+ - pkg/shim/v2/service.go:15
+ - pkg/shim/v2/service_linux.go:18
+ - pkg/state/tests/integer_test.go:23
+ - pkg/state/tests/integer_test.go:28
+ - pkg/sync/rwmutex_test.go:105
+ - pkg/syserr/host_linux.go:35
+ - pkg/unet/unet_test.go:634
+ - pkg/unet/unet_test.go:662
+ - pkg/unet/unet_test.go:703
+ - pkg/unet/unet_test.go:98
+ - pkg/usermem/addr.go:34
+ - pkg/usermem/usermem.go:171
+ - pkg/usermem/usermem.go:170
+ - runsc/boot/compat.go:22
+ - runsc/boot/compat.go:56
+ - runsc/boot/loader.go:1115
+ - runsc/boot/loader.go:1120
+ - runsc/cmd/checkpoint.go:151
+ - runsc/config/flags.go:32
+ - runsc/container/container.go:641
+ - runsc/container/container.go:988
+ - runsc/specutils/specutils.go:172
+ - runsc/specutils/specutils.go:428
+ - runsc/specutils/specutils.go:436
+ - runsc/specutils/specutils.go:442
+ - runsc/specutils/specutils.go:447
+ - runsc/specutils/specutils.go:454
+ - test/cmd/test_app/fds.go:171
+ - test/iptables/filter_output.go:251
+ - test/packetimpact/testbench/connections.go:77
+ - tools/bigquery/bigquery.go:106
+ - tools/checkescape/test1/test1.go:108
+ - tools/checkescape/test1/test1.go:122
+ - tools/checkescape/test1/test1.go:137
+ - tools/checkescape/test1/test1.go:151
+ - tools/checkescape/test1/test1.go:170
+ - tools/checkescape/test1/test1.go:39
+ - tools/checkescape/test1/test1.go:45
+ - tools/checkescape/test1/test1.go:50
+ - tools/checkescape/test1/test1.go:64
+ - tools/checkescape/test1/test1.go:80
+ - tools/checkescape/test1/test1.go:94
+ - tools/go_generics/imports.go:51
+ - tools/go_generics/imports.go:75
+ - tools/go_marshal/gomarshal/generator.go:177
+ - tools/go_marshal/gomarshal/generator.go:81
+ - tools/go_marshal/gomarshal/generator.go:85
+ - tools/go_marshal/test/escape/escape.go:15
+ - tools/go_marshal/test/test.go:164
+analyzers:
+ asmdecl:
+ external: # Enabled.
+ assign:
+ external:
+ exclude:
+ - gazelle/walk/walk.go
+ atomic:
+ external: # Enabled.
+ bools:
+ external: # Enabled.
+ buildtag:
+ external: # Enabled.
+ cgocall:
+ external: # Enabled.
+ shadow: # Disable for now.
+ generated:
+ exclude: [".*"]
+ internal:
+ exclude: [".*"]
+ composites: # Disable for now.
+ generated:
+ exclude: [".*"]
+ internal:
+ exclude: [".*"]
+ errorsas:
+ external: # Enabled.
+ httpresponse:
+ external: # Enabled.
+ loopclosure:
+ external: # Enabled.
+ nilfunc:
+ external: # Enabled.
+ nilness:
+ internal:
+ exclude:
+ - pkg/sentry/platform/kvm/kvm_test.go # Intentional.
+ - tools/bigquery/bigquery.go # False positive.
+ printf:
+ external: # Enabled.
+ shift:
+ external: # Enabled.
+ stringintconv:
+ external:
+ exclude:
+ - ".*protobuf/.*.go" # Bad conversions.
+ - ".*flate/huffman_bit_writer.go" # Bad conversion.
+ # Runtime internal violations.
+ - ".*reflect/value.go"
+ - ".*encoding/xml/xml.go"
+ - ".*runtime/pprof/internal/profile/proto.go"
+ - ".*fmt/scan.go"
+ - ".*go/types/conversions.go"
+ - ".*golang.org/x/net/dns/dnsmessage/message.go"
+ tests:
+ external: # Enabled.
+ unmarshal:
+ external: # Enabled.
+ unreachable:
+ external: # Enabled.
+ unsafeptr:
+ internal:
+ exclude:
+ - ".*_test.go" # Exclude tests.
+ - "pkg/flipcall/.*_unsafe.go" # Special case.
+ - pkg/gohacks/gohacks_unsafe.go # Special case.
+ - pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go # Special case.
+ - pkg/sentry/platform/kvm/bluepill_unsafe.go # Special case.
+ - pkg/sentry/platform/kvm/machine_unsafe.go # Special case.
+ - pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go # Special case.
+ - pkg/sentry/platform/safecopy/safecopy_unsafe.go # Special case.
+ - pkg/sentry/vfs/mount_unsafe.go # Special case.
+ - pkg/state/decode_unsafe.go # Special case.
+ unusedresult:
+ external: # Enabled.
+ checkescape:
+ external: # Enabled.
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index 4a26e28de..a0654df2f 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -55,6 +55,8 @@ go_library(
"sched.go",
"seccomp.go",
"sem.go",
+ "sem_amd64.go",
+ "sem_arm64.go",
"shm.go",
"signal.go",
"signalfd.go",
diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go
index 7df02dd6d..006b5a525 100644
--- a/pkg/abi/linux/ioctl.go
+++ b/pkg/abi/linux/ioctl.go
@@ -121,6 +121,9 @@ const (
// Constants from uapi/linux/fsverity.h.
const (
+ FS_VERITY_HASH_ALG_SHA256 = 1
+ FS_VERITY_HASH_ALG_SHA512 = 2
+
FS_IOC_ENABLE_VERITY = 1082156677
FS_IOC_MEASURE_VERITY = 3221513862
)
diff --git a/pkg/abi/linux/sem.go b/pkg/abi/linux/sem.go
index 487a626cc..1b2f76c0b 100644
--- a/pkg/abi/linux/sem.go
+++ b/pkg/abi/linux/sem.go
@@ -34,18 +34,6 @@ const (
const SEM_UNDO = 0x1000
-// SemidDS is equivalent to struct semid64_ds.
-//
-// +marshal
-type SemidDS struct {
- SemPerm IPCPerm
- SemOTime TimeT
- SemCTime TimeT
- SemNSems uint64
- unused3 uint64
- unused4 uint64
-}
-
// Sembuf is equivalent to struct sembuf.
//
// +marshal slice:SembufSlice
diff --git a/pkg/abi/linux/sem_amd64.go b/pkg/abi/linux/sem_amd64.go
new file mode 100644
index 000000000..ab980cb4f
--- /dev/null
+++ b/pkg/abi/linux/sem_amd64.go
@@ -0,0 +1,33 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package linux
+
+// SemidDS is equivalent to struct semid64_ds.
+//
+// Source: arch/x86/include/uapi/asm/sembuf.h
+//
+// +marshal
+type SemidDS struct {
+ SemPerm IPCPerm
+ SemOTime TimeT
+ unused1 uint64
+ SemCTime TimeT
+ unused2 uint64
+ SemNSems uint64
+ unused3 uint64
+ unused4 uint64
+}
diff --git a/pkg/abi/linux/sem_arm64.go b/pkg/abi/linux/sem_arm64.go
new file mode 100644
index 000000000..521468fb1
--- /dev/null
+++ b/pkg/abi/linux/sem_arm64.go
@@ -0,0 +1,31 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package linux
+
+// SemidDS is equivalent to struct semid64_ds.
+//
+// Source: include/uapi/asm-generic/sembuf.h
+//
+// +marshal
+type SemidDS struct {
+ SemPerm IPCPerm
+ SemOTime TimeT
+ SemCTime TimeT
+ SemNSems uint64
+ unused3 uint64
+ unused4 uint64
+}
diff --git a/pkg/context/context.go b/pkg/context/context.go
index 2613bc752..f3031fc60 100644
--- a/pkg/context/context.go
+++ b/pkg/context/context.go
@@ -166,3 +166,27 @@ var bgContext = &logContext{Logger: log.Log()}
func Background() Context {
return bgContext
}
+
+// WithValue returns a copy of parent in which the value associated with key is
+// val.
+func WithValue(parent Context, key, val interface{}) Context {
+ return &withValue{
+ Context: parent,
+ key: key,
+ val: val,
+ }
+}
+
+type withValue struct {
+ Context
+ key interface{}
+ val interface{}
+}
+
+// Value implements Context.Value.
+func (ctx *withValue) Value(key interface{}) interface{} {
+ if key == ctx.key {
+ return ctx.val
+ }
+ return ctx.Context.Value(key)
+}
diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD
index bee28b68d..a493e3407 100644
--- a/pkg/eventchannel/BUILD
+++ b/pkg/eventchannel/BUILD
@@ -6,6 +6,7 @@ go_library(
name = "eventchannel",
srcs = [
"event.go",
+ "event_any.go",
"rate.go",
],
visibility = ["//:sandbox"],
@@ -14,8 +15,9 @@ go_library(
"//pkg/log",
"//pkg/sync",
"//pkg/unet",
- "@com_github_golang_protobuf//proto:go_default_library",
- "@com_github_golang_protobuf//ptypes:go_default_library_gen",
+ "@org_golang_google_protobuf//encoding/prototext:go_default_library",
+ "@org_golang_google_protobuf//proto:go_default_library",
+ "@org_golang_google_protobuf//types/known/anypb:go_default_library",
"@org_golang_x_time//rate:go_default_library",
],
)
@@ -32,6 +34,6 @@ go_test(
library = ":eventchannel",
deps = [
"//pkg/sync",
- "@com_github_golang_protobuf//proto:go_default_library",
+ "@org_golang_google_protobuf//proto:go_default_library",
],
)
diff --git a/pkg/eventchannel/event.go b/pkg/eventchannel/event.go
index 9a29c58bd..7172ce75d 100644
--- a/pkg/eventchannel/event.go
+++ b/pkg/eventchannel/event.go
@@ -24,8 +24,8 @@ import (
"fmt"
"syscall"
- "github.com/golang/protobuf/proto"
- "github.com/golang/protobuf/ptypes"
+ "google.golang.org/protobuf/encoding/prototext"
+ "google.golang.org/protobuf/proto"
pb "gvisor.dev/gvisor/pkg/eventchannel/eventchannel_go_proto"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sync"
@@ -118,22 +118,6 @@ func (me *multiEmitter) Close() error {
return err
}
-func marshal(msg proto.Message) ([]byte, error) {
- anypb, err := ptypes.MarshalAny(msg)
- if err != nil {
- return nil, err
- }
-
- // Wire format is uvarint message length followed by binary proto.
- bufMsg, err := proto.Marshal(anypb)
- if err != nil {
- return nil, err
- }
- p := make([]byte, binary.MaxVarintLen64)
- n := binary.PutUvarint(p, uint64(len(bufMsg)))
- return append(p[:n], bufMsg...), nil
-}
-
// socketEmitter emits proto messages on a socket.
type socketEmitter struct {
socket *unet.Socket
@@ -155,10 +139,19 @@ func SocketEmitter(fd int) (Emitter, error) {
// Emit implements Emitter.Emit.
func (s *socketEmitter) Emit(msg proto.Message) (bool, error) {
- p, err := marshal(msg)
+ any, err := newAny(msg)
if err != nil {
return false, err
}
+ bufMsg, err := proto.Marshal(any)
+ if err != nil {
+ return false, err
+ }
+
+ // Wire format is uvarint message length followed by binary proto.
+ p := make([]byte, binary.MaxVarintLen64)
+ n := binary.PutUvarint(p, uint64(len(bufMsg)))
+ p = append(p[:n], bufMsg...)
for done := 0; done < len(p); {
n, err := s.socket.Write(p[done:])
if err != nil {
@@ -166,6 +159,7 @@ func (s *socketEmitter) Emit(msg proto.Message) (bool, error) {
}
done += n
}
+
return false, nil
}
@@ -189,9 +183,13 @@ func DebugEmitterFrom(inner Emitter) Emitter {
}
func (d *debugEmitter) Emit(msg proto.Message) (bool, error) {
+ text, err := prototext.Marshal(msg)
+ if err != nil {
+ return false, err
+ }
ev := &pb.DebugEvent{
- Name: proto.MessageName(msg),
- Text: proto.MarshalTextString(msg),
+ Name: string(msg.ProtoReflect().Descriptor().FullName()),
+ Text: string(text),
}
return d.inner.Emit(ev)
}
diff --git a/pkg/eventchannel/event.proto b/pkg/eventchannel/event.proto
index 34468f072..4b24ac47c 100644
--- a/pkg/eventchannel/event.proto
+++ b/pkg/eventchannel/event.proto
@@ -16,7 +16,7 @@ syntax = "proto3";
package gvisor;
-// A debug event encapsulates any other event protobuf in text format. This is
+// DebugEvent encapsulates any other event protobuf in text format. This is
// useful because clients reading events emitted this way do not need to link
// the event protobufs to display them in a human-readable format.
message DebugEvent {
diff --git a/pkg/eventchannel/event_any.go b/pkg/eventchannel/event_any.go
new file mode 100644
index 000000000..a5549f6cd
--- /dev/null
+++ b/pkg/eventchannel/event_any.go
@@ -0,0 +1,25 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package eventchannel
+
+import (
+ "google.golang.org/protobuf/types/known/anypb"
+
+ "google.golang.org/protobuf/proto"
+)
+
+func newAny(m proto.Message) (*anypb.Any, error) {
+ return anypb.New(m)
+}
diff --git a/pkg/eventchannel/event_test.go b/pkg/eventchannel/event_test.go
index 43750360b..0dd408f76 100644
--- a/pkg/eventchannel/event_test.go
+++ b/pkg/eventchannel/event_test.go
@@ -19,7 +19,7 @@ import (
"testing"
"time"
- "github.com/golang/protobuf/proto"
+ "google.golang.org/protobuf/proto"
"gvisor.dev/gvisor/pkg/sync"
)
diff --git a/pkg/eventchannel/rate.go b/pkg/eventchannel/rate.go
index 179226c92..74960e16a 100644
--- a/pkg/eventchannel/rate.go
+++ b/pkg/eventchannel/rate.go
@@ -15,8 +15,8 @@
package eventchannel
import (
- "github.com/golang/protobuf/proto"
"golang.org/x/time/rate"
+ "google.golang.org/protobuf/proto"
)
// rateLimitedEmitter wraps an emitter and limits events to the given limits.
diff --git a/pkg/merkletree/BUILD b/pkg/merkletree/BUILD
index a8fcb2e19..501a9ef21 100644
--- a/pkg/merkletree/BUILD
+++ b/pkg/merkletree/BUILD
@@ -6,12 +6,18 @@ go_library(
name = "merkletree",
srcs = ["merkletree.go"],
visibility = ["//pkg/sentry:internal"],
- deps = ["//pkg/usermem"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/usermem",
+ ],
)
go_test(
name = "merkletree_test",
srcs = ["merkletree_test.go"],
library = ":merkletree",
- deps = ["//pkg/usermem"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/usermem",
+ ],
)
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
index d8227b8bd..18457d287 100644
--- a/pkg/merkletree/merkletree.go
+++ b/pkg/merkletree/merkletree.go
@@ -18,21 +18,32 @@ package merkletree
import (
"bytes"
"crypto/sha256"
+ "crypto/sha512"
"fmt"
"io"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/usermem"
)
const (
// sha256DigestSize specifies the digest size of a SHA256 hash.
sha256DigestSize = 32
+ // sha512DigestSize specifies the digest size of a SHA512 hash.
+ sha512DigestSize = 64
)
// DigestSize returns the size (in bytes) of a digest.
-// TODO(b/156980949): Allow config other hash methods (SHA384/SHA512).
-func DigestSize() int {
- return sha256DigestSize
+// TODO(b/156980949): Allow config SHA384.
+func DigestSize(hashAlgorithm int) int {
+ switch hashAlgorithm {
+ case linux.FS_VERITY_HASH_ALG_SHA256:
+ return sha256DigestSize
+ case linux.FS_VERITY_HASH_ALG_SHA512:
+ return sha512DigestSize
+ default:
+ return -1
+ }
}
// Layout defines the scale of a Merkle tree.
@@ -51,11 +62,19 @@ type Layout struct {
// InitLayout initializes and returns a new Layout object describing the structure
// of a tree. dataSize specifies the size of input data in bytes.
-func InitLayout(dataSize int64, dataAndTreeInSameFile bool) Layout {
+func InitLayout(dataSize int64, hashAlgorithms int, dataAndTreeInSameFile bool) (Layout, error) {
layout := Layout{
blockSize: usermem.PageSize,
- // TODO(b/156980949): Allow config other hash methods (SHA384/SHA512).
- digestSize: sha256DigestSize,
+ }
+
+ // TODO(b/156980949): Allow config SHA384.
+ switch hashAlgorithms {
+ case linux.FS_VERITY_HASH_ALG_SHA256:
+ layout.digestSize = sha256DigestSize
+ case linux.FS_VERITY_HASH_ALG_SHA512:
+ layout.digestSize = sha512DigestSize
+ default:
+ return Layout{}, fmt.Errorf("unexpected hash algorithms")
}
// treeStart is the offset (in bytes) of the first level of the tree in
@@ -88,7 +107,7 @@ func InitLayout(dataSize int64, dataAndTreeInSameFile bool) Layout {
}
layout.levelOffset = append(layout.levelOffset, treeStart+offset*layout.blockSize)
- return layout
+ return layout, nil
}
// hashesPerBlock() returns the number of digests in each block. For example,
@@ -139,12 +158,33 @@ func (d *VerityDescriptor) String() string {
}
// verify generates a hash from d, and compares it with expected.
-func (d *VerityDescriptor) verify(expected []byte) error {
- h := sha256.Sum256([]byte(d.String()))
+func (d *VerityDescriptor) verify(expected []byte, hashAlgorithms int) error {
+ h, err := hashData([]byte(d.String()), hashAlgorithms)
+ if err != nil {
+ return err
+ }
if !bytes.Equal(h[:], expected) {
return fmt.Errorf("unexpected root hash")
}
return nil
+
+}
+
+// hashData hashes data and returns the result hash based on the hash
+// algorithms.
+func hashData(data []byte, hashAlgorithms int) ([]byte, error) {
+ var digest []byte
+ switch hashAlgorithms {
+ case linux.FS_VERITY_HASH_ALG_SHA256:
+ digestArray := sha256.Sum256(data)
+ digest = digestArray[:]
+ case linux.FS_VERITY_HASH_ALG_SHA512:
+ digestArray := sha512.Sum512(data)
+ digest = digestArray[:]
+ default:
+ return nil, fmt.Errorf("unexpected hash algorithms")
+ }
+ return digest, nil
}
// GenerateParams contains the parameters used to generate a Merkle tree.
@@ -161,6 +201,8 @@ type GenerateParams struct {
UID uint32
// GID is the group ID of the target file.
GID uint32
+ // HashAlgorithms is the algorithms used to hash data.
+ HashAlgorithms int
// TreeReader is a reader for the Merkle tree.
TreeReader io.ReaderAt
// TreeWriter is a writer for the Merkle tree.
@@ -176,7 +218,10 @@ type GenerateParams struct {
// Generate returns a hash of a VerityDescriptor, which contains the file
// metadata and the hash from file content.
func Generate(params *GenerateParams) ([]byte, error) {
- layout := InitLayout(params.Size, params.DataAndTreeInSameFile)
+ layout, err := InitLayout(params.Size, params.HashAlgorithms, params.DataAndTreeInSameFile)
+ if err != nil {
+ return nil, err
+ }
numBlocks := (params.Size + layout.blockSize - 1) / layout.blockSize
@@ -218,10 +263,13 @@ func Generate(params *GenerateParams) ([]byte, error) {
return nil, err
}
// Hash the bytes in buf.
- digest := sha256.Sum256(buf)
+ digest, err := hashData(buf, params.HashAlgorithms)
+ if err != nil {
+ return nil, err
+ }
if level == layout.rootLevel() {
- root = digest[:]
+ root = digest
}
// Write the generated hash to the end of the tree file.
@@ -246,8 +294,7 @@ func Generate(params *GenerateParams) ([]byte, error) {
GID: params.GID,
RootHash: root,
}
- ret := sha256.Sum256([]byte(descriptor.String()))
- return ret[:], nil
+ return hashData([]byte(descriptor.String()), params.HashAlgorithms)
}
// VerifyParams contains the params used to verify a portion of a file against
@@ -269,6 +316,8 @@ type VerifyParams struct {
UID uint32
// GID is the group ID of the target file.
GID uint32
+ // HashAlgorithms is the algorithms used to hash data.
+ HashAlgorithms int
// ReadOffset is the offset of the data range to be verified.
ReadOffset int64
// ReadSize is the size of the data range to be verified.
@@ -298,7 +347,7 @@ func verifyMetadata(params *VerifyParams, layout *Layout) error {
GID: params.GID,
RootHash: root,
}
- return descriptor.verify(params.Expected)
+ return descriptor.verify(params.Expected, params.HashAlgorithms)
}
// Verify verifies the content read from data with offset. The content is
@@ -313,7 +362,10 @@ func Verify(params *VerifyParams) (int64, error) {
if params.ReadSize < 0 {
return 0, fmt.Errorf("unexpected read size: %d", params.ReadSize)
}
- layout := InitLayout(int64(params.Size), params.DataAndTreeInSameFile)
+ layout, err := InitLayout(int64(params.Size), params.HashAlgorithms, params.DataAndTreeInSameFile)
+ if err != nil {
+ return 0, err
+ }
if params.ReadSize == 0 {
return 0, verifyMetadata(params, &layout)
}
@@ -354,7 +406,7 @@ func Verify(params *VerifyParams) (int64, error) {
UID: params.UID,
GID: params.GID,
}
- if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.Expected); err != nil {
+ if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.HashAlgorithms, params.Expected); err != nil {
return 0, err
}
@@ -395,7 +447,7 @@ func Verify(params *VerifyParams) (int64, error) {
// fails if the calculated hash from block is different from any level of
// hashes stored in tree. And the final root hash is compared with
// expected.
-func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, expected []byte) error {
+func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, hashAlgorithms int, expected []byte) error {
if len(dataBlock) != int(layout.blockSize) {
return fmt.Errorf("incorrect block size")
}
@@ -406,8 +458,11 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout,
for level := 0; level < layout.numLevels(); level++ {
// Calculate hash.
if level == 0 {
- digestArray := sha256.Sum256(dataBlock)
- digest = digestArray[:]
+ h, err := hashData(dataBlock, hashAlgorithms)
+ if err != nil {
+ return err
+ }
+ digest = h
} else {
// Read a block in previous level that contains the
// hash we just generated, and generate a next level
@@ -415,8 +470,11 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout,
if _, err := tree.ReadAt(treeBlock, layout.blockOffset(level-1, blockIndex)); err != nil {
return err
}
- digestArray := sha256.Sum256(treeBlock)
- digest = digestArray[:]
+ h, err := hashData(treeBlock, hashAlgorithms)
+ if err != nil {
+ return err
+ }
+ digest = h
}
// Read the digest for the current block and store in
@@ -434,5 +492,5 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout,
// Verification for the tree succeeded. Now hash the descriptor with
// the root hash and compare it with expected.
descriptor.RootHash = digest
- return descriptor.verify(expected)
+ return descriptor.verify(expected, hashAlgorithms)
}
diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go
index e1350ebda..0782ca3e7 100644
--- a/pkg/merkletree/merkletree_test.go
+++ b/pkg/merkletree/merkletree_test.go
@@ -22,54 +22,114 @@ import (
"testing"
"time"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/usermem"
)
func TestLayout(t *testing.T) {
testCases := []struct {
dataSize int64
+ hashAlgorithms int
dataAndTreeInSameFile bool
+ expectedDigestSize int64
expectedLevelOffset []int64
}{
{
dataSize: 100,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: false,
+ expectedDigestSize: 32,
expectedLevelOffset: []int64{0},
},
{
dataSize: 100,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ dataAndTreeInSameFile: false,
+ expectedDigestSize: 64,
+ expectedLevelOffset: []int64{0},
+ },
+ {
+ dataSize: 100,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: true,
+ expectedDigestSize: 32,
+ expectedLevelOffset: []int64{usermem.PageSize},
+ },
+ {
+ dataSize: 100,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ dataAndTreeInSameFile: true,
+ expectedDigestSize: 64,
expectedLevelOffset: []int64{usermem.PageSize},
},
{
dataSize: 1000000,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: false,
+ expectedDigestSize: 32,
expectedLevelOffset: []int64{0, 2 * usermem.PageSize, 3 * usermem.PageSize},
},
{
dataSize: 1000000,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ dataAndTreeInSameFile: false,
+ expectedDigestSize: 64,
+ expectedLevelOffset: []int64{0, 4 * usermem.PageSize, 5 * usermem.PageSize},
+ },
+ {
+ dataSize: 1000000,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: true,
+ expectedDigestSize: 32,
expectedLevelOffset: []int64{245 * usermem.PageSize, 247 * usermem.PageSize, 248 * usermem.PageSize},
},
{
+ dataSize: 1000000,
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ dataAndTreeInSameFile: true,
+ expectedDigestSize: 64,
+ expectedLevelOffset: []int64{245 * usermem.PageSize, 249 * usermem.PageSize, 250 * usermem.PageSize},
+ },
+ {
dataSize: 4096 * int64(usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: false,
+ expectedDigestSize: 32,
expectedLevelOffset: []int64{0, 32 * usermem.PageSize, 33 * usermem.PageSize},
},
{
dataSize: 4096 * int64(usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ dataAndTreeInSameFile: false,
+ expectedDigestSize: 64,
+ expectedLevelOffset: []int64{0, 64 * usermem.PageSize, 65 * usermem.PageSize},
+ },
+ {
+ dataSize: 4096 * int64(usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
dataAndTreeInSameFile: true,
+ expectedDigestSize: 32,
expectedLevelOffset: []int64{4096 * usermem.PageSize, 4128 * usermem.PageSize, 4129 * usermem.PageSize},
},
+ {
+ dataSize: 4096 * int64(usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ dataAndTreeInSameFile: true,
+ expectedDigestSize: 64,
+ expectedLevelOffset: []int64{4096 * usermem.PageSize, 4160 * usermem.PageSize, 4161 * usermem.PageSize},
+ },
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) {
- l := InitLayout(tc.dataSize, tc.dataAndTreeInSameFile)
+ l, err := InitLayout(tc.dataSize, tc.hashAlgorithms, tc.dataAndTreeInSameFile)
+ if err != nil {
+ t.Fatalf("Failed to InitLayout: %v", err)
+ }
if l.blockSize != int64(usermem.PageSize) {
t.Errorf("Got blockSize %d, want %d", l.blockSize, usermem.PageSize)
}
- if l.digestSize != sha256DigestSize {
+ if l.digestSize != tc.expectedDigestSize {
t.Errorf("Got digestSize %d, want %d", l.digestSize, sha256DigestSize)
}
if l.numLevels() != len(tc.expectedLevelOffset) {
@@ -118,24 +178,49 @@ func TestGenerate(t *testing.T) {
// The input data has size dataSize. It starts with the data in startWith,
// and all other bytes are zeroes.
testCases := []struct {
- data []byte
- expectedHash []byte
+ data []byte
+ hashAlgorithms int
+ expectedHash []byte
}{
{
- data: bytes.Repeat([]byte{0}, usermem.PageSize),
- expectedHash: []byte{64, 253, 58, 72, 192, 131, 82, 184, 193, 33, 108, 142, 43, 46, 179, 134, 244, 21, 29, 190, 14, 39, 66, 129, 6, 46, 200, 211, 30, 247, 191, 252},
+ data: bytes.Repeat([]byte{0}, usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
+ expectedHash: []byte{64, 253, 58, 72, 192, 131, 82, 184, 193, 33, 108, 142, 43, 46, 179, 134, 244, 21, 29, 190, 14, 39, 66, 129, 6, 46, 200, 211, 30, 247, 191, 252},
+ },
+ {
+ data: bytes.Repeat([]byte{0}, usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ expectedHash: []byte{14, 27, 126, 158, 9, 94, 163, 51, 243, 162, 82, 167, 183, 127, 93, 121, 221, 23, 184, 59, 104, 166, 111, 49, 161, 195, 229, 111, 121, 201, 233, 68, 10, 154, 78, 142, 154, 236, 170, 156, 110, 167, 15, 144, 155, 97, 241, 235, 202, 233, 246, 217, 138, 88, 152, 179, 238, 46, 247, 185, 125, 20, 101, 201},
+ },
+ {
+ data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
+ expectedHash: []byte{182, 223, 218, 62, 65, 185, 160, 219, 93, 119, 186, 88, 205, 32, 122, 231, 173, 72, 78, 76, 65, 57, 177, 146, 159, 39, 44, 123, 230, 156, 97, 26},
+ },
+ {
+ data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ expectedHash: []byte{55, 204, 240, 1, 224, 252, 58, 131, 251, 174, 45, 140, 107, 57, 118, 11, 18, 236, 203, 204, 19, 59, 27, 196, 3, 78, 21, 7, 22, 98, 197, 128, 17, 128, 90, 122, 54, 83, 253, 108, 156, 67, 59, 229, 236, 241, 69, 88, 99, 44, 127, 109, 204, 183, 150, 232, 187, 57, 228, 137, 209, 235, 241, 172},
},
{
- data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1),
- expectedHash: []byte{182, 223, 218, 62, 65, 185, 160, 219, 93, 119, 186, 88, 205, 32, 122, 231, 173, 72, 78, 76, 65, 57, 177, 146, 159, 39, 44, 123, 230, 156, 97, 26},
+ data: []byte{'a'},
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
+ expectedHash: []byte{28, 201, 8, 36, 150, 178, 111, 5, 193, 212, 129, 205, 206, 124, 211, 90, 224, 142, 81, 183, 72, 165, 243, 240, 242, 241, 76, 127, 101, 61, 63, 11},
},
{
- data: []byte{'a'},
- expectedHash: []byte{28, 201, 8, 36, 150, 178, 111, 5, 193, 212, 129, 205, 206, 124, 211, 90, 224, 142, 81, 183, 72, 165, 243, 240, 242, 241, 76, 127, 101, 61, 63, 11},
+ data: []byte{'a'},
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ expectedHash: []byte{207, 233, 114, 94, 113, 212, 243, 160, 59, 232, 226, 77, 28, 81, 176, 61, 211, 213, 222, 190, 148, 196, 90, 166, 237, 56, 113, 148, 230, 154, 23, 105, 14, 97, 144, 211, 12, 122, 226, 207, 167, 203, 136, 193, 38, 249, 227, 187, 92, 238, 101, 97, 170, 255, 246, 209, 246, 98, 241, 150, 175, 253, 173, 206},
},
{
- data: bytes.Repeat([]byte{'a'}, usermem.PageSize),
- expectedHash: []byte{106, 58, 160, 152, 41, 68, 38, 108, 245, 74, 177, 84, 64, 193, 19, 176, 249, 86, 27, 193, 85, 164, 99, 240, 79, 104, 148, 222, 76, 46, 191, 79},
+ data: bytes.Repeat([]byte{'a'}, usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
+ expectedHash: []byte{106, 58, 160, 152, 41, 68, 38, 108, 245, 74, 177, 84, 64, 193, 19, 176, 249, 86, 27, 193, 85, 164, 99, 240, 79, 104, 148, 222, 76, 46, 191, 79},
+ },
+ {
+ data: bytes.Repeat([]byte{'a'}, usermem.PageSize),
+ hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512,
+ expectedHash: []byte{110, 103, 29, 250, 27, 211, 235, 119, 112, 65, 49, 156, 6, 92, 66, 105, 133, 1, 187, 172, 169, 13, 186, 34, 105, 72, 252, 131, 12, 159, 91, 188, 79, 184, 240, 227, 40, 164, 72, 193, 65, 31, 227, 153, 191, 6, 117, 42, 82, 122, 33, 255, 92, 215, 215, 249, 2, 131, 170, 134, 39, 192, 222, 33},
},
}
@@ -149,6 +234,7 @@ func TestGenerate(t *testing.T) {
Mode: defaultMode,
UID: defaultUID,
GID: defaultGID,
+ HashAlgorithms: tc.hashAlgorithms,
TreeReader: &tree,
TreeWriter: &tree,
DataAndTreeInSameFile: dataAndTreeInSameFile,
@@ -348,77 +434,81 @@ func TestVerify(t *testing.T) {
// Generate random bytes in data.
rand.Read(data)
- for _, dataAndTreeInSameFile := range []bool{false, true} {
- var tree bytesReadWriter
- genParams := GenerateParams{
- Size: int64(len(data)),
- Name: defaultName,
- Mode: defaultMode,
- UID: defaultUID,
- GID: defaultGID,
- TreeReader: &tree,
- TreeWriter: &tree,
- DataAndTreeInSameFile: dataAndTreeInSameFile,
- }
- if dataAndTreeInSameFile {
- tree.Write(data)
- genParams.File = &tree
- } else {
- genParams.File = &bytesReadWriter{
- bytes: data,
+ for _, hashAlgorithms := range []int{linux.FS_VERITY_HASH_ALG_SHA256, linux.FS_VERITY_HASH_ALG_SHA512} {
+ for _, dataAndTreeInSameFile := range []bool{false, true} {
+ var tree bytesReadWriter
+ genParams := GenerateParams{
+ Size: int64(len(data)),
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ HashAlgorithms: hashAlgorithms,
+ TreeReader: &tree,
+ TreeWriter: &tree,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
+ if dataAndTreeInSameFile {
+ tree.Write(data)
+ genParams.File = &tree
+ } else {
+ genParams.File = &bytesReadWriter{
+ bytes: data,
+ }
+ }
+ hash, err := Generate(&genParams)
+ if err != nil {
+ t.Fatalf("Generate failed: %v", err)
}
- }
- hash, err := Generate(&genParams)
- if err != nil {
- t.Fatalf("Generate failed: %v", err)
- }
- // Flip a bit in data and checks Verify results.
- var buf bytes.Buffer
- data[tc.modifyByte] ^= 1
- verifyParams := VerifyParams{
- Out: &buf,
- File: bytes.NewReader(data),
- Tree: &tree,
- Size: tc.dataSize,
- Name: defaultName,
- Mode: defaultMode,
- UID: defaultUID,
- GID: defaultGID,
- ReadOffset: tc.verifyStart,
- ReadSize: tc.verifySize,
- Expected: hash,
- DataAndTreeInSameFile: dataAndTreeInSameFile,
- }
- if tc.modifyName {
- verifyParams.Name = defaultName + "abc"
- }
- if tc.modifyMode {
- verifyParams.Mode = defaultMode + 1
- }
- if tc.modifyUID {
- verifyParams.UID = defaultUID + 1
- }
- if tc.modifyGID {
- verifyParams.GID = defaultGID + 1
- }
- if tc.shouldSucceed {
- n, err := Verify(&verifyParams)
- if err != nil && err != io.EOF {
- t.Errorf("Verification failed when expected to succeed: %v", err)
+ // Flip a bit in data and checks Verify results.
+ var buf bytes.Buffer
+ data[tc.modifyByte] ^= 1
+ verifyParams := VerifyParams{
+ Out: &buf,
+ File: bytes.NewReader(data),
+ Tree: &tree,
+ Size: tc.dataSize,
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ HashAlgorithms: hashAlgorithms,
+ ReadOffset: tc.verifyStart,
+ ReadSize: tc.verifySize,
+ Expected: hash,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
}
- if n != tc.verifySize {
- t.Errorf("Got Verify output size %d, want %d", n, tc.verifySize)
+ if tc.modifyName {
+ verifyParams.Name = defaultName + "abc"
}
- if int64(buf.Len()) != tc.verifySize {
- t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), tc.verifySize)
+ if tc.modifyMode {
+ verifyParams.Mode = defaultMode + 1
}
- if !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) {
- t.Errorf("Incorrect output buf from Verify")
+ if tc.modifyUID {
+ verifyParams.UID = defaultUID + 1
}
- } else {
- if _, err := Verify(&verifyParams); err == nil {
- t.Errorf("Verification succeeded when expected to fail")
+ if tc.modifyGID {
+ verifyParams.GID = defaultGID + 1
+ }
+ if tc.shouldSucceed {
+ n, err := Verify(&verifyParams)
+ if err != nil && err != io.EOF {
+ t.Errorf("Verification failed when expected to succeed: %v", err)
+ }
+ if n != tc.verifySize {
+ t.Errorf("Got Verify output size %d, want %d", n, tc.verifySize)
+ }
+ if int64(buf.Len()) != tc.verifySize {
+ t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), tc.verifySize)
+ }
+ if !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) {
+ t.Errorf("Incorrect output buf from Verify")
+ }
+ } else {
+ if _, err := Verify(&verifyParams); err == nil {
+ t.Errorf("Verification succeeded when expected to fail")
+ }
}
}
}
@@ -435,87 +525,91 @@ func TestVerifyRandom(t *testing.T) {
// Generate random bytes in data.
rand.Read(data)
- for _, dataAndTreeInSameFile := range []bool{false, true} {
- var tree bytesReadWriter
- genParams := GenerateParams{
- Size: int64(len(data)),
- Name: defaultName,
- Mode: defaultMode,
- UID: defaultUID,
- GID: defaultGID,
- TreeReader: &tree,
- TreeWriter: &tree,
- DataAndTreeInSameFile: dataAndTreeInSameFile,
- }
+ for _, hashAlgorithms := range []int{linux.FS_VERITY_HASH_ALG_SHA256, linux.FS_VERITY_HASH_ALG_SHA512} {
+ for _, dataAndTreeInSameFile := range []bool{false, true} {
+ var tree bytesReadWriter
+ genParams := GenerateParams{
+ Size: int64(len(data)),
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ HashAlgorithms: hashAlgorithms,
+ TreeReader: &tree,
+ TreeWriter: &tree,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
- if dataAndTreeInSameFile {
- tree.Write(data)
- genParams.File = &tree
- } else {
- genParams.File = &bytesReadWriter{
- bytes: data,
+ if dataAndTreeInSameFile {
+ tree.Write(data)
+ genParams.File = &tree
+ } else {
+ genParams.File = &bytesReadWriter{
+ bytes: data,
+ }
+ }
+ hash, err := Generate(&genParams)
+ if err != nil {
+ t.Fatalf("Generate failed: %v", err)
}
- }
- hash, err := Generate(&genParams)
- if err != nil {
- t.Fatalf("Generate failed: %v", err)
- }
- // Pick a random portion of data.
- start := rand.Int63n(dataSize - 1)
- size := rand.Int63n(dataSize) + 1
+ // Pick a random portion of data.
+ start := rand.Int63n(dataSize - 1)
+ size := rand.Int63n(dataSize) + 1
- var buf bytes.Buffer
- verifyParams := VerifyParams{
- Out: &buf,
- File: bytes.NewReader(data),
- Tree: &tree,
- Size: dataSize,
- Name: defaultName,
- Mode: defaultMode,
- UID: defaultUID,
- GID: defaultGID,
- ReadOffset: start,
- ReadSize: size,
- Expected: hash,
- DataAndTreeInSameFile: dataAndTreeInSameFile,
- }
+ var buf bytes.Buffer
+ verifyParams := VerifyParams{
+ Out: &buf,
+ File: bytes.NewReader(data),
+ Tree: &tree,
+ Size: dataSize,
+ Name: defaultName,
+ Mode: defaultMode,
+ UID: defaultUID,
+ GID: defaultGID,
+ HashAlgorithms: hashAlgorithms,
+ ReadOffset: start,
+ ReadSize: size,
+ Expected: hash,
+ DataAndTreeInSameFile: dataAndTreeInSameFile,
+ }
- // Checks that the random portion of data from the original data is
- // verified successfully.
- n, err := Verify(&verifyParams)
- if err != nil && err != io.EOF {
- t.Errorf("Verification failed for correct data: %v", err)
- }
- if size > dataSize-start {
- size = dataSize - start
- }
- if n != size {
- t.Errorf("Got Verify output size %d, want %d", n, size)
- }
- if int64(buf.Len()) != size {
- t.Errorf("Got Verify output buf size %d, want %d", buf.Len(), size)
- }
- if !bytes.Equal(data[start:start+size], buf.Bytes()) {
- t.Errorf("Incorrect output buf from Verify")
- }
+ // Checks that the random portion of data from the original data is
+ // verified successfully.
+ n, err := Verify(&verifyParams)
+ if err != nil && err != io.EOF {
+ t.Errorf("Verification failed for correct data: %v", err)
+ }
+ if size > dataSize-start {
+ size = dataSize - start
+ }
+ if n != size {
+ t.Errorf("Got Verify output size %d, want %d", n, size)
+ }
+ if int64(buf.Len()) != size {
+ t.Errorf("Got Verify output buf size %d, want %d", buf.Len(), size)
+ }
+ if !bytes.Equal(data[start:start+size], buf.Bytes()) {
+ t.Errorf("Incorrect output buf from Verify")
+ }
- // Verify that modified metadata should fail verification.
- buf.Reset()
- verifyParams.Name = defaultName + "abc"
- if _, err := Verify(&verifyParams); err == nil {
- t.Error("Verify succeeded for modified metadata, expect failure")
- }
+ // Verify that modified metadata should fail verification.
+ buf.Reset()
+ verifyParams.Name = defaultName + "abc"
+ if _, err := Verify(&verifyParams); err == nil {
+ t.Error("Verify succeeded for modified metadata, expect failure")
+ }
- // Flip a random bit in randPortion, and check that verification fails.
- buf.Reset()
- randBytePos := rand.Int63n(size)
- data[start+randBytePos] ^= 1
- verifyParams.File = bytes.NewReader(data)
- verifyParams.Name = defaultName
+ // Flip a random bit in randPortion, and check that verification fails.
+ buf.Reset()
+ randBytePos := rand.Int63n(size)
+ data[start+randBytePos] ^= 1
+ verifyParams.File = bytes.NewReader(data)
+ verifyParams.Name = defaultName
- if _, err := Verify(&verifyParams); err == nil {
- t.Error("Verification succeeded for modified data, expect failure")
+ if _, err := Verify(&verifyParams); err == nil {
+ t.Error("Verification succeeded for modified data, expect failure")
+ }
}
}
}
diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD
index 58305009d..0a6a5d215 100644
--- a/pkg/metric/BUILD
+++ b/pkg/metric/BUILD
@@ -27,6 +27,6 @@ go_test(
deps = [
":metric_go_proto",
"//pkg/eventchannel",
- "@com_github_golang_protobuf//proto:go_default_library",
+ "@org_golang_google_protobuf//proto:go_default_library",
],
)
diff --git a/pkg/metric/metric_test.go b/pkg/metric/metric_test.go
index c425ea532..aefd0ea5c 100644
--- a/pkg/metric/metric_test.go
+++ b/pkg/metric/metric_test.go
@@ -17,7 +17,7 @@ package metric
import (
"testing"
- "github.com/golang/protobuf/proto"
+ "google.golang.org/protobuf/proto"
"gvisor.dev/gvisor/pkg/eventchannel"
pb "gvisor.dev/gvisor/pkg/metric/metric_go_proto"
)
diff --git a/pkg/refs_vfs2/BUILD b/pkg/refsvfs2/BUILD
index 577b827a5..245e33d2d 100644
--- a/pkg/refs_vfs2/BUILD
+++ b/pkg/refsvfs2/BUILD
@@ -19,8 +19,16 @@ go_template(
)
go_library(
- name = "refs_vfs2",
- srcs = ["refs.go"],
- visibility = ["//pkg/sentry:internal"],
- deps = ["//pkg/context"],
+ name = "refsvfs2",
+ srcs = [
+ "refs.go",
+ "refs_map.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/context",
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/sync",
+ ],
)
diff --git a/pkg/refs_vfs2/refs.go b/pkg/refsvfs2/refs.go
index 99a074e96..ef8beb659 100644
--- a/pkg/refs_vfs2/refs.go
+++ b/pkg/refsvfs2/refs.go
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package refs_vfs2 defines an interface for a reference-counted object.
-package refs_vfs2
+// Package refsvfs2 defines an interface for a reference-counted object.
+package refsvfs2
import (
"gvisor.dev/gvisor/pkg/context"
diff --git a/pkg/refsvfs2/refs_map.go b/pkg/refsvfs2/refs_map.go
new file mode 100644
index 000000000..be75b0cc2
--- /dev/null
+++ b/pkg/refsvfs2/refs_map.go
@@ -0,0 +1,97 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package refsvfs2
+
+import (
+ "fmt"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/log"
+ refs_vfs1 "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// TODO(gvisor.dev/issue/1193): re-enable once kernfs refs are fixed.
+var ignored []string = []string{"kernfs.", "proc.", "sys.", "devpts.", "fuse."}
+
+var (
+ // liveObjects is a global map of reference-counted objects. Objects are
+ // inserted when leak check is enabled, and they are removed when they are
+ // destroyed. It is protected by liveObjectsMu.
+ liveObjects map[CheckedObject]struct{}
+ liveObjectsMu sync.Mutex
+)
+
+// CheckedObject represents a reference-counted object with an informative
+// leak detection message.
+type CheckedObject interface {
+ // LeakMessage supplies a warning to be printed upon leak detection.
+ LeakMessage() string
+}
+
+func init() {
+ liveObjects = make(map[CheckedObject]struct{})
+}
+
+// LeakCheckEnabled returns whether leak checking is enabled. The following
+// functions should only be called if it returns true.
+func LeakCheckEnabled() bool {
+ return refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking
+}
+
+// Register adds obj to the live object map.
+func Register(obj CheckedObject, typ string) {
+ for _, str := range ignored {
+ if strings.Contains(typ, str) {
+ return
+ }
+ }
+ liveObjectsMu.Lock()
+ if _, ok := liveObjects[obj]; ok {
+ panic(fmt.Sprintf("Unexpected entry in leak checking map: reference %p already added", obj))
+ }
+ liveObjects[obj] = struct{}{}
+ liveObjectsMu.Unlock()
+}
+
+// Unregister removes obj from the live object map.
+func Unregister(obj CheckedObject, typ string) {
+ liveObjectsMu.Lock()
+ defer liveObjectsMu.Unlock()
+ if _, ok := liveObjects[obj]; !ok {
+ for _, str := range ignored {
+ if strings.Contains(typ, str) {
+ return
+ }
+ }
+ panic(fmt.Sprintf("Expected to find entry in leak checking map for reference %p", obj))
+ }
+ delete(liveObjects, obj)
+}
+
+// DoLeakCheck iterates through the live object map and logs a message for each
+// object. It is called once no reference-counted objects should be reachable
+// anymore, at which point anything left in the map is considered a leak.
+func DoLeakCheck() {
+ liveObjectsMu.Lock()
+ defer liveObjectsMu.Unlock()
+ leaked := len(liveObjects)
+ if leaked > 0 {
+ log.Warningf("Leak checking detected %d leaked objects:", leaked)
+ for obj := range liveObjects {
+ log.Warningf(obj.LeakMessage())
+ }
+ }
+}
diff --git a/pkg/refs_vfs2/refs_template.go b/pkg/refsvfs2/refs_template.go
index d9b552896..ec295ef5b 100644
--- a/pkg/refs_vfs2/refs_template.go
+++ b/pkg/refsvfs2/refs_template.go
@@ -21,11 +21,9 @@ package refs_template
import (
"fmt"
- "runtime"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/log"
- refs_vfs1 "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
)
// T is the type of the reference counted object. It is only used to customize
@@ -42,11 +40,6 @@ var ownerType *T
// Note that the number of references is actually refCount + 1 so that a default
// zero-value Refs object contains one reference.
//
-// TODO(gvisor.dev/issue/1486): Store stack traces when leak check is enabled in
-// a map with 16-bit hashes, and store the hash in the top 16 bits of refCount.
-// This will allow us to add stack trace information to the leak messages
-// without growing the size of Refs.
-//
// +stateify savable
type Refs struct {
// refCount is composed of two fields:
@@ -59,24 +52,16 @@ type Refs struct {
refCount int64
}
-func (r *Refs) finalize() {
- var note string
- switch refs_vfs1.GetLeakMode() {
- case refs_vfs1.NoLeakChecking:
- return
- case refs_vfs1.UninitializedLeakChecking:
- note = "(Leak checker uninitialized): "
- }
- if n := r.ReadRefs(); n != 0 {
- log.Warningf("%sRefs %p owned by %T garbage collected with ref count of %d (want 0)", note, r, ownerType, n)
+// EnableLeakCheck enables reference leak checking on r.
+func (r *Refs) EnableLeakCheck() {
+ if refsvfs2.LeakCheckEnabled() {
+ refsvfs2.Register(r, fmt.Sprintf("%T", ownerType))
}
}
-// EnableLeakCheck checks for reference leaks when Refs gets garbage collected.
-func (r *Refs) EnableLeakCheck() {
- if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking {
- runtime.SetFinalizer(r, (*Refs).finalize)
- }
+// LeakMessage implements refsvfs2.CheckedObject.LeakMessage.
+func (r *Refs) LeakMessage() string {
+ return fmt.Sprintf("%T %p: reference count of %d instead of 0", ownerType, r, r.ReadRefs())
}
// ReadRefs returns the current number of references. The returned count is
@@ -91,7 +76,7 @@ func (r *Refs) ReadRefs() int64 {
//go:nosplit
func (r *Refs) IncRef() {
if v := atomic.AddInt64(&r.refCount, 1); v <= 0 {
- panic(fmt.Sprintf("Incrementing non-positive ref count %p owned by %T", r, ownerType))
+ panic(fmt.Sprintf("Incrementing non-positive count %p on %T", r, ownerType))
}
}
@@ -134,9 +119,18 @@ func (r *Refs) DecRef(destroy func()) {
panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %T", r, ownerType))
case v == -1:
+ if refsvfs2.LeakCheckEnabled() {
+ refsvfs2.Unregister(r, fmt.Sprintf("%T", ownerType))
+ }
// Call the destructor.
if destroy != nil {
destroy()
}
}
}
+
+func (r *Refs) afterLoad() {
+ if refsvfs2.LeakCheckEnabled() && r.ReadRefs() > 0 {
+ r.EnableLeakCheck()
+ }
+}
diff --git a/pkg/sentry/control/state.go b/pkg/sentry/control/state.go
index 41feeffe3..d800f2c85 100644
--- a/pkg/sentry/control/state.go
+++ b/pkg/sentry/control/state.go
@@ -69,5 +69,5 @@ func (s *State) Save(o *SaveOpts, _ *struct{}) error {
s.Kernel.Kill(kernel.ExitStatus{})
},
}
- return saveOpts.Save(s.Kernel, s.Watchdog)
+ return saveOpts.Save(s.Kernel.SupervisorContext(), s.Kernel, s.Watchdog)
}
diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go
index 655ea549b..ff5d49fbd 100644
--- a/pkg/sentry/devices/tundev/tundev.go
+++ b/pkg/sentry/devices/tundev/tundev.go
@@ -39,6 +39,8 @@ const (
)
// tunDevice implements vfs.Device for /dev/net/tun.
+//
+// +stateify savable
type tunDevice struct{}
// Open implements vfs.Device.Open.
@@ -53,6 +55,8 @@ func (tunDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opt
}
// tunFD implements vfs.FileDescriptionImpl for /dev/net/tun.
+//
+// +stateify savable
type tunFD struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go
index 1390a9a7f..4468f5dd2 100644
--- a/pkg/sentry/fs/fsutil/host_file_mapper.go
+++ b/pkg/sentry/fs/fsutil/host_file_mapper.go
@@ -70,6 +70,13 @@ func (f *HostFileMapper) Init() {
f.mappings = make(map[uint64]mapping)
}
+// IsInited returns true if f.Init() has been called. This is used when
+// restoring a checkpoint that contains a HostFileMapper that may or may not
+// have been initialized.
+func (f *HostFileMapper) IsInited() bool {
+ return f.refs != nil
+}
+
// NewHostFileMapper returns an initialized HostFileMapper allocated on the
// heap with no references or cached mappings.
func NewHostFileMapper() *HostFileMapper {
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index 22d658acf..450044c9c 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -92,6 +92,7 @@ func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bo
"gid_map": newGIDMap(t, msrc),
"io": newIO(t, msrc, isThreadGroup),
"maps": newMaps(t, msrc),
+ "mem": newMem(t, msrc),
"mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc),
"mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc),
"net": newNetDir(t, msrc),
@@ -399,6 +400,88 @@ func newNamespaceDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
return newProcInode(t, d, msrc, fs.SpecialDirectory, t)
}
+// memData implements fs.Inode for /proc/[pid]/mem.
+//
+// +stateify savable
+type memData struct {
+ fsutil.SimpleFileInode
+
+ t *kernel.Task
+}
+
+// memDataFile implements fs.FileOperations for /proc/[pid]/mem.
+//
+// +stateify savable
+type memDataFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoWrite `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ t *kernel.Task
+}
+
+func newMem(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ inode := &memData{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0400), linux.PROC_SUPER_MAGIC),
+ t: t,
+ }
+ return newProcInode(t, inode, msrc, fs.SpecialFile, t)
+}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (m *memData) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (m *memData) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ // TODO(gvisor.dev/issue/260): Add check for PTRACE_MODE_ATTACH_FSCREDS
+ // Permission to read this file is governed by PTRACE_MODE_ATTACH_FSCREDS
+ // Since we dont implement setfsuid/setfsgid we can just use PTRACE_MODE_ATTACH
+ if !kernel.ContextCanTrace(ctx, m.t, true) {
+ return nil, syserror.EACCES
+ }
+ if err := checkTaskState(m.t); err != nil {
+ return nil, err
+ }
+ // Enable random access reads
+ flags.Pread = true
+ return fs.NewFile(ctx, dirent, flags, &memDataFile{t: m.t}), nil
+}
+
+// Read implements fs.FileOperations.Read.
+func (m *memDataFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ mm, err := getTaskMM(m.t)
+ if err != nil {
+ return 0, nil
+ }
+ defer mm.DecUsers(ctx)
+ // Buffer the read data because of MM locks
+ buf := make([]byte, dst.NumBytes())
+ n, readErr := mm.CopyIn(ctx, usermem.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true})
+ if n > 0 {
+ if _, err := dst.CopyOut(ctx, buf[:n]); err != nil {
+ return 0, syserror.EFAULT
+ }
+ return int64(n), nil
+ }
+ if readErr != nil {
+ return 0, syserror.EIO
+ }
+ return 0, nil
+}
+
// mapsData implements seqfile.SeqSource for /proc/[pid]/maps.
//
// +stateify savable
diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD
index 84baaac66..6af3c3781 100644
--- a/pkg/sentry/fsimpl/devpts/BUILD
+++ b/pkg/sentry/fsimpl/devpts/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "root_inode_refs.go",
package = "devpts",
prefix = "rootInode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "rootInode",
},
@@ -33,6 +33,7 @@ go_library(
"//pkg/marshal",
"//pkg/marshal/primitive",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/fs",
diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go
index d5c5aaa8c..9185877f6 100644
--- a/pkg/sentry/fsimpl/devpts/devpts.go
+++ b/pkg/sentry/fsimpl/devpts/devpts.go
@@ -60,7 +60,7 @@ func (fstype *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Vir
}
fstype.initOnce.Do(func() {
- fs, root, err := fstype.newFilesystem(vfsObj, creds)
+ fs, root, err := fstype.newFilesystem(ctx, vfsObj, creds)
if err != nil {
fstype.initErr = err
return
@@ -93,7 +93,7 @@ type filesystem struct {
// newFilesystem creates a new devpts filesystem with root directory and ptmx
// master inode. It returns the filesystem and root Dentry.
-func (fstype *FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*filesystem, *kernfs.Dentry, error) {
+func (fstype *FilesystemType) newFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*filesystem, *kernfs.Dentry, error) {
devMinor, err := vfsObj.GetAnonBlockDevMinor()
if err != nil {
return nil, nil, err
@@ -108,7 +108,7 @@ func (fstype *FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds
root := &rootInode{
replicas: make(map[uint32]*replicaInode),
}
- root.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555)
+ root.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555)
root.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
root.EnableLeakCheck()
@@ -120,7 +120,7 @@ func (fstype *FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds
master := &masterInode{
root: root,
}
- master.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 2, linux.ModeCharacterDevice|0666)
+ master.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, devMinor, 2, linux.ModeCharacterDevice|0666)
// Add the master as a child of the root.
links := root.OrderedChildren.Populate(map[string]kernfs.Inode{
@@ -170,7 +170,7 @@ type rootInode struct {
var _ kernfs.Inode = (*rootInode)(nil)
// allocateTerminal creates a new Terminal and installs a pts node for it.
-func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error) {
+func (i *rootInode) allocateTerminal(ctx context.Context, creds *auth.Credentials) (*Terminal, error) {
i.mu.Lock()
defer i.mu.Unlock()
if i.nextIdx == math.MaxUint32 {
@@ -192,7 +192,7 @@ func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error)
}
// Linux always uses pty index + 3 as the inode id. See
// fs/devpts/inode.c:devpts_pty_new().
- replica.InodeAttrs.Init(creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600)
+ replica.InodeAttrs.Init(ctx, creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600)
i.replicas[idx] = replica
return t, nil
@@ -248,9 +248,10 @@ func (i *rootInode) Lookup(ctx context.Context, name string) (kernfs.Inode, erro
}
// IterDirents implements kernfs.Inode.IterDirents.
-func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+func (i *rootInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
i.mu.Lock()
defer i.mu.Unlock()
+ i.InodeAttrs.TouchAtime(ctx, mnt)
ids := make([]int, 0, len(i.replicas))
for id := range i.replicas {
ids = append(ids, int(id))
diff --git a/pkg/sentry/fsimpl/devpts/line_discipline.go b/pkg/sentry/fsimpl/devpts/line_discipline.go
index e6b0e81cf..ae95fdd08 100644
--- a/pkg/sentry/fsimpl/devpts/line_discipline.go
+++ b/pkg/sentry/fsimpl/devpts/line_discipline.go
@@ -100,10 +100,10 @@ type lineDiscipline struct {
column int
// masterWaiter is used to wait on the master end of the TTY.
- masterWaiter waiter.Queue `state:"zerovalue"`
+ masterWaiter waiter.Queue
// replicaWaiter is used to wait on the replica end of the TTY.
- replicaWaiter waiter.Queue `state:"zerovalue"`
+ replicaWaiter waiter.Queue
}
func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline {
diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go
index fda30fb93..e91fa26a4 100644
--- a/pkg/sentry/fsimpl/devpts/master.go
+++ b/pkg/sentry/fsimpl/devpts/master.go
@@ -50,7 +50,7 @@ var _ kernfs.Inode = (*masterInode)(nil)
// Open implements kernfs.Inode.Open.
func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- t, err := mi.root.allocateTerminal(rp.Credentials())
+ t, err := mi.root.allocateTerminal(ctx, rp.Credentials())
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/fsimpl/devtmpfs/BUILD b/pkg/sentry/fsimpl/devtmpfs/BUILD
index 01bbee5ad..e49a04c1b 100644
--- a/pkg/sentry/fsimpl/devtmpfs/BUILD
+++ b/pkg/sentry/fsimpl/devtmpfs/BUILD
@@ -4,7 +4,10 @@ licenses(["notice"])
go_library(
name = "devtmpfs",
- srcs = ["devtmpfs.go"],
+ srcs = [
+ "devtmpfs.go",
+ "save_restore.go",
+ ],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
diff --git a/pkg/sentry/fsimpl/devtmpfs/save_restore.go b/pkg/sentry/fsimpl/devtmpfs/save_restore.go
new file mode 100644
index 000000000..28832d850
--- /dev/null
+++ b/pkg/sentry/fsimpl/devtmpfs/save_restore.go
@@ -0,0 +1,23 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devtmpfs
+
+// afterLoad is invoked by stateify.
+func (fst *FilesystemType) afterLoad() {
+ if fst.fs != nil {
+ // Ensure that we don't create another filesystem.
+ fst.initOnce.Do(func() {})
+ }
+}
diff --git a/pkg/sentry/fsimpl/eventfd/eventfd.go b/pkg/sentry/fsimpl/eventfd/eventfd.go
index 1c27ad700..5b29f2358 100644
--- a/pkg/sentry/fsimpl/eventfd/eventfd.go
+++ b/pkg/sentry/fsimpl/eventfd/eventfd.go
@@ -43,7 +43,7 @@ type EventFileDescription struct {
// queue is used to notify interested parties when the event object
// becomes readable or writable.
- queue waiter.Queue `state:"zerovalue"`
+ queue waiter.Queue
// mu protects the fields below.
mu sync.Mutex `state:"nosave"`
diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD
index 045d7ab08..2158b1bbc 100644
--- a/pkg/sentry/fsimpl/fuse/BUILD
+++ b/pkg/sentry/fsimpl/fuse/BUILD
@@ -20,7 +20,7 @@ go_template_instance(
out = "inode_refs.go",
package = "fuse",
prefix = "inode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "inode",
},
@@ -49,6 +49,7 @@ go_library(
"//pkg/log",
"//pkg/marshal",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/fsimpl/devtmpfs",
"//pkg/sentry/fsimpl/kernfs",
diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go
index e39df21c6..e7ef5998e 100644
--- a/pkg/sentry/fsimpl/fuse/fusefs.go
+++ b/pkg/sentry/fsimpl/fuse/fusefs.go
@@ -205,7 +205,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
// root is the fusefs root directory.
- root := fs.newRootInode(creds, fsopts.rootMode)
+ root := fs.newRootInode(ctx, creds, fsopts.rootMode)
return fs.VFSFilesystem(), root.VFSDentry(), nil
}
@@ -284,9 +284,9 @@ type inode struct {
link string
}
-func (fs *filesystem) newRootInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry {
+func (fs *filesystem) newRootInode(ctx context.Context, creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry {
i := &inode{fs: fs, nodeID: 1}
- i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755)
+ i.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755)
i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
i.EnableLeakCheck()
@@ -295,10 +295,10 @@ func (fs *filesystem) newRootInode(creds *auth.Credentials, mode linux.FileMode)
return &d
}
-func (fs *filesystem) newInode(nodeID uint64, attr linux.FUSEAttr) kernfs.Inode {
+func (fs *filesystem) newInode(ctx context.Context, nodeID uint64, attr linux.FUSEAttr) kernfs.Inode {
i := &inode{fs: fs, nodeID: nodeID}
creds := auth.Credentials{EffectiveKGID: auth.KGID(attr.UID), EffectiveKUID: auth.KUID(attr.UID)}
- i.InodeAttrs.Init(&creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode))
+ i.InodeAttrs.Init(ctx, &creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode))
atomic.StoreUint64(&i.size, attr.Size)
i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
i.EnableLeakCheck()
@@ -424,7 +424,7 @@ func (i *inode) Keep() bool {
}
// IterDirents implements kernfs.Inode.IterDirents.
-func (*inode) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+func (*inode) IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
return offset, nil
}
@@ -544,7 +544,7 @@ func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMo
if opcode != linux.FUSE_LOOKUP && ((out.Attr.Mode&linux.S_IFMT)^uint32(fileType) != 0 || out.NodeID == 0 || out.NodeID == linux.FUSE_ROOT_ID) {
return nil, syserror.EIO
}
- child := i.fs.newInode(out.NodeID, out.Attr)
+ child := i.fs.newInode(ctx, out.NodeID, out.Attr)
return child, nil
}
@@ -696,7 +696,7 @@ func (i *inode) getAttr(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOp
}
// Set the metadata of kernfs.InodeAttrs.
- if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{
+ if err := i.InodeAttrs.SetStat(ctx, fs, creds, vfs.SetStatOptions{
Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor),
}); err != nil {
return linux.FUSEAttr{}, err
@@ -812,7 +812,7 @@ func (i *inode) setAttr(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
}
// Set the metadata of kernfs.InodeAttrs.
- if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{
+ if err := i.InodeAttrs.SetStat(ctx, fs, creds, vfs.SetStatOptions{
Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor),
}); err != nil {
return err
diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go
index 625d1547f..2d396e84c 100644
--- a/pkg/sentry/fsimpl/fuse/read_write.go
+++ b/pkg/sentry/fsimpl/fuse/read_write.go
@@ -132,7 +132,7 @@ func (fs *filesystem) ReadCallback(ctx context.Context, fd *regularFileFD, off u
// May need to update the signature.
i := fd.inode()
- // TODO(gvisor.dev/issue/1193): Invalidate or update atime.
+ i.InodeAttrs.TouchAtime(ctx, fd.vfsfd.Mount())
// Reached EOF.
if sizeRead < size {
@@ -179,6 +179,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64,
Flags: fd.statusFlags(),
}
+ inode := fd.inode()
var written uint32
// This loop is intended for fragmented write where the bytes to write is
@@ -203,7 +204,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64,
in.Offset = off + uint64(written)
in.Size = toWrite
- req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_WRITE, &in)
+ req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), inode.nodeID, linux.FUSE_WRITE, &in)
if err != nil {
return 0, err
}
@@ -237,6 +238,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64,
break
}
}
+ inode.InodeAttrs.TouchCMtime(ctx)
return written, nil
}
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
index ad0afc41b..4c3e9acf8 100644
--- a/pkg/sentry/fsimpl/gofer/BUILD
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -38,6 +38,7 @@ go_library(
"host_named_pipe.go",
"p9file.go",
"regular_file.go",
+ "save_restore.go",
"socket.go",
"special_file.go",
"symlink.go",
@@ -53,6 +54,7 @@ go_library(
"//pkg/log",
"//pkg/p9",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/lock",
@@ -70,6 +72,7 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usage",
"//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/unet",
diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go
index 18c884b59..e993c8e36 100644
--- a/pkg/sentry/fsimpl/gofer/directory.go
+++ b/pkg/sentry/fsimpl/gofer/directory.go
@@ -16,16 +16,17 @@ package gofer
import (
"fmt"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -92,7 +93,7 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) {
child := &dentry{
refs: 1, // held by d
fs: d.fs,
- ino: d.fs.nextSyntheticIno(),
+ ino: d.fs.nextIno(),
mode: uint32(opts.mode),
uid: uint32(opts.kuid),
gid: uint32(opts.kgid),
@@ -100,6 +101,9 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) {
hostFD: -1,
nlink: uint32(2),
}
+ if refsvfs2.LeakCheckEnabled() {
+ refsvfs2.Register(child, "gofer.dentry")
+ }
switch opts.mode.FileType() {
case linux.S_IFDIR:
// Nothing else needs to be done.
@@ -235,7 +239,7 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) {
}
dirent := vfs.Dirent{
Name: p9d.Name,
- Ino: uint64(inoFromPath(p9d.QID.Path)),
+ Ino: d.fs.inoFromQIDPath(p9d.QID.Path),
NextOff: int64(len(dirents) + 1),
}
// p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 94d96261b..baecb88c4 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -35,7 +35,7 @@ import (
// Sync implements vfs.FilesystemImpl.Sync.
func (fs *filesystem) Sync(ctx context.Context) error {
- // Snapshot current syncable dentries and special files.
+ // Snapshot current syncable dentries and special file FDs.
fs.syncMu.Lock()
ds := make([]*dentry, 0, len(fs.syncableDentries))
for d := range fs.syncableDentries {
@@ -53,22 +53,28 @@ func (fs *filesystem) Sync(ctx context.Context) error {
// regardless.
var retErr error
- // Sync regular files.
+ // Sync syncable dentries.
for _, d := range ds {
- err := d.syncCachedFile(ctx)
+ err := d.syncCachedFile(ctx, true /* forFilesystemSync */)
d.DecRef(ctx)
- if err != nil && retErr == nil {
- retErr = err
+ if err != nil {
+ ctx.Infof("gofer.filesystem.Sync: dentry.syncCachedFile failed: %v", err)
+ if retErr == nil {
+ retErr = err
+ }
}
}
// Sync special files, which may be writable but do not use dentry shared
// handles (so they won't be synced by the above).
for _, sffd := range sffds {
- err := sffd.Sync(ctx)
+ err := sffd.sync(ctx, true /* forFilesystemSync */)
sffd.vfsfd.DecRef(ctx)
- if err != nil && retErr == nil {
- retErr = err
+ if err != nil {
+ ctx.Infof("gofer.filesystem.Sync: specialFileFD.sync failed: %v", err)
+ if retErr == nil {
+ retErr = err
+ }
}
}
@@ -229,7 +235,7 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir
return nil, err
}
if child != nil {
- if !file.isNil() && inoFromPath(qid.Path) == child.ino {
+ if !file.isNil() && qid.Path == child.qidPath {
// The file at this path hasn't changed. Just update cached metadata.
file.close(ctx)
child.updateFromP9AttrsLocked(attrMask, &attr)
@@ -1512,7 +1518,6 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath
d.IncRef()
return &endpoint{
dentry: d,
- file: d.file.file,
path: opts.Addr,
}, nil
}
@@ -1591,7 +1596,3 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
defer fs.renameMu.RUnlock()
return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b)
}
-
-func (fs *filesystem) nextSyntheticIno() inodeNumber {
- return inodeNumber(atomic.AddUint64(&fs.syntheticSeq, 1) | syntheticInoMask)
-}
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index f1dad1b08..80668ebc1 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -26,6 +26,9 @@
// *** "memmap.Mappable locks taken by Translate" below this point
// dentry.handleMu
// dentry.dataMu
+// filesystem.inoMu
+// specialFileFD.mu
+// specialFileFD.bufMu
//
// Locking dentry.dirMu in multiple dentries requires that either ancestor
// dentries are locked before descendant dentries, or that filesystem.renameMu
@@ -36,7 +39,6 @@ import (
"fmt"
"strconv"
"strings"
- "sync"
"sync/atomic"
"syscall"
@@ -44,6 +46,8 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
+ refs_vfs1 "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -53,6 +57,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/unet"
"gvisor.dev/gvisor/pkg/usermem"
@@ -81,7 +86,7 @@ type filesystem struct {
iopts InternalFilesystemOptions
// client is the client used by this filesystem. client is immutable.
- client *p9.Client `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
+ client *p9.Client `state:"nosave"`
// clock is a realtime clock used to set timestamps in file operations.
clock ktime.Clock
@@ -89,6 +94,9 @@ type filesystem struct {
// devMinor is the filesystem's minor device number. devMinor is immutable.
devMinor uint32
+ // root is the root dentry. root is immutable.
+ root *dentry
+
// renameMu serves two purposes:
//
// - It synchronizes path resolution with renaming initiated by this
@@ -103,39 +111,35 @@ type filesystem struct {
// cachedDentries contains all dentries with 0 references. (Due to race
// conditions, it may also contain dentries with non-zero references.)
- // cachedDentriesLen is the number of dentries in cachedDentries. These
- // fields are protected by renameMu.
+ // cachedDentriesLen is the number of dentries in cachedDentries. These fields
+ // are protected by renameMu.
cachedDentries dentryList
cachedDentriesLen uint64
- // syncableDentries contains all dentries in this filesystem for which
- // !dentry.file.isNil(). specialFileFDs contains all open specialFileFDs.
- // These fields are protected by syncMu.
+ // syncableDentries contains all non-synthetic dentries. specialFileFDs
+ // contains all open specialFileFDs. These fields are protected by syncMu.
syncMu sync.Mutex `state:"nosave"`
syncableDentries map[*dentry]struct{}
specialFileFDs map[*specialFileFD]struct{}
- // syntheticSeq stores a counter to used to generate unique inodeNumber for
- // synthetic dentries.
- syntheticSeq uint64
-}
+ // inoByQIDPath maps previously-observed QID.Paths to inode numbers
+ // assigned to those paths. inoByQIDPath is not preserved across
+ // checkpoint/restore because QIDs may be reused between different gofer
+ // processes, so QIDs may be repeated for different files across
+ // checkpoint/restore. inoByQIDPath is protected by inoMu.
+ inoMu sync.Mutex `state:"nosave"`
+ inoByQIDPath map[uint64]uint64 `state:"nosave"`
-// inodeNumber represents inode number reported in Dirent.Ino. For regular
-// dentries, it comes from QID.Path from the 9P server. Synthetic dentries
-// have have their inodeNumber generated sequentially, with the MSB reserved to
-// prevent conflicts with regular dentries.
-//
-// +stateify savable
-type inodeNumber uint64
+ // lastIno is the last inode number assigned to a file. lastIno is accessed
+ // using atomic memory operations.
+ lastIno uint64
-// Reserve MSB for synthetic mounts.
-const syntheticInoMask = uint64(1) << 63
+ // savedDentryRW records open read/write handles during save/restore.
+ savedDentryRW map[*dentry]savedDentryRW
-func inoFromPath(path uint64) inodeNumber {
- if path&syntheticInoMask != 0 {
- log.Warningf("Dropping MSB from ino, collision is possible. Original: %d, new: %d", path, path&^syntheticInoMask)
- }
- return inodeNumber(path &^ syntheticInoMask)
+ // released is nonzero once filesystem.Release has been called. It is accessed
+ // with atomic memory operations.
+ released int32
}
// +stateify savable
@@ -149,8 +153,7 @@ type filesystemOptions struct {
msize uint32
version string
- // maxCachedDentries is the maximum number of dentries with 0 references
- // retained by the client.
+ // maxCachedDentries is the maximum size of filesystem.cachedDentries.
maxCachedDentries uint64
// If forcePageCache is true, host FDs may not be used for application
@@ -247,6 +250,10 @@ const (
//
// +stateify savable
type InternalFilesystemOptions struct {
+ // If UniqueID is non-empty, it is an opaque string used to reassociate the
+ // filesystem with a new server FD during restoration from checkpoint.
+ UniqueID string
+
// If LeakConnection is true, do not close the connection to the server
// when the Filesystem is released. This is necessary for deployments in
// which servers can handle only a single client and report failure if that
@@ -286,46 +293,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
mopts := vfs.GenericParseMountOptions(opts.Data)
var fsopts filesystemOptions
- // Check that the transport is "fd".
- trans, ok := mopts["trans"]
- if !ok {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: transport must be specified as 'trans=fd'")
- return nil, nil, syserror.EINVAL
- }
- delete(mopts, "trans")
- if trans != "fd" {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: unsupported transport: trans=%s", trans)
- return nil, nil, syserror.EINVAL
- }
-
- // Check that read and write FDs are provided and identical.
- rfdstr, ok := mopts["rfdno"]
- if !ok {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD must be specified as 'rfdno=<file descriptor>")
- return nil, nil, syserror.EINVAL
- }
- delete(mopts, "rfdno")
- rfd, err := strconv.Atoi(rfdstr)
- if err != nil {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid read FD: rfdno=%s", rfdstr)
- return nil, nil, syserror.EINVAL
- }
- wfdstr, ok := mopts["wfdno"]
- if !ok {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: write FD must be specified as 'wfdno=<file descriptor>")
- return nil, nil, syserror.EINVAL
- }
- delete(mopts, "wfdno")
- wfd, err := strconv.Atoi(wfdstr)
+ fd, err := getFDFromMountOptionsMap(ctx, mopts)
if err != nil {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid write FD: wfdno=%s", wfdstr)
- return nil, nil, syserror.EINVAL
- }
- if rfd != wfd {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD (%d) and write FD (%d) must be equal", rfd, wfd)
- return nil, nil, syserror.EINVAL
+ return nil, nil, err
}
- fsopts.fd = rfd
+ fsopts.fd = fd
// Get the attach name.
fsopts.aname = "/"
@@ -441,57 +413,44 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
// If !ok, iopts being the zero value is correct.
- // Establish a connection with the server.
- conn, err := unet.NewSocket(fsopts.fd)
+ // Construct the filesystem object.
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
if err != nil {
return nil, nil, err
}
+ fs := &filesystem{
+ mfp: mfp,
+ opts: fsopts,
+ iopts: iopts,
+ clock: ktime.RealtimeClockFromContext(ctx),
+ devMinor: devMinor,
+ syncableDentries: make(map[*dentry]struct{}),
+ specialFileFDs: make(map[*specialFileFD]struct{}),
+ inoByQIDPath: make(map[uint64]uint64),
+ }
+ fs.vfsfs.Init(vfsObj, &fstype, fs)
- // Perform version negotiation with the server.
- ctx.UninterruptibleSleepStart(false)
- client, err := p9.NewClient(conn, fsopts.msize, fsopts.version)
- ctx.UninterruptibleSleepFinish(false)
- if err != nil {
- conn.Close()
+ // Connect to the server.
+ if err := fs.dial(ctx); err != nil {
return nil, nil, err
}
- // Ownership of conn has been transferred to client.
// Perform attach to obtain the filesystem root.
ctx.UninterruptibleSleepStart(false)
- attached, err := client.Attach(fsopts.aname)
+ attached, err := fs.client.Attach(fsopts.aname)
ctx.UninterruptibleSleepFinish(false)
if err != nil {
- client.Close()
+ fs.vfsfs.DecRef(ctx)
return nil, nil, err
}
attachFile := p9file{attached}
qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask())
if err != nil {
attachFile.close(ctx)
- client.Close()
+ fs.vfsfs.DecRef(ctx)
return nil, nil, err
}
- // Construct the filesystem object.
- devMinor, err := vfsObj.GetAnonBlockDevMinor()
- if err != nil {
- attachFile.close(ctx)
- client.Close()
- return nil, nil, err
- }
- fs := &filesystem{
- mfp: mfp,
- opts: fsopts,
- iopts: iopts,
- client: client,
- clock: ktime.RealtimeClockFromContext(ctx),
- devMinor: devMinor,
- syncableDentries: make(map[*dentry]struct{}),
- specialFileFDs: make(map[*specialFileFD]struct{}),
- }
- fs.vfsfs.Init(vfsObj, &fstype, fs)
-
// Construct the root dentry.
root, err := fs.newDentry(ctx, attachFile, qid, attrMask, &attr)
if err != nil {
@@ -500,25 +459,87 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
return nil, nil, err
}
// Set the root's reference count to 2. One reference is returned to the
- // caller, and the other is deliberately leaked to prevent the root from
- // being "cached" and subsequently evicted. Its resources will still be
- // cleaned up by fs.Release().
+ // caller, and the other is held by fs to prevent the root from being "cached"
+ // and subsequently evicted.
root.refs = 2
+ fs.root = root
return &fs.vfsfs, &root.vfsd, nil
}
+func getFDFromMountOptionsMap(ctx context.Context, mopts map[string]string) (int, error) {
+ // Check that the transport is "fd".
+ trans, ok := mopts["trans"]
+ if !ok || trans != "fd" {
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: transport must be specified as 'trans=fd'")
+ return -1, syserror.EINVAL
+ }
+ delete(mopts, "trans")
+
+ // Check that read and write FDs are provided and identical.
+ rfdstr, ok := mopts["rfdno"]
+ if !ok {
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD must be specified as 'rfdno=<file descriptor>'")
+ return -1, syserror.EINVAL
+ }
+ delete(mopts, "rfdno")
+ rfd, err := strconv.Atoi(rfdstr)
+ if err != nil {
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid read FD: rfdno=%s", rfdstr)
+ return -1, syserror.EINVAL
+ }
+ wfdstr, ok := mopts["wfdno"]
+ if !ok {
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: write FD must be specified as 'wfdno=<file descriptor>'")
+ return -1, syserror.EINVAL
+ }
+ delete(mopts, "wfdno")
+ wfd, err := strconv.Atoi(wfdstr)
+ if err != nil {
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid write FD: wfdno=%s", wfdstr)
+ return -1, syserror.EINVAL
+ }
+ if rfd != wfd {
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD (%d) and write FD (%d) must be equal", rfd, wfd)
+ return -1, syserror.EINVAL
+ }
+ return rfd, nil
+}
+
+// Preconditions: fs.client == nil.
+func (fs *filesystem) dial(ctx context.Context) error {
+ // Establish a connection with the server.
+ conn, err := unet.NewSocket(fs.opts.fd)
+ if err != nil {
+ return err
+ }
+
+ // Perform version negotiation with the server.
+ ctx.UninterruptibleSleepStart(false)
+ client, err := p9.NewClient(conn, fs.opts.msize, fs.opts.version)
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ conn.Close()
+ return err
+ }
+ // Ownership of conn has been transferred to client.
+
+ fs.client = client
+ return nil
+}
+
// Release implements vfs.FilesystemImpl.Release.
func (fs *filesystem) Release(ctx context.Context) {
- mf := fs.mfp.MemoryFile()
+ atomic.StoreInt32(&fs.released, 1)
+ mf := fs.mfp.MemoryFile()
fs.syncMu.Lock()
for d := range fs.syncableDentries {
d.handleMu.Lock()
d.dataMu.Lock()
if h := d.writeHandleLocked(); h.isOpen() {
// Write dirty cached data to the remote file.
- if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, fs.mfp.MemoryFile(), h.writeFromBlocksAt); err != nil {
+ if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, h.writeFromBlocksAt); err != nil {
log.Warningf("gofer.filesystem.Release: failed to flush dentry: %v", err)
}
// TODO(jamieliu): Do we need to flushf/fsync d?
@@ -539,6 +560,21 @@ func (fs *filesystem) Release(ctx context.Context) {
// fs.
fs.syncMu.Unlock()
+ // If leak checking is enabled, release all outstanding references in the
+ // filesystem. We deliberately avoid doing this outside of leak checking; we
+ // have released all external resources above rather than relying on dentry
+ // destructors.
+ if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking {
+ fs.renameMu.Lock()
+ fs.root.releaseSyntheticRecursiveLocked(ctx)
+ fs.evictAllCachedDentriesLocked(ctx)
+ fs.renameMu.Unlock()
+
+ // An extra reference was held by the filesystem on the root to prevent it from
+ // being cached/evicted.
+ fs.root.DecRef(ctx)
+ }
+
if !fs.iopts.LeakConnection {
// Close the connection to the server. This implicitly clunks all fids.
fs.client.Close()
@@ -547,6 +583,31 @@ func (fs *filesystem) Release(ctx context.Context) {
fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
}
+// releaseSyntheticRecursiveLocked traverses the tree with root d and decrements
+// the reference count on every synthetic dentry. Synthetic dentries have one
+// reference for existence that should be dropped during filesystem.Release.
+//
+// Precondition: d.fs.renameMu is locked.
+func (d *dentry) releaseSyntheticRecursiveLocked(ctx context.Context) {
+ if d.isSynthetic() {
+ d.decRefLocked()
+ d.checkCachingLocked(ctx)
+ }
+ if d.isDir() {
+ var children []*dentry
+ d.dirMu.Lock()
+ for _, child := range d.children {
+ children = append(children, child)
+ }
+ d.dirMu.Unlock()
+ for _, child := range children {
+ if child != nil {
+ child.releaseSyntheticRecursiveLocked(ctx)
+ }
+ }
+ }
+}
+
// dentry implements vfs.DentryImpl.
//
// +stateify savable
@@ -574,12 +635,15 @@ type dentry struct {
// filesystem.renameMu.
name string
+ // qidPath is the p9.QID.Path for this file. qidPath is immutable.
+ qidPath uint64
+
// file is the unopened p9.File that backs this dentry. file is immutable.
//
// If file.isNil(), this dentry represents a synthetic file, i.e. a file
// that does not exist on the remote filesystem. As of this writing, the
// only files that can be synthetic are sockets, pipes, and directories.
- file p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
+ file p9file `state:"nosave"`
// If deleted is non-zero, the file represented by this dentry has been
// deleted. deleted is accessed using atomic memory operations.
@@ -623,12 +687,12 @@ type dentry struct {
// To mutate:
// - Lock metadataMu and use atomic operations to update because we might
// have atomic readers that don't hold the lock.
- metadataMu sync.Mutex `state:"nosave"`
- ino inodeNumber // immutable
- mode uint32 // type is immutable, perms are mutable
- uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
- gid uint32 // auth.KGID, but ...
- blockSize uint32 // 0 if unknown
+ metadataMu sync.Mutex `state:"nosave"`
+ ino uint64 // immutable
+ mode uint32 // type is immutable, perms are mutable
+ uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
+ gid uint32 // auth.KGID, but ...
+ blockSize uint32 // 0 if unknown
// Timestamps, all nsecs from the Unix epoch.
atime int64
mtime int64
@@ -679,9 +743,9 @@ type dentry struct {
// (isNil() == false), it may be mutated with handleMu locked, but cannot
// be closed until the dentry is destroyed.
handleMu sync.RWMutex `state:"nosave"`
- readFile p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
- writeFile p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
- hostFD int32
+ readFile p9file `state:"nosave"`
+ writeFile p9file `state:"nosave"`
+ hostFD int32 `state:"nosave"`
dataMu sync.RWMutex `state:"nosave"`
@@ -758,8 +822,9 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma
d := &dentry{
fs: fs,
+ qidPath: qid.Path,
file: file,
- ino: inoFromPath(qid.Path),
+ ino: fs.inoFromQIDPath(qid.Path),
mode: uint32(attr.Mode),
uid: uint32(fs.opts.dfltuid),
gid: uint32(fs.opts.dfltgid),
@@ -795,6 +860,9 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma
d.nlink = uint32(attr.NLink)
}
d.vfsd.Init(d)
+ if refsvfs2.LeakCheckEnabled() {
+ refsvfs2.Register(d, "gofer.dentry")
+ }
fs.syncMu.Lock()
fs.syncableDentries[d] = struct{}{}
@@ -802,6 +870,21 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma
return d, nil
}
+func (fs *filesystem) inoFromQIDPath(qidPath uint64) uint64 {
+ fs.inoMu.Lock()
+ defer fs.inoMu.Unlock()
+ if ino, ok := fs.inoByQIDPath[qidPath]; ok {
+ return ino
+ }
+ ino := fs.nextIno()
+ fs.inoByQIDPath[qidPath] = ino
+ return ino
+}
+
+func (fs *filesystem) nextIno() uint64 {
+ return atomic.AddUint64(&fs.lastIno, 1)
+}
+
func (d *dentry) isSynthetic() bool {
return d.file.isNil()
}
@@ -853,7 +936,7 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) {
}
}
-// Preconditions: !d.isSynthetic()
+// Preconditions: !d.isSynthetic().
func (d *dentry) updateFromGetattr(ctx context.Context) error {
// Use d.readFile or d.writeFile, which represent 9P fids that have been
// opened, in preference to d.file, which represents a 9P fid that has not.
@@ -916,10 +999,10 @@ func (d *dentry) statTo(stat *linux.Statx) {
// This is consistent with regularFileFD.Seek(), which treats regular files
// as having no holes.
stat.Blocks = (stat.Size + 511) / 512
- stat.Atime = statxTimestampFromDentry(atomic.LoadInt64(&d.atime))
- stat.Btime = statxTimestampFromDentry(atomic.LoadInt64(&d.btime))
- stat.Ctime = statxTimestampFromDentry(atomic.LoadInt64(&d.ctime))
- stat.Mtime = statxTimestampFromDentry(atomic.LoadInt64(&d.mtime))
+ stat.Atime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.atime))
+ stat.Btime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.btime))
+ stat.Ctime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.ctime))
+ stat.Mtime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.mtime))
stat.DevMajor = linux.UNNAMED_MAJOR
stat.DevMinor = d.fs.devMinor
}
@@ -967,10 +1050,10 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
// Use client clocks for timestamps.
now = d.fs.clock.Now().Nanoseconds()
if stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec == linux.UTIME_NOW {
- stat.Atime = statxTimestampFromDentry(now)
+ stat.Atime = linux.NsecToStatxTimestamp(now)
}
if stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec == linux.UTIME_NOW {
- stat.Mtime = statxTimestampFromDentry(now)
+ stat.Mtime = linux.NsecToStatxTimestamp(now)
}
}
@@ -1029,11 +1112,11 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
// !d.cachedMetadataAuthoritative() then we returned after calling
// d.file.setAttr(). For the same reason, now must have been initialized.
if stat.Mask&linux.STATX_ATIME != 0 {
- atomic.StoreInt64(&d.atime, dentryTimestampFromStatx(stat.Atime))
+ atomic.StoreInt64(&d.atime, stat.Atime.ToNsec())
atomic.StoreUint32(&d.atimeDirty, 0)
}
if stat.Mask&linux.STATX_MTIME != 0 {
- atomic.StoreInt64(&d.mtime, dentryTimestampFromStatx(stat.Mtime))
+ atomic.StoreInt64(&d.mtime, stat.Mtime.ToNsec())
atomic.StoreUint32(&d.mtimeDirty, 0)
}
atomic.StoreInt64(&d.ctime, now)
@@ -1175,6 +1258,11 @@ func (d *dentry) decRefLocked() {
}
}
+// LeakMessage implements refsvfs2.CheckedObject.LeakMessage.
+func (d *dentry) LeakMessage() string {
+ return fmt.Sprintf("[gofer.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs))
+}
+
// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
func (d *dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) {
if d.isDir() {
@@ -1223,6 +1311,10 @@ func (d *dentry) checkCachingLocked(ctx context.Context) {
// resolution, which requires renameMu, so if d.refs is zero then it will
// remain zero while we hold renameMu for writing.)
refs := atomic.LoadInt64(&d.refs)
+ if refs == -1 {
+ // Dentry has already been destroyed.
+ return
+ }
if refs > 0 {
if d.cached {
d.fs.cachedDentries.Remove(d)
@@ -1231,10 +1323,6 @@ func (d *dentry) checkCachingLocked(ctx context.Context) {
}
return
}
- if refs == -1 {
- // Dentry has already been destroyed.
- return
- }
// Deleted and invalidated dentries with zero references are no longer
// reachable by path resolution and should be dropped immediately.
if d.vfsd.IsDead() {
@@ -1257,6 +1345,16 @@ func (d *dentry) checkCachingLocked(ctx context.Context) {
if d.watches.Size() > 0 {
return
}
+
+ if atomic.LoadInt32(&d.fs.released) != 0 {
+ if d.parent != nil {
+ d.parent.dirMu.Lock()
+ delete(d.parent.children, d.name)
+ d.parent.dirMu.Unlock()
+ }
+ d.destroyLocked(ctx)
+ }
+
// If d is already cached, just move it to the front of the LRU.
if d.cached {
d.fs.cachedDentries.Remove(d)
@@ -1269,33 +1367,48 @@ func (d *dentry) checkCachingLocked(ctx context.Context) {
d.fs.cachedDentriesLen++
d.cached = true
if d.fs.cachedDentriesLen > d.fs.opts.maxCachedDentries {
- victim := d.fs.cachedDentries.Back()
- d.fs.cachedDentries.Remove(victim)
- d.fs.cachedDentriesLen--
- victim.cached = false
- // victim.refs may have become non-zero from an earlier path resolution
- // since it was inserted into fs.cachedDentries.
- if atomic.LoadInt64(&victim.refs) == 0 {
- if victim.parent != nil {
- victim.parent.dirMu.Lock()
- if !victim.vfsd.IsDead() {
- // Note that victim can't be a mount point (in any mount
- // namespace), since VFS holds references on mount points.
- d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd)
- delete(victim.parent.children, victim.name)
- // We're only deleting the dentry, not the file it
- // represents, so we don't need to update
- // victimParent.dirents etc.
- }
- victim.parent.dirMu.Unlock()
- }
- victim.destroyLocked(ctx)
- }
+ d.fs.evictCachedDentryLocked(ctx)
// Whether or not victim was destroyed, we brought fs.cachedDentriesLen
// back down to fs.opts.maxCachedDentries, so we don't loop.
}
}
+// Precondition: fs.renameMu must be locked for writing; it may be temporarily
+// unlocked.
+func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) {
+ for fs.cachedDentriesLen != 0 {
+ fs.evictCachedDentryLocked(ctx)
+ }
+}
+
+// Preconditions:
+// * fs.renameMu must be locked for writing; it may be temporarily unlocked.
+// * fs.cachedDentriesLen != 0.
+func (fs *filesystem) evictCachedDentryLocked(ctx context.Context) {
+ victim := fs.cachedDentries.Back()
+ fs.cachedDentries.Remove(victim)
+ fs.cachedDentriesLen--
+ victim.cached = false
+ // victim.refs may have become non-zero from an earlier path resolution
+ // since it was inserted into fs.cachedDentries.
+ if atomic.LoadInt64(&victim.refs) == 0 {
+ if victim.parent != nil {
+ victim.parent.dirMu.Lock()
+ if !victim.vfsd.IsDead() {
+ // Note that victim can't be a mount point (in any mount
+ // namespace), since VFS holds references on mount points.
+ fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd)
+ delete(victim.parent.children, victim.name)
+ // We're only deleting the dentry, not the file it
+ // represents, so we don't need to update
+ // victimParent.dirents etc.
+ }
+ victim.parent.dirMu.Unlock()
+ }
+ victim.destroyLocked(ctx)
+ }
+}
+
// destroyLocked destroys the dentry.
//
// Preconditions:
@@ -1380,6 +1493,10 @@ func (d *dentry) destroyLocked(ctx context.Context) {
panic("gofer.dentry.DecRef() called without holding a reference")
}
}
+
+ if refsvfs2.LeakCheckEnabled() {
+ refsvfs2.Unregister(d, "gofer.dentry")
+ }
}
func (d *dentry) isDeleted() bool {
@@ -1623,6 +1740,33 @@ func (d *dentry) syncRemoteFileLocked(ctx context.Context) error {
return nil
}
+func (d *dentry) syncCachedFile(ctx context.Context, forFilesystemSync bool) error {
+ d.handleMu.RLock()
+ defer d.handleMu.RUnlock()
+ h := d.writeHandleLocked()
+ if h.isOpen() {
+ // Write back dirty pages to the remote file.
+ d.dataMu.Lock()
+ err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), h.writeFromBlocksAt)
+ d.dataMu.Unlock()
+ if err != nil {
+ return err
+ }
+ }
+ if err := d.syncRemoteFileLocked(ctx); err != nil {
+ if !forFilesystemSync {
+ return err
+ }
+ // Only return err if we can reasonably have expected sync to succeed
+ // (d is a regular file and was opened for writing).
+ if d.isRegularFile() && h.isOpen() {
+ return err
+ }
+ ctx.Debugf("gofer.dentry.syncCachedFile: syncing non-writable or non-regular-file dentry failed: %v", err)
+ }
+ return nil
+}
+
// incLinks increments link count.
func (d *dentry) incLinks() {
if atomic.LoadUint32(&d.nlink) == 0 {
@@ -1650,7 +1794,7 @@ type fileDescription struct {
vfs.FileDescriptionDefaultImpl
vfs.LockFD
- lockLogging sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
+ lockLogging sync.Once `state:"nosave"`
}
func (fd *fileDescription) filesystem() *filesystem {
diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go
index bfe75dfe4..76f08e252 100644
--- a/pkg/sentry/fsimpl/gofer/gofer_test.go
+++ b/pkg/sentry/fsimpl/gofer/gofer_test.go
@@ -26,12 +26,13 @@ import (
func TestDestroyIdempotent(t *testing.T) {
ctx := contexttest.Context(t)
fs := filesystem{
- mfp: pgalloc.MemoryFileProviderFromContext(ctx),
- syncableDentries: make(map[*dentry]struct{}),
+ mfp: pgalloc.MemoryFileProviderFromContext(ctx),
opts: filesystemOptions{
// Test relies on no dentry being held in the cache.
maxCachedDentries: 0,
},
+ syncableDentries: make(map[*dentry]struct{}),
+ inoByQIDPath: make(map[uint64]uint64),
}
attr := &p9.Attr{
diff --git a/pkg/sentry/fsimpl/gofer/host_named_pipe.go b/pkg/sentry/fsimpl/gofer/host_named_pipe.go
index 7294de7d6..c7bf10007 100644
--- a/pkg/sentry/fsimpl/gofer/host_named_pipe.go
+++ b/pkg/sentry/fsimpl/gofer/host_named_pipe.go
@@ -51,8 +51,24 @@ func blockUntilNonblockingPipeHasWriter(ctx context.Context, fd int32) error {
if ok {
return nil
}
- if err := sleepBetweenNamedPipeOpenChecks(ctx); err != nil {
- return err
+ if sleepErr := sleepBetweenNamedPipeOpenChecks(ctx); sleepErr != nil {
+ // Another application thread may have opened this pipe for
+ // writing, succeeded because we previously opened the pipe for
+ // reading, and subsequently interrupted us for checkpointing (e.g.
+ // this occurs in mknod tests under cooperative save/restore). In
+ // this case, our open has to succeed for the checkpoint to include
+ // a readable FD for the pipe, which is in turn necessary to
+ // restore the other thread's writable FD for the same pipe
+ // (otherwise it will get ENXIO). So we have to check
+ // nonblockingPipeHasWriter() once last time.
+ ok, err := nonblockingPipeHasWriter(fd)
+ if err != nil {
+ return err
+ }
+ if ok {
+ return nil
+ }
+ return sleepErr
}
}
}
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index f8b19bae7..dc8a890cb 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -18,7 +18,6 @@ import (
"fmt"
"io"
"math"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -31,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -624,23 +624,7 @@ func regularFileSeekLocked(ctx context.Context, d *dentry, fdOffset, offset int6
// Sync implements vfs.FileDescriptionImpl.Sync.
func (fd *regularFileFD) Sync(ctx context.Context) error {
- return fd.dentry().syncCachedFile(ctx)
-}
-
-func (d *dentry) syncCachedFile(ctx context.Context) error {
- d.handleMu.RLock()
- defer d.handleMu.RUnlock()
-
- if h := d.writeHandleLocked(); h.isOpen() {
- d.dataMu.Lock()
- // Write dirty cached data to the remote file.
- err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), h.writeFromBlocksAt)
- d.dataMu.Unlock()
- if err != nil {
- return err
- }
- }
- return d.syncRemoteFileLocked(ctx)
+ return fd.dentry().syncCachedFile(ctx, false /* lowSyncExpectations */)
}
// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
@@ -913,7 +897,7 @@ type dentryPlatformFile struct {
hostFileMapper fsutil.HostFileMapper
// hostFileMapperInitOnce is used to lazily initialize hostFileMapper.
- hostFileMapperInitOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
+ hostFileMapperInitOnce sync.Once `state:"nosave"`
}
// IncRef implements memmap.File.IncRef.
diff --git a/pkg/sentry/fsimpl/gofer/save_restore.go b/pkg/sentry/fsimpl/gofer/save_restore.go
new file mode 100644
index 000000000..2ea224c43
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/save_restore.go
@@ -0,0 +1,329 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+ "io"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type saveRestoreContextID int
+
+const (
+ // CtxRestoreServerFDMap is a Context.Value key for a map[string]int
+ // mapping filesystem unique IDs (cf. InternalFilesystemOptions.UniqueID)
+ // to host FDs.
+ CtxRestoreServerFDMap saveRestoreContextID = iota
+)
+
+// +stateify savable
+type savedDentryRW struct {
+ read bool
+ write bool
+}
+
+// PreprareSave implements vfs.FilesystemImplSaveRestoreExtension.PrepareSave.
+func (fs *filesystem) PrepareSave(ctx context.Context) error {
+ if len(fs.iopts.UniqueID) == 0 {
+ return fmt.Errorf("gofer.filesystem with no UniqueID cannot be saved")
+ }
+
+ // Purge cached dentries, which may not be reopenable after restore due to
+ // permission changes.
+ fs.renameMu.Lock()
+ fs.evictAllCachedDentriesLocked(ctx)
+ fs.renameMu.Unlock()
+
+ // Buffer pipe data so that it's available for reading after restore. (This
+ // is a legacy VFS1 feature.)
+ fs.syncMu.Lock()
+ for sffd := range fs.specialFileFDs {
+ if sffd.dentry().fileType() == linux.S_IFIFO && sffd.vfsfd.IsReadable() {
+ if err := sffd.savePipeData(ctx); err != nil {
+ fs.syncMu.Unlock()
+ return err
+ }
+ }
+ }
+ fs.syncMu.Unlock()
+
+ // Flush local state to the remote filesystem.
+ if err := fs.Sync(ctx); err != nil {
+ return err
+ }
+
+ fs.savedDentryRW = make(map[*dentry]savedDentryRW)
+ return fs.root.prepareSaveRecursive(ctx)
+}
+
+// Preconditions:
+// * fd represents a pipe.
+// * fd is readable.
+func (fd *specialFileFD) savePipeData(ctx context.Context) error {
+ fd.bufMu.Lock()
+ defer fd.bufMu.Unlock()
+ var buf [usermem.PageSize]byte
+ for {
+ n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:])), ^uint64(0))
+ if n != 0 {
+ fd.buf = append(fd.buf, buf[:n]...)
+ }
+ if err != nil {
+ if err == io.EOF || err == syserror.EAGAIN {
+ break
+ }
+ return err
+ }
+ }
+ if len(fd.buf) != 0 {
+ atomic.StoreUint32(&fd.haveBuf, 1)
+ }
+ return nil
+}
+
+func (d *dentry) prepareSaveRecursive(ctx context.Context) error {
+ if d.isRegularFile() && !d.cachedMetadataAuthoritative() {
+ // Get updated metadata for d in case we need to perform metadata
+ // validation during restore.
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return err
+ }
+ }
+ if !d.readFile.isNil() || !d.writeFile.isNil() {
+ d.fs.savedDentryRW[d] = savedDentryRW{
+ read: !d.readFile.isNil(),
+ write: !d.writeFile.isNil(),
+ }
+ }
+ d.dirMu.Lock()
+ defer d.dirMu.Unlock()
+ for _, child := range d.children {
+ if child != nil {
+ if err := child.prepareSaveRecursive(ctx); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// beforeSave is invoked by stateify.
+func (d *dentry) beforeSave() {
+ if d.vfsd.IsDead() {
+ panic(fmt.Sprintf("gofer.dentry(%q).beforeSave: deleted and invalidated dentries can't be restored", genericDebugPathname(d)))
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (d *dentry) afterLoad() {
+ d.hostFD = -1
+ if refsvfs2.LeakCheckEnabled() && atomic.LoadInt64(&d.refs) != -1 {
+ refsvfs2.Register(d, "gofer.dentry")
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (d *dentryPlatformFile) afterLoad() {
+ if d.hostFileMapper.IsInited() {
+ // Ensure that we don't call d.hostFileMapper.Init() again.
+ d.hostFileMapperInitOnce.Do(func() {})
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (fd *specialFileFD) afterLoad() {
+ fd.handle.fd = -1
+}
+
+// CompleteRestore implements
+// vfs.FilesystemImplSaveRestoreExtension.CompleteRestore.
+func (fs *filesystem) CompleteRestore(ctx context.Context, opts vfs.CompleteRestoreOptions) error {
+ fdmapv := ctx.Value(CtxRestoreServerFDMap)
+ if fdmapv == nil {
+ return fmt.Errorf("no server FD map available")
+ }
+ fdmap := fdmapv.(map[string]int)
+ fd, ok := fdmap[fs.iopts.UniqueID]
+ if !ok {
+ return fmt.Errorf("no server FD available for filesystem with unique ID %q", fs.iopts.UniqueID)
+ }
+ fs.opts.fd = fd
+ if err := fs.dial(ctx); err != nil {
+ return err
+ }
+ fs.inoByQIDPath = make(map[uint64]uint64)
+
+ // Restore the filesystem root.
+ ctx.UninterruptibleSleepStart(false)
+ attached, err := fs.client.Attach(fs.opts.aname)
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ return err
+ }
+ attachFile := p9file{attached}
+ qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask())
+ if err != nil {
+ return err
+ }
+ if err := fs.root.restoreFile(ctx, attachFile, qid, attrMask, &attr, &opts); err != nil {
+ return err
+ }
+
+ // Restore remaining dentries.
+ if err := fs.root.restoreDescendantsRecursive(ctx, &opts); err != nil {
+ return err
+ }
+
+ // Re-open handles for specialFileFDs. Unlike the initial open
+ // (dentry.openSpecialFile()), pipes are always opened without blocking;
+ // non-readable pipe FDs are opened last to ensure that they don't get
+ // ENXIO if another specialFileFD represents the read end of the same pipe.
+ // This is consistent with VFS1.
+ haveWriteOnlyPipes := false
+ for fd := range fs.specialFileFDs {
+ if fd.dentry().fileType() == linux.S_IFIFO && !fd.vfsfd.IsReadable() {
+ haveWriteOnlyPipes = true
+ continue
+ }
+ if err := fd.completeRestore(ctx); err != nil {
+ return err
+ }
+ }
+ if haveWriteOnlyPipes {
+ for fd := range fs.specialFileFDs {
+ if fd.dentry().fileType() == linux.S_IFIFO && !fd.vfsfd.IsReadable() {
+ if err := fd.completeRestore(ctx); err != nil {
+ return err
+ }
+ }
+ }
+ }
+
+ // Discard state only required during restore.
+ fs.savedDentryRW = nil
+
+ return nil
+}
+
+func (d *dentry) restoreFile(ctx context.Context, file p9file, qid p9.QID, attrMask p9.AttrMask, attr *p9.Attr, opts *vfs.CompleteRestoreOptions) error {
+ d.file = file
+
+ // Gofers do not preserve QID across checkpoint/restore, so:
+ //
+ // - We must assume that the remote filesystem did not change in a way that
+ // would invalidate dentries, since we can't revalidate dentries by
+ // checking QIDs.
+ //
+ // - We need to associate the new QID.Path with the existing d.ino.
+ d.qidPath = qid.Path
+ d.fs.inoMu.Lock()
+ d.fs.inoByQIDPath[qid.Path] = d.ino
+ d.fs.inoMu.Unlock()
+
+ // Check metadata stability before updating metadata.
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ if d.isRegularFile() {
+ if opts.ValidateFileSizes {
+ if !attrMask.Size {
+ return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: file size not available", genericDebugPathname(d))
+ }
+ if d.size != attr.Size {
+ return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: size changed from %d to %d", genericDebugPathname(d), d.size, attr.Size)
+ }
+ }
+ if opts.ValidateFileModificationTimestamps {
+ if !attrMask.MTime {
+ return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime not available", genericDebugPathname(d))
+ }
+ if want := dentryTimestampFromP9(attr.MTimeSeconds, attr.MTimeNanoSeconds); d.mtime != want {
+ return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime changed from %+v to %+v", genericDebugPathname(d), linux.NsecToStatxTimestamp(d.mtime), linux.NsecToStatxTimestamp(want))
+ }
+ }
+ }
+ if !d.cachedMetadataAuthoritative() {
+ d.updateFromP9AttrsLocked(attrMask, attr)
+ }
+
+ if rw, ok := d.fs.savedDentryRW[d]; ok {
+ if err := d.ensureSharedHandle(ctx, rw.read, rw.write, false /* trunc */); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// Preconditions: d is not synthetic.
+func (d *dentry) restoreDescendantsRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error {
+ for _, child := range d.children {
+ if child == nil {
+ continue
+ }
+ if _, ok := d.fs.syncableDentries[child]; !ok {
+ // child is synthetic.
+ continue
+ }
+ if err := child.restoreRecursive(ctx, opts); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// Preconditions: d is not synthetic (but note that since this function
+// restores d.file, d.file.isNil() is always true at this point, so this can
+// only be detected by checking filesystem.syncableDentries). d.parent has been
+// restored.
+func (d *dentry) restoreRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error {
+ qid, file, attrMask, attr, err := d.parent.file.walkGetAttrOne(ctx, d.name)
+ if err != nil {
+ return err
+ }
+ if err := d.restoreFile(ctx, file, qid, attrMask, &attr, opts); err != nil {
+ return err
+ }
+ return d.restoreDescendantsRecursive(ctx, opts)
+}
+
+func (fd *specialFileFD) completeRestore(ctx context.Context) error {
+ d := fd.dentry()
+ h, err := openHandle(ctx, d.file, fd.vfsfd.IsReadable(), fd.vfsfd.IsWritable(), false /* trunc */)
+ if err != nil {
+ return err
+ }
+ fd.handle = h
+
+ ftype := d.fileType()
+ fd.haveQueue = (ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK) && fd.handle.fd >= 0
+ if fd.haveQueue {
+ if err := fdnotifier.AddFD(fd.handle.fd, &fd.queue); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go
index 326b940a7..a21199eac 100644
--- a/pkg/sentry/fsimpl/gofer/socket.go
+++ b/pkg/sentry/fsimpl/gofer/socket.go
@@ -42,9 +42,6 @@ type endpoint struct {
// dentry is the filesystem dentry which produced this endpoint.
dentry *dentry
- // file is the p9 file that contains a single unopened fid.
- file p9.File `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
-
// path is the sentry path where this endpoint is bound.
path string
}
@@ -116,7 +113,7 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect
}
func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFlags, queue *waiter.Queue) (*host.SCMConnectedEndpoint, *syserr.Error) {
- hostFile, err := e.file.Connect(flags)
+ hostFile, err := e.dentry.file.connect(ctx, flags)
if err != nil {
return nil, syserr.ErrConnectionRefused
}
@@ -131,7 +128,7 @@ func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFla
c, serr := host.NewSCMEndpoint(ctx, hostFD, queue, e.path)
if serr != nil {
- log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.file, flags, serr)
+ log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.dentry.file, flags, serr)
return nil, serr
}
return c, nil
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
index 71581736c..625400c0b 100644
--- a/pkg/sentry/fsimpl/gofer/special_file.go
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -15,7 +15,6 @@
package gofer
import (
- "sync"
"sync/atomic"
"syscall"
@@ -25,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
@@ -40,7 +40,7 @@ type specialFileFD struct {
fileDescription
// handle is used for file I/O. handle is immutable.
- handle handle `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
+ handle handle `state:"nosave"`
// isRegularFile is true if this FD represents a regular file which is only
// possible when filesystemOptions.regularFilesUseSpecialFileFD is in
@@ -54,12 +54,20 @@ type specialFileFD struct {
// haveQueue is true if this file description represents a file for which
// queue may send I/O readiness events. haveQueue is immutable.
- haveQueue bool
+ haveQueue bool `state:"nosave"`
queue waiter.Queue
// If seekable is true, off is the file offset. off is protected by mu.
mu sync.Mutex `state:"nosave"`
off int64
+
+ // If haveBuf is non-zero, this FD represents a pipe, and buf contains data
+ // read from the pipe from previous calls to specialFileFD.savePipeData().
+ // haveBuf and buf are protected by bufMu. haveBuf is accessed using atomic
+ // memory operations.
+ bufMu sync.Mutex `state:"nosave"`
+ haveBuf uint32
+ buf []byte
}
func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, flags uint32) (*specialFileFD, error) {
@@ -87,6 +95,9 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks,
}
return nil, err
}
+ d.fs.syncMu.Lock()
+ d.fs.specialFileFDs[fd] = struct{}{}
+ d.fs.syncMu.Unlock()
return fd, nil
}
@@ -161,26 +172,51 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs
return 0, syserror.EOPNOTSUPP
}
- // Going through dst.CopyOutFrom() holds MM locks around file operations of
- // unknown duration. For regularFileFD, doing so is necessary to support
- // mmap due to lock ordering; MM locks precede dentry.dataMu. That doesn't
- // hold here since specialFileFD doesn't client-cache data. Just buffer the
- // read instead.
if d := fd.dentry(); d.cachedMetadataAuthoritative() {
d.touchAtime(fd.vfsfd.Mount())
}
+
+ bufN := int64(0)
+ if atomic.LoadUint32(&fd.haveBuf) != 0 {
+ var err error
+ fd.bufMu.Lock()
+ if len(fd.buf) != 0 {
+ var n int
+ n, err = dst.CopyOut(ctx, fd.buf)
+ dst = dst.DropFirst(n)
+ fd.buf = fd.buf[n:]
+ if len(fd.buf) == 0 {
+ atomic.StoreUint32(&fd.haveBuf, 0)
+ fd.buf = nil
+ }
+ bufN = int64(n)
+ if offset >= 0 {
+ offset += bufN
+ }
+ }
+ fd.bufMu.Unlock()
+ if err != nil {
+ return bufN, err
+ }
+ }
+
+ // Going through dst.CopyOutFrom() would hold MM locks around file
+ // operations of unknown duration. For regularFileFD, doing so is necessary
+ // to support mmap due to lock ordering; MM locks precede dentry.dataMu.
+ // That doesn't hold here since specialFileFD doesn't client-cache data.
+ // Just buffer the read instead.
buf := make([]byte, dst.NumBytes())
n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset))
if err == syserror.EAGAIN {
err = syserror.ErrWouldBlock
}
if n == 0 {
- return 0, err
+ return bufN, err
}
if cp, cperr := dst.CopyOut(ctx, buf[:n]); cperr != nil {
- return int64(cp), cperr
+ return bufN + int64(cp), cperr
}
- return int64(n), err
+ return bufN + int64(n), err
}
// Read implements vfs.FileDescriptionImpl.Read.
@@ -217,16 +253,16 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off
}
d := fd.dentry()
- // If the regular file fd was opened with O_APPEND, make sure the file size
- // is updated. There is a possible race here if size is modified externally
- // after metadata cache is updated.
- if fd.isRegularFile && fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() {
- if err := d.updateFromGetattr(ctx); err != nil {
- return 0, offset, err
+ if fd.isRegularFile {
+ // If the regular file fd was opened with O_APPEND, make sure the file
+ // size is updated. There is a possible race here if size is modified
+ // externally after metadata cache is updated.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() {
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return 0, offset, err
+ }
}
- }
- if fd.isRegularFile {
// We need to hold the metadataMu *while* writing to a regular file.
d.metadataMu.Lock()
defer d.metadataMu.Unlock()
@@ -306,13 +342,31 @@ func (fd *specialFileFD) Seek(ctx context.Context, offset int64, whence int32) (
// Sync implements vfs.FileDescriptionImpl.Sync.
func (fd *specialFileFD) Sync(ctx context.Context) error {
- // If we have a host FD, fsyncing it is likely to be faster than an fsync
- // RPC.
- if fd.handle.fd >= 0 {
- ctx.UninterruptibleSleepStart(false)
- err := syscall.Fsync(int(fd.handle.fd))
- ctx.UninterruptibleSleepFinish(false)
- return err
+ return fd.sync(ctx, false /* forFilesystemSync */)
+}
+
+func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool) error {
+ err := func() error {
+ // If we have a host FD, fsyncing it is likely to be faster than an fsync
+ // RPC.
+ if fd.handle.fd >= 0 {
+ ctx.UninterruptibleSleepStart(false)
+ err := syscall.Fsync(int(fd.handle.fd))
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+ }
+ return fd.handle.file.fsync(ctx)
+ }()
+ if err != nil {
+ if !forFilesystemSync {
+ return err
+ }
+ // Only return err if we can reasonably have expected sync to succeed
+ // (fd represents a regular file that was opened for writing).
+ if fd.isRegularFile && fd.vfsfd.IsWritable() {
+ return err
+ }
+ ctx.Debugf("gofer.specialFileFD.sync: syncing non-writable or non-regular-file FD failed: %v", err)
}
- return fd.handle.file.fsync(ctx)
+ return nil
}
diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go
index 7e825caae..9cbe805b9 100644
--- a/pkg/sentry/fsimpl/gofer/time.go
+++ b/pkg/sentry/fsimpl/gofer/time.go
@@ -17,7 +17,6 @@ package gofer
import (
"sync/atomic"
- "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/vfs"
)
@@ -25,17 +24,6 @@ func dentryTimestampFromP9(s, ns uint64) int64 {
return int64(s*1e9 + ns)
}
-func dentryTimestampFromStatx(ts linux.StatxTimestamp) int64 {
- return ts.Sec*1e9 + int64(ts.Nsec)
-}
-
-func statxTimestampFromDentry(ns int64) linux.StatxTimestamp {
- return linux.StatxTimestamp{
- Sec: ns / 1e9,
- Nsec: uint32(ns % 1e9),
- }
-}
-
// Preconditions: d.cachedMetadataAuthoritative() == true.
func (d *dentry) touchAtime(mnt *vfs.Mount) {
if mnt.Flags.NoATime || mnt.ReadOnly() {
diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD
index 56bcf9bdb..dc0f86061 100644
--- a/pkg/sentry/fsimpl/host/BUILD
+++ b/pkg/sentry/fsimpl/host/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "inode_refs.go",
package = "host",
prefix = "inode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "inode",
},
@@ -19,7 +19,7 @@ go_template_instance(
out = "connected_endpoint_refs.go",
package = "host",
prefix = "ConnectedEndpoint",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "ConnectedEndpoint",
},
@@ -34,6 +34,7 @@ go_library(
"inode_refs.go",
"ioctl_unsafe.go",
"mmap.go",
+ "save_restore.go",
"socket.go",
"socket_iovec.go",
"socket_unsafe.go",
@@ -51,6 +52,7 @@ go_library(
"//pkg/log",
"//pkg/marshal/primitive",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/fs/fsutil",
diff --git a/pkg/sentry/fsimpl/host/control.go b/pkg/sentry/fsimpl/host/control.go
index 0135e4428..13ef48cb5 100644
--- a/pkg/sentry/fsimpl/host/control.go
+++ b/pkg/sentry/fsimpl/host/control.go
@@ -79,7 +79,7 @@ func fdsToFiles(ctx context.Context, fds []int) []*vfs.FileDescription {
}
// Create the file backed by hostFD.
- file, err := ImportFD(ctx, kernel.KernelFromContext(ctx).HostMount(), fd, false /* isTTY */)
+ file, err := NewFD(ctx, kernel.KernelFromContext(ctx).HostMount(), fd, &NewFDOptions{})
if err != nil {
ctx.Warningf("Error creating file from host FD: %v", err)
break
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index 698e913fe..eeed0f97d 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -19,6 +19,7 @@ package host
import (
"fmt"
"math"
+ "sync/atomic"
"syscall"
"golang.org/x/sys/unix"
@@ -40,34 +41,106 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-func newInode(fs *filesystem, hostFD int, fileType linux.FileMode, isTTY bool) (*inode, error) {
- // Determine if hostFD is seekable. If not, this syscall will return ESPIPE
- // (see fs/read_write.c:llseek), e.g. for pipes, sockets, and some character
- // devices.
+// inode implements kernfs.Inode.
+//
+// +stateify savable
+type inode struct {
+ kernfs.InodeNoStatFS
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+ kernfs.InodeTemporary // This holds no meaning as this inode can't be Looked up and is always valid.
+
+ locks vfs.FileLocks
+
+ // When the reference count reaches zero, the host fd is closed.
+ inodeRefs
+
+ // hostFD contains the host fd that this file was originally created from,
+ // which must be available at time of restore.
+ //
+ // This field is initialized at creation time and is immutable.
+ hostFD int
+
+ // ino is an inode number unique within this filesystem.
+ //
+ // This field is initialized at creation time and is immutable.
+ ino uint64
+
+ // ftype is the file's type (a linux.S_IFMT mask).
+ //
+ // This field is initialized at creation time and is immutable.
+ ftype uint16
+
+ // mayBlock is true if hostFD is non-blocking, and operations on it may
+ // return EAGAIN or EWOULDBLOCK instead of blocking.
+ //
+ // This field is initialized at creation time and is immutable.
+ mayBlock bool
+
+ // seekable is false if lseek(hostFD) returns ESPIPE. We assume that file
+ // offsets are meaningful iff seekable is true.
+ //
+ // This field is initialized at creation time and is immutable.
+ seekable bool
+
+ // isTTY is true if this file represents a TTY.
+ //
+ // This field is initialized at creation time and is immutable.
+ isTTY bool
+
+ // savable is true if hostFD may be saved/restored by its numeric value.
+ //
+ // This field is initialized at creation time and is immutable.
+ savable bool
+
+ // Event queue for blocking operations.
+ queue waiter.Queue
+
+ // mapsMu protects mappings.
+ mapsMu sync.Mutex `state:"nosave"`
+
+ // If this file is mmappable, mappings tracks mappings of hostFD into
+ // memmap.MappingSpaces.
+ mappings memmap.MappingSet
+
+ // pf implements platform.File for mappings of hostFD.
+ pf inodePlatformFile
+
+ // If haveBuf is non-zero, hostFD represents a pipe, and buf contains data
+ // read from the pipe from previous calls to inode.beforeSave(). haveBuf
+ // and buf are protected by bufMu. haveBuf is accessed using atomic memory
+ // operations.
+ bufMu sync.Mutex `state:"nosave"`
+ haveBuf uint32
+ buf []byte
+}
+
+func newInode(ctx context.Context, fs *filesystem, hostFD int, savable bool, fileType linux.FileMode, isTTY bool) (*inode, error) {
+ // Determine if hostFD is seekable.
_, err := unix.Seek(hostFD, 0, linux.SEEK_CUR)
seekable := err != syserror.ESPIPE
+ // We expect regular files to be seekable, as this is required for them to
+ // be memory-mappable.
+ if !seekable && fileType == syscall.S_IFREG {
+ ctx.Infof("host.newInode: host FD %d is a non-seekable regular file", hostFD)
+ return nil, syserror.ESPIPE
+ }
i := &inode{
- hostFD: hostFD,
- ino: fs.NextIno(),
- isTTY: isTTY,
- wouldBlock: wouldBlock(uint32(fileType)),
- seekable: seekable,
- // NOTE(b/38213152): Technically, some obscure char devices can be memory
- // mapped, but we only allow regular files.
- canMap: fileType == linux.S_IFREG,
+ hostFD: hostFD,
+ ino: fs.NextIno(),
+ ftype: uint16(fileType),
+ mayBlock: fileType != syscall.S_IFREG && fileType != syscall.S_IFDIR,
+ seekable: seekable,
+ isTTY: isTTY,
+ savable: savable,
}
i.pf.inode = i
i.EnableLeakCheck()
- // Non-seekable files can't be memory mapped, assert this.
- if !i.seekable && i.canMap {
- panic("files that can return EWOULDBLOCK (sockets, pipes, etc.) cannot be memory mapped")
- }
-
- // If the hostFD would block, we must set it to non-blocking and handle
- // blocking behavior in the sentry.
- if i.wouldBlock {
+ // If the hostFD can return EWOULDBLOCK when set to non-blocking, do so and
+ // handle blocking behavior in the sentry.
+ if i.mayBlock {
if err := syscall.SetNonblock(i.hostFD, true); err != nil {
return nil, err
}
@@ -80,6 +153,11 @@ func newInode(fs *filesystem, hostFD int, fileType linux.FileMode, isTTY bool) (
// NewFDOptions contains options to NewFD.
type NewFDOptions struct {
+ // If Savable is true, the host file descriptor may be saved/restored by
+ // numeric value; the sandbox API requires a corresponding host FD with the
+ // same numeric value to be provieded at time of restore.
+ Savable bool
+
// If IsTTY is true, the file descriptor is a TTY.
IsTTY bool
@@ -114,7 +192,7 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions)
}
d := &kernfs.Dentry{}
- i, err := newInode(fs, hostFD, linux.FileMode(s.Mode).FileType(), opts.IsTTY)
+ i, err := newInode(ctx, fs, hostFD, opts.Savable, linux.FileMode(s.Mode).FileType(), opts.IsTTY)
if err != nil {
return nil, err
}
@@ -132,7 +210,8 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions)
// ImportFD sets up and returns a vfs.FileDescription from a donated fd.
func ImportFD(ctx context.Context, mnt *vfs.Mount, hostFD int, isTTY bool) (*vfs.FileDescription, error) {
return NewFD(ctx, mnt, hostFD, &NewFDOptions{
- IsTTY: isTTY,
+ Savable: true,
+ IsTTY: isTTY,
})
}
@@ -191,68 +270,6 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
return vfs.PrependPathSyntheticError{}
}
-// inode implements kernfs.Inode.
-//
-// +stateify savable
-type inode struct {
- kernfs.InodeNoStatFS
- kernfs.InodeNotDirectory
- kernfs.InodeNotSymlink
- kernfs.InodeTemporary // This holds no meaning as this inode can't be Looked up and is always valid.
-
- locks vfs.FileLocks
-
- // When the reference count reaches zero, the host fd is closed.
- inodeRefs
-
- // hostFD contains the host fd that this file was originally created from,
- // which must be available at time of restore.
- //
- // This field is initialized at creation time and is immutable.
- hostFD int
-
- // ino is an inode number unique within this filesystem.
- //
- // This field is initialized at creation time and is immutable.
- ino uint64
-
- // isTTY is true if this file represents a TTY.
- //
- // This field is initialized at creation time and is immutable.
- isTTY bool
-
- // seekable is false if the host fd points to a file representing a stream,
- // e.g. a socket or a pipe. Such files are not seekable and can return
- // EWOULDBLOCK for I/O operations.
- //
- // This field is initialized at creation time and is immutable.
- seekable bool
-
- // wouldBlock is true if the host FD would return EWOULDBLOCK for
- // operations that would block.
- //
- // This field is initialized at creation time and is immutable.
- wouldBlock bool
-
- // Event queue for blocking operations.
- queue waiter.Queue
-
- // canMap specifies whether we allow the file to be memory mapped.
- //
- // This field is initialized at creation time and is immutable.
- canMap bool
-
- // mapsMu protects mappings.
- mapsMu sync.Mutex `state:"nosave"`
-
- // If canMap is true, mappings tracks mappings of hostFD into
- // memmap.MappingSpaces.
- mappings memmap.MappingSet
-
- // pf implements platform.File for mappings of hostFD.
- pf inodePlatformFile
-}
-
// CheckPermissions implements kernfs.Inode.CheckPermissions.
func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
var s syscall.Stat_t
@@ -448,7 +465,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
// DecRef implements kernfs.Inode.DecRef.
func (i *inode) DecRef(ctx context.Context) {
i.inodeRefs.DecRef(func() {
- if i.wouldBlock {
+ if i.mayBlock {
fdnotifier.RemoveFD(int32(i.hostFD))
}
if err := unix.Close(i.hostFD); err != nil {
@@ -567,6 +584,13 @@ func (f *fileDescription) Allocate(ctx context.Context, mode, offset, length uin
// PRead implements vfs.FileDescriptionImpl.PRead.
func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ // Check that flags are supported.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
+ if opts.Flags&^linux.RWF_HIPRI != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
i := f.inode
if !i.seekable {
return 0, syserror.ESPIPE
@@ -577,19 +601,31 @@ func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, off
// Read implements vfs.FileDescriptionImpl.Read.
func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // Check that flags are supported.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
+ if opts.Flags&^linux.RWF_HIPRI != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
i := f.inode
if !i.seekable {
+ bufN, err := i.readFromBuf(ctx, &dst)
+ if err != nil {
+ return bufN, err
+ }
n, err := readFromHostFD(ctx, i.hostFD, dst, -1, opts.Flags)
+ total := bufN + n
if isBlockError(err) {
// If we got any data at all, return it as a "completed" partial read
// rather than retrying until complete.
- if n != 0 {
+ if total != 0 {
err = nil
} else {
err = syserror.ErrWouldBlock
}
}
- return n, err
+ return total, err
}
f.offsetMu.Lock()
@@ -599,13 +635,26 @@ func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts
return n, err
}
-func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags uint32) (int64, error) {
- // Check that flags are supported.
- //
- // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
- if flags&^linux.RWF_HIPRI != 0 {
- return 0, syserror.EOPNOTSUPP
+func (i *inode) readFromBuf(ctx context.Context, dst *usermem.IOSequence) (int64, error) {
+ if atomic.LoadUint32(&i.haveBuf) == 0 {
+ return 0, nil
}
+ i.bufMu.Lock()
+ defer i.bufMu.Unlock()
+ if len(i.buf) == 0 {
+ return 0, nil
+ }
+ n, err := dst.CopyOut(ctx, i.buf)
+ *dst = dst.DropFirst(n)
+ i.buf = i.buf[n:]
+ if len(i.buf) == 0 {
+ atomic.StoreUint32(&i.haveBuf, 0)
+ i.buf = nil
+ }
+ return int64(n), err
+}
+
+func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags uint32) (int64, error) {
reader := hostfd.GetReadWriterAt(int32(hostFD), offset, flags)
n, err := dst.CopyOutFrom(ctx, reader)
hostfd.PutReadWriterAt(reader)
@@ -735,14 +784,16 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i
}
// Sync implements vfs.FileDescriptionImpl.Sync.
-func (f *fileDescription) Sync(context.Context) error {
+func (f *fileDescription) Sync(ctx context.Context) error {
// TODO(gvisor.dev/issue/1897): Currently, we always sync everything.
return unix.Fsync(f.inode.hostFD)
}
// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
func (f *fileDescription) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts) error {
- if !f.inode.canMap {
+ // NOTE(b/38213152): Technically, some obscure char devices can be memory
+ // mapped, but we only allow regular files.
+ if f.inode.ftype != syscall.S_IFREG {
return syserror.ENODEV
}
i := f.inode
@@ -753,13 +804,17 @@ func (f *fileDescription) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts
// EventRegister implements waiter.Waitable.EventRegister.
func (f *fileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
f.inode.queue.EventRegister(e, mask)
- fdnotifier.UpdateFD(int32(f.inode.hostFD))
+ if f.inode.mayBlock {
+ fdnotifier.UpdateFD(int32(f.inode.hostFD))
+ }
}
// EventUnregister implements waiter.Waitable.EventUnregister.
func (f *fileDescription) EventUnregister(e *waiter.Entry) {
f.inode.queue.EventUnregister(e)
- fdnotifier.UpdateFD(int32(f.inode.hostFD))
+ if f.inode.mayBlock {
+ fdnotifier.UpdateFD(int32(f.inode.hostFD))
+ }
}
// Readiness uses the poll() syscall to check the status of the underlying FD.
diff --git a/pkg/sentry/fsimpl/host/mmap.go b/pkg/sentry/fsimpl/host/mmap.go
index b51a17bed..3d7eb2f96 100644
--- a/pkg/sentry/fsimpl/host/mmap.go
+++ b/pkg/sentry/fsimpl/host/mmap.go
@@ -43,7 +43,7 @@ type inodePlatformFile struct {
fileMapper fsutil.HostFileMapper
// fileMapperInitOnce is used to lazily initialize fileMapper.
- fileMapperInitOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported.
+ fileMapperInitOnce sync.Once `state:"nosave"`
}
// IncRef implements memmap.File.IncRef.
diff --git a/pkg/sentry/fsimpl/host/save_restore.go b/pkg/sentry/fsimpl/host/save_restore.go
new file mode 100644
index 000000000..7e32a8863
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/save_restore.go
@@ -0,0 +1,78 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "fmt"
+ "io"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/hostfd"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// beforeSave is invoked by stateify.
+func (i *inode) beforeSave() {
+ if !i.savable {
+ panic("host.inode is not savable")
+ }
+ if i.ftype == syscall.S_IFIFO {
+ // If this pipe FD is readable, drain it so that bytes in the pipe can
+ // be read after restore. (This is a legacy VFS1 feature.) We don't
+ // know if the pipe FD is readable, so just try reading and tolerate
+ // EBADF from the read.
+ i.bufMu.Lock()
+ defer i.bufMu.Unlock()
+ var buf [usermem.PageSize]byte
+ for {
+ n, err := hostfd.Preadv2(int32(i.hostFD), safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:])), -1 /* offset */, 0 /* flags */)
+ if n != 0 {
+ i.buf = append(i.buf, buf[:n]...)
+ }
+ if err != nil {
+ if err == io.EOF || err == syscall.EAGAIN || err == syscall.EBADF {
+ break
+ }
+ panic(fmt.Errorf("host.inode.beforeSave: buffering from pipe failed: %v", err))
+ }
+ }
+ if len(i.buf) != 0 {
+ atomic.StoreUint32(&i.haveBuf, 1)
+ }
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (i *inode) afterLoad() {
+ if i.mayBlock {
+ if err := syscall.SetNonblock(i.hostFD, true); err != nil {
+ panic(fmt.Sprintf("host.inode.afterLoad: failed to set host FD %d non-blocking: %v", i.hostFD, err))
+ }
+ if err := fdnotifier.AddFD(int32(i.hostFD), &i.queue); err != nil {
+ panic(fmt.Sprintf("host.inode.afterLoad: fdnotifier.AddFD(%d) failed: %v", i.hostFD, err))
+ }
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (i *inodePlatformFile) afterLoad() {
+ if i.fileMapper.IsInited() {
+ // Ensure that we don't call i.fileMapper.Init() again.
+ i.fileMapperInitOnce.Do(func() {})
+ }
+}
diff --git a/pkg/sentry/fsimpl/host/util.go b/pkg/sentry/fsimpl/host/util.go
index 412bdb2eb..b2f43a119 100644
--- a/pkg/sentry/fsimpl/host/util.go
+++ b/pkg/sentry/fsimpl/host/util.go
@@ -43,12 +43,6 @@ func timespecToStatxTimestamp(ts unix.Timespec) linux.StatxTimestamp {
return linux.StatxTimestamp{Sec: int64(ts.Sec), Nsec: uint32(ts.Nsec)}
}
-// wouldBlock returns true for file types that can return EWOULDBLOCK
-// for blocking operations, e.g. pipes, character devices, and sockets.
-func wouldBlock(fileType uint32) bool {
- return fileType == syscall.S_IFIFO || fileType == syscall.S_IFCHR || fileType == syscall.S_IFSOCK
-}
-
// isBlockError checks if an error is EAGAIN or EWOULDBLOCK.
// If so, they can be transformed into syserror.ErrWouldBlock.
func isBlockError(err error) bool {
diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD
index 858cc24ce..aaad67ab8 100644
--- a/pkg/sentry/fsimpl/kernfs/BUILD
+++ b/pkg/sentry/fsimpl/kernfs/BUILD
@@ -4,6 +4,18 @@ load("//tools/go_generics:defs.bzl", "go_template_instance")
licenses(["notice"])
go_template_instance(
+ name = "dentry_list",
+ out = "dentry_list.go",
+ package = "kernfs",
+ prefix = "dentry",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*Dentry",
+ "Linker": "*Dentry",
+ },
+)
+
+go_template_instance(
name = "fstree",
out = "fstree.go",
package = "kernfs",
@@ -27,22 +39,11 @@ go_template_instance(
)
go_template_instance(
- name = "dentry_refs",
- out = "dentry_refs.go",
- package = "kernfs",
- prefix = "Dentry",
- template = "//pkg/refs_vfs2:refs_template",
- types = {
- "T": "Dentry",
- },
-)
-
-go_template_instance(
name = "static_directory_refs",
out = "static_directory_refs.go",
package = "kernfs",
prefix = "StaticDirectory",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "StaticDirectory",
},
@@ -53,7 +54,7 @@ go_template_instance(
out = "dir_refs.go",
package = "kernfs_test",
prefix = "dir",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "dir",
},
@@ -64,7 +65,7 @@ go_template_instance(
out = "readonly_dir_refs.go",
package = "kernfs_test",
prefix = "readonlyDir",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "readonlyDir",
},
@@ -75,7 +76,7 @@ go_template_instance(
out = "synthetic_directory_refs.go",
package = "kernfs",
prefix = "syntheticDirectory",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "syntheticDirectory",
},
@@ -84,7 +85,7 @@ go_template_instance(
go_library(
name = "kernfs",
srcs = [
- "dentry_refs.go",
+ "dentry_list.go",
"dynamic_bytes_file.go",
"fd_impl_util.go",
"filesystem.go",
@@ -104,8 +105,11 @@ go_library(
"//pkg/fspath",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
+ "//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/lock",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
"//pkg/sentry/memmap",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
@@ -129,6 +133,7 @@ go_test(
"//pkg/context",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/sentry/contexttest",
"//pkg/sentry/fsimpl/testutil",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
index b929118b1..485504995 100644
--- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
+++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
@@ -47,11 +47,11 @@ type DynamicBytesFile struct {
var _ Inode = (*DynamicBytesFile)(nil)
// Init initializes a dynamic bytes file.
-func (f *DynamicBytesFile) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) {
+func (f *DynamicBytesFile) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) {
if perm&^linux.PermissionsMask != 0 {
panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask))
}
- f.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm)
+ f.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm)
f.data = data
}
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
index abf1905d6..f8dae22f8 100644
--- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -145,8 +145,12 @@ func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem {
return fd.vfsfd.VirtualDentry().Mount().Filesystem()
}
+func (fd *GenericDirectoryFD) dentry() *Dentry {
+ return fd.vfsfd.Dentry().Impl().(*Dentry)
+}
+
func (fd *GenericDirectoryFD) inode() Inode {
- return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode
+ return fd.dentry().inode
}
// IterDirents implements vfs.FileDescriptionImpl.IterDirents. IterDirents holds
@@ -176,8 +180,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
// Handle "..".
if fd.off == 1 {
- vfsd := fd.vfsfd.VirtualDentry().Dentry()
- parentInode := genericParentOrSelf(vfsd.Impl().(*Dentry)).inode
+ parentInode := genericParentOrSelf(fd.dentry()).inode
stat, err := parentInode.Stat(ctx, fd.filesystem(), opts)
if err != nil {
return err
@@ -219,7 +222,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
var err error
relOffset := fd.off - int64(len(fd.children.set)) - 2
- fd.off, err = fd.inode().IterDirents(ctx, cb, fd.off, relOffset)
+ fd.off, err = fd.inode().IterDirents(ctx, fd.vfsfd.Mount(), cb, fd.off, relOffset)
return err
}
@@ -265,8 +268,7 @@ func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (l
// SetStat implements vfs.FileDescriptionImpl.SetStat.
func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
creds := auth.CredentialsFromContext(ctx)
- inode := fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode
- return inode.SetStat(ctx, fd.filesystem(), creds, opts)
+ return fd.inode().SetStat(ctx, fd.filesystem(), creds, opts)
}
// Allocate implements vfs.FileDescriptionImpl.Allocate.
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index 6426a55f6..399895f3e 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -373,7 +373,7 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
if !opts.ForSyntheticMountpoint || err == syserror.EEXIST {
return err
}
- childI = newSyntheticDirectory(rp.Credentials(), opts.Mode)
+ childI = newSyntheticDirectory(ctx, rp.Credentials(), opts.Mode)
}
var child Dentry
child.Init(fs, childI)
@@ -517,9 +517,6 @@ afterTrailingSymlink:
}
var child Dentry
child.Init(fs, childI)
- // FIXME(gvisor.dev/issue/1193): Race between checking existence with
- // fs.stepExistingLocked and parent.insertChild. If possible, we should hold
- // dirMu from one to the other.
parent.insertChild(pc, &child)
// Open may block so we need to unlock fs.mu. IncRef child to prevent
// its destruction while fs.mu is unlocked.
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index 122b10591..d9d76758a 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -21,9 +21,11 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// InodeNoopRefCount partially implements the Inode interface, specifically the
@@ -143,7 +145,7 @@ func (InodeNotDirectory) Lookup(ctx context.Context, name string) (Inode, error)
}
// IterDirents implements Inode.IterDirents.
-func (InodeNotDirectory) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
+func (InodeNotDirectory) IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
panic("IterDirents called on non-directory inode")
}
@@ -172,17 +174,23 @@ func (InodeNotSymlink) Getlink(context.Context, *vfs.Mount) (vfs.VirtualDentry,
//
// +stateify savable
type InodeAttrs struct {
- devMajor uint32
- devMinor uint32
- ino uint64
- mode uint32
- uid uint32
- gid uint32
- nlink uint32
+ devMajor uint32
+ devMinor uint32
+ ino uint64
+ mode uint32
+ uid uint32
+ gid uint32
+ nlink uint32
+ blockSize uint32
+
+ // Timestamps, all nsecs from the Unix epoch.
+ atime int64
+ mtime int64
+ ctime int64
}
// Init initializes this InodeAttrs.
-func (a *InodeAttrs) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, mode linux.FileMode) {
+func (a *InodeAttrs) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, mode linux.FileMode) {
if mode.FileType() == 0 {
panic(fmt.Sprintf("No file type specified in 'mode' for InodeAttrs.Init(): mode=0%o", mode))
}
@@ -198,6 +206,11 @@ func (a *InodeAttrs) Init(creds *auth.Credentials, devMajor, devMinor uint32, in
atomic.StoreUint32(&a.uid, uint32(creds.EffectiveKUID))
atomic.StoreUint32(&a.gid, uint32(creds.EffectiveKGID))
atomic.StoreUint32(&a.nlink, nlink)
+ atomic.StoreUint32(&a.blockSize, usermem.PageSize)
+ now := ktime.NowFromContext(ctx).Nanoseconds()
+ atomic.StoreInt64(&a.atime, now)
+ atomic.StoreInt64(&a.mtime, now)
+ atomic.StoreInt64(&a.ctime, now)
}
// DevMajor returns the device major number.
@@ -220,12 +233,33 @@ func (a *InodeAttrs) Mode() linux.FileMode {
return linux.FileMode(atomic.LoadUint32(&a.mode))
}
+// TouchAtime updates a.atime to the current time.
+func (a *InodeAttrs) TouchAtime(ctx context.Context, mnt *vfs.Mount) {
+ if mnt.Flags.NoATime || mnt.ReadOnly() {
+ return
+ }
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return
+ }
+ atomic.StoreInt64(&a.atime, ktime.NowFromContext(ctx).Nanoseconds())
+ mnt.EndWrite()
+}
+
+// TouchCMtime updates a.{c/m}time to the current time. The caller should
+// synchronize calls to this so that ctime and mtime are updated to the same
+// value.
+func (a *InodeAttrs) TouchCMtime(ctx context.Context) {
+ now := ktime.NowFromContext(ctx).Nanoseconds()
+ atomic.StoreInt64(&a.mtime, now)
+ atomic.StoreInt64(&a.ctime, now)
+}
+
// Stat partially implements Inode.Stat. Note that this function doesn't provide
// all the stat fields, and the embedder should consider extending the result
// with filesystem-specific fields.
func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) {
var stat linux.Statx
- stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK
+ stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME
stat.DevMajor = a.devMajor
stat.DevMinor = a.devMinor
stat.Ino = atomic.LoadUint64(&a.ino)
@@ -233,21 +267,15 @@ func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (li
stat.UID = atomic.LoadUint32(&a.uid)
stat.GID = atomic.LoadUint32(&a.gid)
stat.Nlink = atomic.LoadUint32(&a.nlink)
-
- // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps.
-
+ stat.Blksize = atomic.LoadUint32(&a.blockSize)
+ stat.Atime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.atime))
+ stat.Mtime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.mtime))
+ stat.Ctime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.ctime))
return stat, nil
}
// SetStat implements Inode.SetStat.
func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
- return a.SetInodeStat(ctx, fs, creds, opts)
-}
-
-// SetInodeStat sets the corresponding attributes from opts to InodeAttrs.
-// This function can be used by other kernfs-based filesystem implementation to
-// sets the unexported attributes into InodeAttrs.
-func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
if opts.Stat.Mask == 0 {
return nil
}
@@ -256,9 +284,7 @@ func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds
// inode numbers are immutable after node creation. Setting the size is often
// allowed by kernfs files but does not do anything. If some other behavior is
// needed, the embedder should consider extending SetStat.
- //
- // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps.
- if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_SIZE) != 0 {
+ if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_SIZE) != 0 {
return syserror.EPERM
}
if opts.Stat.Mask&linux.STATX_SIZE != 0 && a.Mode().IsDir() {
@@ -286,6 +312,20 @@ func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds
atomic.StoreUint32(&a.gid, stat.GID)
}
+ now := ktime.NowFromContext(ctx).Nanoseconds()
+ if stat.Mask&linux.STATX_ATIME != 0 {
+ if stat.Atime.Nsec == linux.UTIME_NOW {
+ stat.Atime = linux.NsecToStatxTimestamp(now)
+ }
+ atomic.StoreInt64(&a.atime, stat.Atime.ToNsec())
+ }
+ if stat.Mask&linux.STATX_MTIME != 0 {
+ if stat.Mtime.Nsec == linux.UTIME_NOW {
+ stat.Mtime = linux.NsecToStatxTimestamp(now)
+ }
+ atomic.StoreInt64(&a.mtime, stat.Mtime.ToNsec())
+ }
+
return nil
}
@@ -421,7 +461,7 @@ func (o *OrderedChildren) Lookup(ctx context.Context, name string) (Inode, error
}
// IterDirents implements Inode.IterDirents.
-func (o *OrderedChildren) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
+func (o *OrderedChildren) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
// All entries from OrderedChildren have already been handled in
// GenericDirectoryFD.IterDirents.
return offset, nil
@@ -619,9 +659,9 @@ type StaticDirectory struct {
var _ Inode = (*StaticDirectory)(nil)
// NewStaticDir creates a new static directory and returns its dentry.
-func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]Inode, fdOpts GenericDirectoryFDOptions) Inode {
+func NewStaticDir(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]Inode, fdOpts GenericDirectoryFDOptions) Inode {
inode := &StaticDirectory{}
- inode.Init(creds, devMajor, devMinor, ino, perm, fdOpts)
+ inode.Init(ctx, creds, devMajor, devMinor, ino, perm, fdOpts)
inode.EnableLeakCheck()
inode.OrderedChildren.Init(OrderedChildrenOptions{})
@@ -632,12 +672,12 @@ func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64
}
// Init initializes StaticDirectory.
-func (s *StaticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, fdOpts GenericDirectoryFDOptions) {
+func (s *StaticDirectory) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, fdOpts GenericDirectoryFDOptions) {
if perm&^linux.PermissionsMask != 0 {
panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask))
}
s.fdOpts = fdOpts
- s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeDirectory|perm)
+ s.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeDirectory|perm)
}
// Open implements Inode.Open.
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index 606081e68..5c5e09ac5 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -107,6 +107,17 @@ type Filesystem struct {
// nextInoMinusOne is used to to allocate inode numbers on this
// filesystem. Must be accessed by atomic operations.
nextInoMinusOne uint64
+
+ // cachedDentries contains all dentries with 0 references. (Due to race
+ // conditions, it may also contain dentries with non-zero references.)
+ // cachedDentriesLen is the number of dentries in cachedDentries. These
+ // fields are protected by mu.
+ cachedDentries dentryList
+ cachedDentriesLen uint64
+
+ // MaxCachedDentries is the maximum size of cachedDentries. If not set,
+ // defaults to 0 and kernfs does not cache any dentries. This is immutable.
+ MaxCachedDentries uint64
}
// deferDecRef defers dropping a dentry ref until the next call to
@@ -165,7 +176,12 @@ const (
// +stateify savable
type Dentry struct {
vfsd vfs.Dentry
- DentryRefs
+
+ // refs is the reference count. When refs reaches 0, the dentry may be
+ // added to the cache or destroyed. If refs == -1, the dentry has already
+ // been destroyed. refs are allowed to go to 0 and increase again. refs is
+ // accessed using atomic memory operations.
+ refs int64
// fs is the owning filesystem. fs is immutable.
fs *Filesystem
@@ -177,6 +193,12 @@ type Dentry struct {
parent *Dentry
name string
+ // If cached is true, dentryEntry links dentry into
+ // Filesystem.cachedDentries. cached and dentryEntry are protected by
+ // Filesystem.mu.
+ cached bool
+ dentryEntry
+
// dirMu protects children and the names of child Dentries.
//
// Note that holding fs.mu for writing is not sufficient;
@@ -188,6 +210,150 @@ type Dentry struct {
inode Inode
}
+// IncRef implements vfs.DentryImpl.IncRef.
+func (d *Dentry) IncRef() {
+ // d.refs may be 0 if d.fs.mu is locked, which serializes against
+ // d.cacheLocked().
+ atomic.AddInt64(&d.refs, 1)
+}
+
+// TryIncRef implements vfs.DentryImpl.TryIncRef.
+func (d *Dentry) TryIncRef() bool {
+ for {
+ refs := atomic.LoadInt64(&d.refs)
+ if refs <= 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) {
+ return true
+ }
+ }
+}
+
+// DecRef implements vfs.DentryImpl.DecRef.
+func (d *Dentry) DecRef(ctx context.Context) {
+ if refs := atomic.AddInt64(&d.refs, -1); refs == 0 {
+ d.fs.mu.Lock()
+ d.cacheLocked(ctx)
+ d.fs.mu.Unlock()
+ } else if refs < 0 {
+ panic("kernfs.Dentry.DecRef() called without holding a reference")
+ }
+}
+
+// cacheLocked should be called after d's reference count becomes 0. The ref
+// count check may happen before acquiring d.fs.mu so there might be a race
+// condition where the ref count is increased again by the time the caller
+// acquires d.fs.mu. This race is handled.
+// Only reachable dentries are added to the cache. However, a dentry might
+// become unreachable *while* it is in the cache due to invalidation.
+//
+// Preconditions: d.fs.mu must be locked for writing.
+func (d *Dentry) cacheLocked(ctx context.Context) {
+ // Dentries with a non-zero reference count must be retained. (The only way
+ // to obtain a reference on a dentry with zero references is via path
+ // resolution, which requires d.fs.mu, so if d.refs is zero then it will
+ // remain zero while we hold d.fs.mu for writing.)
+ refs := atomic.LoadInt64(&d.refs)
+ if refs == -1 {
+ // Dentry has already been destroyed.
+ panic(fmt.Sprintf("cacheLocked called on a dentry which has already been destroyed: %v", d))
+ }
+ if refs > 0 {
+ if d.cached {
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentriesLen--
+ d.cached = false
+ }
+ return
+ }
+ // If the dentry is deleted and invalidated or has no parent, then it is no
+ // longer reachable by path resolution and should be dropped immediately
+ // because it has zero references.
+ // Note that a dentry may not always have a parent; for example magic links
+ // as described in Inode.Getlink.
+ if isDead := d.VFSDentry().IsDead(); isDead || d.parent == nil {
+ if !isDead {
+ d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, d.VFSDentry())
+ }
+ if d.cached {
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentriesLen--
+ d.cached = false
+ }
+ d.destroyLocked(ctx)
+ return
+ }
+ // If d is already cached, just move it to the front of the LRU.
+ if d.cached {
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentries.PushFront(d)
+ return
+ }
+ // Cache the dentry, then evict the least recently used cached dentry if
+ // the cache becomes over-full.
+ d.fs.cachedDentries.PushFront(d)
+ d.fs.cachedDentriesLen++
+ d.cached = true
+ if d.fs.cachedDentriesLen <= d.fs.MaxCachedDentries {
+ return
+ }
+ // Evict the least recently used dentry because cache size is greater than
+ // max cache size (configured on mount).
+ victim := d.fs.cachedDentries.Back()
+ d.fs.cachedDentries.Remove(victim)
+ d.fs.cachedDentriesLen--
+ victim.cached = false
+ // victim.refs may have become non-zero from an earlier path resolution
+ // after it was inserted into fs.cachedDentries.
+ if atomic.LoadInt64(&victim.refs) == 0 {
+ if !victim.vfsd.IsDead() {
+ victim.parent.dirMu.Lock()
+ // Note that victim can't be a mount point (in any mount
+ // namespace), since VFS holds references on mount points.
+ d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, victim.VFSDentry())
+ delete(victim.parent.children, victim.name)
+ victim.parent.dirMu.Unlock()
+ }
+ victim.destroyLocked(ctx)
+ }
+ // Whether or not victim was destroyed, we brought fs.cachedDentriesLen
+ // back down to fs.MaxCachedDentries, so we don't loop.
+}
+
+// destroyLocked destroys the dentry.
+//
+// Preconditions:
+// * d.fs.mu must be locked for writing.
+// * d.refs == 0.
+// * d should have been removed from d.parent.children, i.e. d is not reachable
+// by path traversal.
+// * d.vfsd.IsDead() is true.
+func (d *Dentry) destroyLocked(ctx context.Context) {
+ switch atomic.LoadInt64(&d.refs) {
+ case 0:
+ // Mark the dentry destroyed.
+ atomic.StoreInt64(&d.refs, -1)
+ case -1:
+ panic("dentry.destroyLocked() called on already destroyed dentry")
+ default:
+ panic("dentry.destroyLocked() called with references on the dentry")
+ }
+
+ d.inode.DecRef(ctx) // IncRef from Init.
+ d.inode = nil
+
+ // Drop the reference held by d on its parent without recursively locking
+ // d.fs.mu.
+ if d.parent != nil {
+ if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 {
+ d.parent.cacheLocked(ctx)
+ } else if refs < 0 {
+ panic("kernfs.Dentry.DecRef() called without holding a reference")
+ }
+ }
+}
+
// Init initializes this dentry.
//
// Precondition: Caller must hold a reference on inode.
@@ -197,6 +363,7 @@ func (d *Dentry) Init(fs *Filesystem, inode Inode) {
d.vfsd.Init(d)
d.fs = fs
d.inode = inode
+ atomic.StoreInt64(&d.refs, 1)
ftype := inode.Mode().FileType()
if ftype == linux.ModeDirectory {
d.flags |= dflagsIsDir
@@ -204,7 +371,6 @@ func (d *Dentry) Init(fs *Filesystem, inode Inode) {
if ftype == linux.ModeSymlink {
d.flags |= dflagsIsSymlink
}
- d.EnableLeakCheck()
}
// VFSDentry returns the generic vfs dentry for this kernfs dentry.
@@ -222,32 +388,6 @@ func (d *Dentry) isSymlink() bool {
return atomic.LoadUint32(&d.flags)&dflagsIsSymlink != 0
}
-// DecRef implements vfs.DentryImpl.DecRef.
-func (d *Dentry) DecRef(ctx context.Context) {
- decRefParent := false
- d.fs.mu.Lock()
- d.DentryRefs.DecRef(func() {
- d.inode.DecRef(ctx) // IncRef from Init.
- d.inode = nil
- if d.parent != nil {
- // We will DecRef d.parent once all locks are dropped.
- decRefParent = true
- d.parent.dirMu.Lock()
- // Remove d from parent.children. It might already have been
- // removed due to invalidation.
- if _, ok := d.parent.children[d.name]; ok {
- delete(d.parent.children, d.name)
- d.fs.VFSFilesystem().VirtualFilesystem().InvalidateDentry(ctx, d.VFSDentry())
- }
- d.parent.dirMu.Unlock()
- }
- })
- d.fs.mu.Unlock()
- if decRefParent {
- d.parent.DecRef(ctx) // IncRef from Dentry.insertChild.
- }
-}
-
// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
//
// Although Linux technically supports inotify on pseudo filesystems (inotify
@@ -267,7 +407,9 @@ func (d *Dentry) OnZeroWatches(context.Context) {}
// this dentry. This does not update the directory inode, so calling this on its
// own isn't sufficient to insert a child into a directory.
//
-// Precondition: d must represent a directory inode.
+// Preconditions:
+// * d must represent a directory inode.
+// * d.fs.mu must be locked for at least reading.
func (d *Dentry) insertChild(name string, child *Dentry) {
d.dirMu.Lock()
d.insertChildLocked(name, child)
@@ -280,6 +422,7 @@ func (d *Dentry) insertChild(name string, child *Dentry) {
// Preconditions:
// * d must represent a directory inode.
// * d.dirMu must be locked.
+// * d.fs.mu must be locked for at least reading.
func (d *Dentry) insertChildLocked(name string, child *Dentry) {
if !d.isDir() {
panic(fmt.Sprintf("insertChildLocked called on non-directory Dentry: %+v.", d))
@@ -436,7 +579,7 @@ type inodeDirectory interface {
// the inode is a directory.
//
// The child returned by Lookup will be hashed into the VFS dentry tree,
- // atleast for the duration of the current FS operation.
+ // at least for the duration of the current FS operation.
//
// Lookup must return the child with an extra reference whose ownership is
// transferred to the dentry that is created to point to that inode. If
@@ -454,7 +597,7 @@ type inodeDirectory interface {
// inside the entries returned by this IterDirents invocation. In other words,
// 'offset' should be used to calculate each vfs.Dirent.NextOff as well as
// the return value, while 'relOffset' is the place to start iteration.
- IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error)
+ IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error)
}
type inodeSymlink interface {
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
index 82fa19c03..2418eec44 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -36,7 +36,7 @@ const staticFileContent = "This is sample content for a static test file."
// RootDentryFn is a generator function for creating the root dentry of a test
// filesystem. See newTestSystem.
-type RootDentryFn func(*auth.Credentials, *filesystem) kernfs.Inode
+type RootDentryFn func(context.Context, *auth.Credentials, *filesystem) kernfs.Inode
// newTestSystem sets up a minimal environment for running a test, including an
// instance of a test filesystem. Tests can control the contents of the
@@ -72,10 +72,10 @@ type file struct {
content string
}
-func (fs *filesystem) newFile(creds *auth.Credentials, content string) kernfs.Inode {
+func (fs *filesystem) newFile(ctx context.Context, creds *auth.Credentials, content string) kernfs.Inode {
f := &file{}
f.content = content
- f.DynamicBytesFile.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), f, 0777)
+ f.DynamicBytesFile.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), f, 0777)
return f
}
@@ -105,9 +105,9 @@ type readonlyDir struct {
locks vfs.FileLocks
}
-func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
+func (fs *filesystem) newReadonlyDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
dir := &readonlyDir{}
- dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode)
+ dir.attrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode)
dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
dir.EnableLeakCheck()
dir.IncLinks(dir.OrderedChildren.Populate(contents))
@@ -142,10 +142,10 @@ type dir struct {
fs *filesystem
}
-func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
+func (fs *filesystem) newDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
dir := &dir{}
dir.fs = fs
- dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode)
+ dir.attrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode)
dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true})
dir.EnableLeakCheck()
@@ -169,22 +169,24 @@ func (d *dir) DecRef(ctx context.Context) {
func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (kernfs.Inode, error) {
creds := auth.CredentialsFromContext(ctx)
- dir := d.fs.newDir(creds, opts.Mode, nil)
+ dir := d.fs.newDir(ctx, creds, opts.Mode, nil)
if err := d.OrderedChildren.Insert(name, dir); err != nil {
dir.DecRef(ctx)
return nil, err
}
+ d.TouchCMtime(ctx)
d.IncLinks(1)
return dir, nil
}
func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (kernfs.Inode, error) {
creds := auth.CredentialsFromContext(ctx)
- f := d.fs.newFile(creds, "")
+ f := d.fs.newFile(ctx, creds, "")
if err := d.OrderedChildren.Insert(name, f); err != nil {
f.DecRef(ctx)
return nil, err
}
+ d.TouchCMtime(ctx)
return f, nil
}
@@ -209,7 +211,7 @@ func (fsType) Release(ctx context.Context) {}
func (fst fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opt vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
fs := &filesystem{}
fs.VFSFilesystem().Init(vfsObj, &fst, fs)
- root := fst.rootFn(creds, fs)
+ root := fst.rootFn(ctx, creds, fs)
var d kernfs.Dentry
d.Init(&fs.Filesystem, root)
return fs.VFSFilesystem(), d.VFSDentry(), nil
@@ -218,9 +220,9 @@ func (fst fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesyst
// -------------------- Remainder of the file are test cases --------------------
func TestBasic(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
- return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
- "file1": fs.newFile(creds, staticFileContent),
+ sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{
+ "file1": fs.newFile(ctx, creds, staticFileContent),
})
})
defer sys.Destroy()
@@ -228,9 +230,9 @@ func TestBasic(t *testing.T) {
}
func TestMkdirGetDentry(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
- return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
- "dir1": fs.newDir(creds, 0755, nil),
+ sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{
+ "dir1": fs.newDir(ctx, creds, 0755, nil),
})
})
defer sys.Destroy()
@@ -243,9 +245,9 @@ func TestMkdirGetDentry(t *testing.T) {
}
func TestReadStaticFile(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
- return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
- "file1": fs.newFile(creds, staticFileContent),
+ sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{
+ "file1": fs.newFile(ctx, creds, staticFileContent),
})
})
defer sys.Destroy()
@@ -269,9 +271,9 @@ func TestReadStaticFile(t *testing.T) {
}
func TestCreateNewFileInStaticDir(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
- return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
- "dir1": fs.newDir(creds, 0755, nil),
+ sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{
+ "dir1": fs.newDir(ctx, creds, 0755, nil),
})
})
defer sys.Destroy()
@@ -296,8 +298,8 @@ func TestCreateNewFileInStaticDir(t *testing.T) {
}
func TestDirFDReadWrite(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
- return fs.newReadonlyDir(creds, 0755, nil)
+ sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(ctx, creds, 0755, nil)
})
defer sys.Destroy()
@@ -320,14 +322,14 @@ func TestDirFDReadWrite(t *testing.T) {
}
func TestDirFDIterDirents(t *testing.T) {
- sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode {
- return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{
+ sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode {
+ return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{
// Fill root with nodes backed by various inode implementations.
- "dir1": fs.newReadonlyDir(creds, 0755, nil),
- "dir2": fs.newDir(creds, 0755, map[string]kernfs.Inode{
- "dir3": fs.newDir(creds, 0755, nil),
+ "dir1": fs.newReadonlyDir(ctx, creds, 0755, nil),
+ "dir2": fs.newDir(ctx, creds, 0755, map[string]kernfs.Inode{
+ "dir3": fs.newDir(ctx, creds, 0755, nil),
}),
- "file1": fs.newFile(creds, staticFileContent),
+ "file1": fs.newFile(ctx, creds, staticFileContent),
})
})
defer sys.Destroy()
diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go
index 934cc6c9e..a0736c0d6 100644
--- a/pkg/sentry/fsimpl/kernfs/symlink.go
+++ b/pkg/sentry/fsimpl/kernfs/symlink.go
@@ -38,16 +38,16 @@ type StaticSymlink struct {
var _ Inode = (*StaticSymlink)(nil)
// NewStaticSymlink creates a new symlink file pointing to 'target'.
-func NewStaticSymlink(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, target string) Inode {
+func NewStaticSymlink(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, target string) Inode {
inode := &StaticSymlink{}
- inode.Init(creds, devMajor, devMinor, ino, target)
+ inode.Init(ctx, creds, devMajor, devMinor, ino, target)
return inode
}
// Init initializes the instance.
-func (s *StaticSymlink) Init(creds *auth.Credentials, devMajor uint32, devMinor uint32, ino uint64, target string) {
+func (s *StaticSymlink) Init(ctx context.Context, creds *auth.Credentials, devMajor uint32, devMinor uint32, ino uint64, target string) {
s.target = target
- s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeSymlink|0777)
+ s.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeSymlink|0777)
}
// Readlink implements Inode.Readlink.
diff --git a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go
index d0ed17b18..463d77d79 100644
--- a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go
+++ b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go
@@ -41,17 +41,17 @@ type syntheticDirectory struct {
var _ Inode = (*syntheticDirectory)(nil)
-func newSyntheticDirectory(creds *auth.Credentials, perm linux.FileMode) Inode {
+func newSyntheticDirectory(ctx context.Context, creds *auth.Credentials, perm linux.FileMode) Inode {
inode := &syntheticDirectory{}
- inode.Init(creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm)
+ inode.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm)
return inode
}
-func (dir *syntheticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) {
+func (dir *syntheticDirectory) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) {
if perm&^linux.PermissionsMask != 0 {
panic(fmt.Sprintf("perm contains non-permission bits: %#o", perm))
}
- dir.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.S_IFDIR|perm)
+ dir.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.S_IFDIR|perm)
dir.OrderedChildren.Init(OrderedChildrenOptions{
Writable: true,
})
@@ -76,11 +76,12 @@ func (dir *syntheticDirectory) NewDir(ctx context.Context, name string, opts vfs
if !opts.ForSyntheticMountpoint {
return nil, syserror.EPERM
}
- subdirI := newSyntheticDirectory(auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask)
+ subdirI := newSyntheticDirectory(ctx, auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask)
if err := dir.OrderedChildren.Insert(name, subdirI); err != nil {
subdirI.DecRef(ctx)
return nil, err
}
+ dir.TouchCMtime(ctx)
return subdirI, nil
}
diff --git a/pkg/sentry/fsimpl/overlay/BUILD b/pkg/sentry/fsimpl/overlay/BUILD
index 8cf5b35d3..fd6c55921 100644
--- a/pkg/sentry/fsimpl/overlay/BUILD
+++ b/pkg/sentry/fsimpl/overlay/BUILD
@@ -21,14 +21,18 @@ go_library(
"directory.go",
"filesystem.go",
"fstree.go",
- "non_directory.go",
"overlay.go",
+ "regular_file.go",
+ "save_restore.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/context",
"//pkg/fspath",
+ "//pkg/log",
+ "//pkg/refsvfs2",
+ "//pkg/sentry/arch",
"//pkg/sentry/fs/lock",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/memmap",
@@ -37,5 +41,6 @@ go_library(
"//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
+ "//pkg/waiter",
],
)
diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go
index 73b126669..4506642ca 100644
--- a/pkg/sentry/fsimpl/overlay/copy_up.go
+++ b/pkg/sentry/fsimpl/overlay/copy_up.go
@@ -75,8 +75,21 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
return syserror.ENOENT
}
- // Perform copy-up.
+ // Obtain settable timestamps from the lower layer.
vfsObj := d.fs.vfsfs.VirtualFilesystem()
+ oldpop := vfs.PathOperation{
+ Root: d.lowerVDs[0],
+ Start: d.lowerVDs[0],
+ }
+ const timestampsMask = linux.STATX_ATIME | linux.STATX_MTIME
+ oldStat, err := vfsObj.StatAt(ctx, d.fs.creds, &oldpop, &vfs.StatOptions{
+ Mask: timestampsMask,
+ })
+ if err != nil {
+ return err
+ }
+
+ // Perform copy-up.
newpop := vfs.PathOperation{
Root: d.parent.upperVD,
Start: d.parent.upperVD,
@@ -101,10 +114,7 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
}
switch ftype {
case linux.S_IFREG:
- oldFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
- Root: d.lowerVDs[0],
- Start: d.lowerVDs[0],
- }, &vfs.OpenOptions{
+ oldFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &oldpop, &vfs.OpenOptions{
Flags: linux.O_RDONLY,
})
if err != nil {
@@ -160,9 +170,11 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
}
if err := newFD.SetStat(ctx, vfs.SetStatOptions{
Stat: linux.Statx{
- Mask: linux.STATX_UID | linux.STATX_GID,
- UID: d.uid,
- GID: d.gid,
+ Mask: linux.STATX_UID | linux.STATX_GID | oldStat.Mask&timestampsMask,
+ UID: d.uid,
+ GID: d.gid,
+ Atime: oldStat.Atime,
+ Mtime: oldStat.Mtime,
},
}); err != nil {
cleanupUndoCopyUp()
@@ -179,9 +191,11 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
}
if err := vfsObj.SetStatAt(ctx, d.fs.creds, &newpop, &vfs.SetStatOptions{
Stat: linux.Statx{
- Mask: linux.STATX_UID | linux.STATX_GID,
- UID: d.uid,
- GID: d.gid,
+ Mask: linux.STATX_UID | linux.STATX_GID | oldStat.Mask&timestampsMask,
+ UID: d.uid,
+ GID: d.gid,
+ Atime: oldStat.Atime,
+ Mtime: oldStat.Mtime,
},
}); err != nil {
cleanupUndoCopyUp()
@@ -195,10 +209,7 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
d.upperVD = upperVD
case linux.S_IFLNK:
- target, err := vfsObj.ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{
- Root: d.lowerVDs[0],
- Start: d.lowerVDs[0],
- })
+ target, err := vfsObj.ReadlinkAt(ctx, d.fs.creds, &oldpop)
if err != nil {
return err
}
@@ -207,10 +218,12 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
}
if err := vfsObj.SetStatAt(ctx, d.fs.creds, &newpop, &vfs.SetStatOptions{
Stat: linux.Statx{
- Mask: linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID,
- Mode: uint16(d.mode),
- UID: d.uid,
- GID: d.gid,
+ Mask: linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | oldStat.Mask&timestampsMask,
+ Mode: uint16(d.mode),
+ UID: d.uid,
+ GID: d.gid,
+ Atime: oldStat.Atime,
+ Mtime: oldStat.Mtime,
},
}); err != nil {
cleanupUndoCopyUp()
@@ -224,25 +237,20 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
d.upperVD = upperVD
case linux.S_IFBLK, linux.S_IFCHR:
- lowerStat, err := vfsObj.StatAt(ctx, d.fs.creds, &vfs.PathOperation{
- Root: d.lowerVDs[0],
- Start: d.lowerVDs[0],
- }, &vfs.StatOptions{})
- if err != nil {
- return err
- }
if err := vfsObj.MknodAt(ctx, d.fs.creds, &newpop, &vfs.MknodOptions{
Mode: linux.FileMode(d.mode),
- DevMajor: lowerStat.RdevMajor,
- DevMinor: lowerStat.RdevMinor,
+ DevMajor: oldStat.RdevMajor,
+ DevMinor: oldStat.RdevMinor,
}); err != nil {
return err
}
if err := vfsObj.SetStatAt(ctx, d.fs.creds, &newpop, &vfs.SetStatOptions{
Stat: linux.Statx{
- Mask: linux.STATX_UID | linux.STATX_GID,
- UID: d.uid,
- GID: d.gid,
+ Mask: linux.STATX_UID | linux.STATX_GID | oldStat.Mask&timestampsMask,
+ UID: d.uid,
+ GID: d.gid,
+ Atime: oldStat.Atime,
+ Mtime: oldStat.Mtime,
},
}); err != nil {
cleanupUndoCopyUp()
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
index bd11372d5..10161a08d 100644
--- a/pkg/sentry/fsimpl/overlay/filesystem.go
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -302,8 +302,14 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str
child.devMinor = fs.dirDevMinor
child.ino = fs.newDirIno()
} else if !child.upperVD.Ok() {
+ childDevMinor, err := fs.getLowerDevMinor(child.devMajor, child.devMinor)
+ if err != nil {
+ ctx.Infof("overlay.filesystem.lookupLocked: failed to map lower layer device number (%d, %d) to an overlay-specific device number: %v", child.devMajor, child.devMinor, err)
+ child.destroyLocked(ctx)
+ return nil, err
+ }
child.devMajor = linux.UNNAMED_MAJOR
- child.devMinor = fs.lowerDevMinors[child.lowerVDs[0].Mount().Filesystem()]
+ child.devMinor = childDevMinor
}
parent.IncRef()
@@ -765,7 +771,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if mustCreate {
return nil, syserror.EEXIST
}
- if mayWrite {
+ if start.isRegularFile() && mayWrite {
if err := start.copyUpLocked(ctx); err != nil {
return nil, err
}
@@ -819,7 +825,7 @@ afterTrailingSymlink:
if rp.MustBeDir() && !child.isDir() {
return nil, syserror.ENOTDIR
}
- if mayWrite {
+ if child.isRegularFile() && mayWrite {
if err := child.copyUpLocked(ctx); err != nil {
return nil, err
}
@@ -872,8 +878,11 @@ func (d *dentry) openCopiedUp(ctx context.Context, rp *vfs.ResolvingPath, opts *
if err != nil {
return nil, err
}
+ if ftype != linux.S_IFREG {
+ return layerFD, nil
+ }
layerFlags := layerFD.StatusFlags()
- fd := &nonDirectoryFD{
+ fd := &regularFileFD{
copiedUp: isUpper,
cachedFD: layerFD,
cachedFlags: layerFlags,
@@ -969,7 +978,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving
}
// Finally construct the overlay FD.
upperFlags := upperFD.StatusFlags()
- fd := &nonDirectoryFD{
+ fd := &regularFileFD{
copiedUp: true,
cachedFD: upperFD,
cachedFlags: upperFlags,
@@ -1293,6 +1302,9 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
if !child.isDir() {
return syserror.ENOTDIR
}
+ if err := vfs.CheckDeleteSticky(rp.Credentials(), linux.FileMode(atomic.LoadUint32(&parent.mode)), auth.KUID(atomic.LoadUint32(&child.uid))); err != nil {
+ return err
+ }
child.dirMu.Lock()
defer child.dirMu.Unlock()
whiteouts, err := child.collectWhiteoutsForRmdirLocked(ctx)
@@ -1528,12 +1540,38 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
return err
}
+ parentMode := atomic.LoadUint32(&parent.mode)
child := parent.children[name]
var childLayer lookupLayer
+ if child == nil {
+ if parentMode&linux.S_ISVTX != 0 {
+ // If the parent's sticky bit is set, we need a child dentry to get
+ // its owner.
+ child, err = fs.getChildLocked(ctx, parent, name, &ds)
+ if err != nil {
+ return err
+ }
+ } else {
+ // Determine if the file being unlinked actually exists. Holding
+ // parent.dirMu prevents a dentry from being instantiated for the file,
+ // which in turn prevents it from being copied-up, so this result is
+ // stable.
+ childLayer, err = fs.lookupLayerLocked(ctx, parent, name)
+ if err != nil {
+ return err
+ }
+ if !childLayer.existsInOverlay() {
+ return syserror.ENOENT
+ }
+ }
+ }
if child != nil {
if child.isDir() {
return syserror.EISDIR
}
+ if err := vfs.CheckDeleteSticky(rp.Credentials(), linux.FileMode(parentMode), auth.KUID(atomic.LoadUint32(&child.uid))); err != nil {
+ return err
+ }
if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
return err
}
@@ -1546,18 +1584,6 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
} else {
childLayer = lookupLayerLower
}
- } else {
- // Determine if the file being unlinked actually exists. Holding
- // parent.dirMu prevents a dentry from being instantiated for the file,
- // which in turn prevents it from being copied-up, so this result is
- // stable.
- childLayer, err = fs.lookupLayerLocked(ctx, parent, name)
- if err != nil {
- return err
- }
- if !childLayer.existsInOverlay() {
- return syserror.ENOENT
- }
}
pop := vfs.PathOperation{
diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go
index e5f506d2e..c812f0a70 100644
--- a/pkg/sentry/fsimpl/overlay/overlay.go
+++ b/pkg/sentry/fsimpl/overlay/overlay.go
@@ -18,10 +18,11 @@
//
// Lock order:
//
-// directoryFD.mu / nonDirectoryFD.mu
+// directoryFD.mu / regularFileFD.mu
// filesystem.renameMu
// dentry.dirMu
// dentry.copyMu
+// filesystem.devMu
// *** "memmap.Mappable locks" below this point
// dentry.mapsMu
// *** "memmap.Mappable locks taken by Translate" below this point
@@ -33,12 +34,14 @@
package overlay
import (
+ "fmt"
"strings"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
@@ -99,10 +102,15 @@ type filesystem struct {
// is immutable.
dirDevMinor uint32
- // lowerDevMinors maps lower layer filesystems to device minor numbers
- // assigned to non-directory files originating from that filesystem.
- // lowerDevMinors is immutable.
- lowerDevMinors map[*vfs.Filesystem]uint32
+ // lowerDevMinors maps device numbers from lower layer filesystems to
+ // device minor numbers assigned to non-directory files originating from
+ // that filesystem. (This remapping is necessary for lower layers because a
+ // file on a lower layer, and that same file on an overlay, are
+ // distinguishable because they will diverge after copy-up; this isn't true
+ // for non-directory files already on the upper layer.) lowerDevMinors is
+ // protected by devMu.
+ devMu sync.Mutex `state:"nosave"`
+ lowerDevMinors map[layerDevNumber]uint32
// renameMu synchronizes renaming with non-renaming operations in order to
// ensure consistent lock ordering between dentry.dirMu in different
@@ -114,78 +122,69 @@ type filesystem struct {
lastDirIno uint64
}
+// +stateify savable
+type layerDevNumber struct {
+ major uint32
+ minor uint32
+}
+
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
mopts := vfs.GenericParseMountOptions(opts.Data)
fsoptsRaw := opts.InternalData
- fsopts, haveFSOpts := fsoptsRaw.(FilesystemOptions)
- if fsoptsRaw != nil && !haveFSOpts {
+ fsopts, ok := fsoptsRaw.(FilesystemOptions)
+ if fsoptsRaw != nil && !ok {
ctx.Infof("overlay.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted overlay.FilesystemOptions or nil", fsoptsRaw)
return nil, nil, syserror.EINVAL
}
- if haveFSOpts {
- if len(fsopts.LowerRoots) == 0 {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: LowerRoots must be non-empty")
+ vfsroot := vfs.RootFromContext(ctx)
+ if vfsroot.Ok() {
+ defer vfsroot.DecRef(ctx)
+ }
+
+ if upperPathname, ok := mopts["upperdir"]; ok {
+ if fsopts.UpperRoot.Ok() {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: both upperdir and FilesystemOptions.UpperRoot are specified")
return nil, nil, syserror.EINVAL
}
- if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two LowerRoots are required when UpperRoot is unspecified")
+ delete(mopts, "upperdir")
+ // Linux overlayfs also requires a workdir when upperdir is
+ // specified; we don't, so silently ignore this option.
+ delete(mopts, "workdir")
+ upperPath := fspath.Parse(upperPathname)
+ if !upperPath.Absolute {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname)
return nil, nil, syserror.EINVAL
}
- // We don't enforce a maximum number of lower layers when not
- // configured by applications; the sandbox owner can have an overlay
- // filesystem with any number of lower layers.
- } else {
- vfsroot := vfs.RootFromContext(ctx)
- defer vfsroot.DecRef(ctx)
- upperPathname, ok := mopts["upperdir"]
- if ok {
- delete(mopts, "upperdir")
- // Linux overlayfs also requires a workdir when upperdir is
- // specified; we don't, so silently ignore this option.
- delete(mopts, "workdir")
- upperPath := fspath.Parse(upperPathname)
- if !upperPath.Absolute {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname)
- return nil, nil, syserror.EINVAL
- }
- upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{
- Root: vfsroot,
- Start: vfsroot,
- Path: upperPath,
- FollowFinalSymlink: true,
- }, &vfs.GetDentryOptions{
- CheckSearchable: true,
- })
- if err != nil {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err)
- return nil, nil, err
- }
- defer upperRoot.DecRef(ctx)
- privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */)
- if err != nil {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err)
- return nil, nil, err
- }
- defer privateUpperRoot.DecRef(ctx)
- fsopts.UpperRoot = privateUpperRoot
+ upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{
+ Root: vfsroot,
+ Start: vfsroot,
+ Path: upperPath,
+ FollowFinalSymlink: true,
+ }, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err)
+ return nil, nil, err
+ }
+ privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */)
+ upperRoot.DecRef(ctx)
+ if err != nil {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err)
+ return nil, nil, err
}
- lowerPathnamesStr, ok := mopts["lowerdir"]
- if !ok {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: missing required option lowerdir")
+ defer privateUpperRoot.DecRef(ctx)
+ fsopts.UpperRoot = privateUpperRoot
+ }
+
+ if lowerPathnamesStr, ok := mopts["lowerdir"]; ok {
+ if len(fsopts.LowerRoots) != 0 {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: both lowerdir and FilesystemOptions.LowerRoots are specified")
return nil, nil, syserror.EINVAL
}
delete(mopts, "lowerdir")
lowerPathnames := strings.Split(lowerPathnamesStr, ":")
- const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK
- if len(lowerPathnames) < 2 && !fsopts.UpperRoot.Ok() {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two lowerdirs are required when upperdir is unspecified")
- return nil, nil, syserror.EINVAL
- }
- if len(lowerPathnames) > maxLowerLayers {
- ctx.Infof("overlay.FilesystemType.GetFilesystem: %d lowerdirs specified, maximum %d", len(lowerPathnames), maxLowerLayers)
- return nil, nil, syserror.EINVAL
- }
for _, lowerPathname := range lowerPathnames {
lowerPath := fspath.Parse(lowerPathname)
if !lowerPath.Absolute {
@@ -204,8 +203,8 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve lowerdir %q: %v", lowerPathname, err)
return nil, nil, err
}
- defer lowerRoot.DecRef(ctx)
privateLowerRoot, err := clonePrivateMount(vfsObj, lowerRoot, true /* forceReadOnly */)
+ lowerRoot.DecRef(ctx)
if err != nil {
ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of lowerdir %q: %v", lowerPathname, err)
return nil, nil, err
@@ -214,31 +213,31 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
fsopts.LowerRoots = append(fsopts.LowerRoots, privateLowerRoot)
}
}
+
if len(mopts) != 0 {
ctx.Infof("overlay.FilesystemType.GetFilesystem: unused options: %v", mopts)
return nil, nil, syserror.EINVAL
}
- // Allocate device numbers.
+ if len(fsopts.LowerRoots) == 0 {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: at least one lower layer is required")
+ return nil, nil, syserror.EINVAL
+ }
+ if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two lower layers are required when no upper layer is present")
+ return nil, nil, syserror.EINVAL
+ }
+ const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK
+ if len(fsopts.LowerRoots) > maxLowerLayers {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: %d lower layers specified, maximum %d", len(fsopts.LowerRoots), maxLowerLayers)
+ return nil, nil, syserror.EINVAL
+ }
+
+ // Allocate dirDevMinor. lowerDevMinors are allocated dynamically.
dirDevMinor, err := vfsObj.GetAnonBlockDevMinor()
if err != nil {
return nil, nil, err
}
- lowerDevMinors := make(map[*vfs.Filesystem]uint32)
- for _, lowerRoot := range fsopts.LowerRoots {
- lowerFS := lowerRoot.Mount().Filesystem()
- if _, ok := lowerDevMinors[lowerFS]; !ok {
- devMinor, err := vfsObj.GetAnonBlockDevMinor()
- if err != nil {
- vfsObj.PutAnonBlockDevMinor(dirDevMinor)
- for _, lowerDevMinor := range lowerDevMinors {
- vfsObj.PutAnonBlockDevMinor(lowerDevMinor)
- }
- return nil, nil, err
- }
- lowerDevMinors[lowerFS] = devMinor
- }
- }
// Take extra references held by the filesystem.
if fsopts.UpperRoot.Ok() {
@@ -252,7 +251,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
opts: fsopts,
creds: creds.Fork(),
dirDevMinor: dirDevMinor,
- lowerDevMinors: lowerDevMinors,
+ lowerDevMinors: make(map[layerDevNumber]uint32),
}
fs.vfsfs.Init(vfsObj, &fstype, fs)
@@ -302,7 +301,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
root.ino = fs.newDirIno()
} else if !root.upperVD.Ok() {
root.devMajor = linux.UNNAMED_MAJOR
- root.devMinor = fs.lowerDevMinors[root.lowerVDs[0].Mount().Filesystem()]
+ rootDevMinor, err := fs.getLowerDevMinor(rootStat.DevMajor, rootStat.DevMinor)
+ if err != nil {
+ ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to get device number for root: %v", err)
+ root.destroyLocked(ctx)
+ fs.vfsfs.DecRef(ctx)
+ return nil, nil, err
+ }
+ root.devMinor = rootDevMinor
root.ino = rootStat.Ino
} else {
root.devMajor = rootStat.DevMajor
@@ -375,6 +381,21 @@ func (fs *filesystem) newDirIno() uint64 {
return atomic.AddUint64(&fs.lastDirIno, 1)
}
+func (fs *filesystem) getLowerDevMinor(layerMajor, layerMinor uint32) (uint32, error) {
+ fs.devMu.Lock()
+ defer fs.devMu.Unlock()
+ orig := layerDevNumber{layerMajor, layerMinor}
+ if minor, ok := fs.lowerDevMinors[orig]; ok {
+ return minor, nil
+ }
+ minor, err := fs.vfsfs.VirtualFilesystem().GetAnonBlockDevMinor()
+ if err != nil {
+ return 0, err
+ }
+ fs.lowerDevMinors[orig] = minor
+ return minor, nil
+}
+
// dentry implements vfs.DentryImpl.
//
// +stateify savable
@@ -453,14 +474,14 @@ type dentry struct {
// - If this dentry is copied-up, then wrappedMappable is the Mappable
// obtained from a call to the current top layer's
// FileDescription.ConfigureMMap(). Once wrappedMappable becomes non-nil
- // (from a call to nonDirectoryFD.ensureMappable()), it cannot become nil.
+ // (from a call to regularFileFD.ensureMappable()), it cannot become nil.
// wrappedMappable is protected by mapsMu and dataMu.
//
// - isMappable is non-zero iff wrappedMappable is non-nil. isMappable is
// accessed using atomic memory operations.
- mapsMu sync.Mutex
+ mapsMu sync.Mutex `state:"nosave"`
lowerMappings memmap.MappingSet
- dataMu sync.RWMutex
+ dataMu sync.RWMutex `state:"nosave"`
wrappedMappable memmap.Mappable
isMappable uint32
@@ -484,6 +505,9 @@ func (fs *filesystem) newDentry() *dentry {
}
d.lowerVDs = d.inlineLowerVDs[:0]
d.vfsd.Init(d)
+ if refsvfs2.LeakCheckEnabled() {
+ refsvfs2.Register(d, "overlay.dentry")
+ }
return d
}
@@ -583,6 +607,14 @@ func (d *dentry) destroyLocked(ctx context.Context) {
panic("overlay.dentry.DecRef() called without holding a reference")
}
}
+ if refsvfs2.LeakCheckEnabled() {
+ refsvfs2.Unregister(d, "overlay.dentry")
+ }
+}
+
+// LeakMessage implements refsvfs2.CheckedObject.LeakMessage.
+func (d *dentry) LeakMessage() string {
+ return fmt.Sprintf("[overlay.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs))
}
// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
diff --git a/pkg/sentry/fsimpl/overlay/non_directory.go b/pkg/sentry/fsimpl/overlay/regular_file.go
index 853aee951..2b89a7a6d 100644
--- a/pkg/sentry/fsimpl/overlay/non_directory.go
+++ b/pkg/sentry/fsimpl/overlay/regular_file.go
@@ -19,14 +19,21 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
)
+func (d *dentry) isRegularFile() bool {
+ return atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFREG
+}
+
func (d *dentry) isSymlink() bool {
return atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFLNK
}
@@ -40,7 +47,7 @@ func (d *dentry) readlink(ctx context.Context) (string, error) {
}
// +stateify savable
-type nonDirectoryFD struct {
+type regularFileFD struct {
fileDescription
// If copiedUp is false, cachedFD represents
@@ -52,9 +59,13 @@ type nonDirectoryFD struct {
copiedUp bool
cachedFD *vfs.FileDescription
cachedFlags uint32
+
+ // If copiedUp is false, lowerWaiters contains all waiter.Entries
+ // registered with cachedFD. lowerWaiters is protected by mu.
+ lowerWaiters map[*waiter.Entry]waiter.EventMask
}
-func (fd *nonDirectoryFD) getCurrentFD(ctx context.Context) (*vfs.FileDescription, error) {
+func (fd *regularFileFD) getCurrentFD(ctx context.Context) (*vfs.FileDescription, error) {
fd.mu.Lock()
defer fd.mu.Unlock()
wrappedFD, err := fd.currentFDLocked(ctx)
@@ -65,7 +76,7 @@ func (fd *nonDirectoryFD) getCurrentFD(ctx context.Context) (*vfs.FileDescriptio
return wrappedFD, nil
}
-func (fd *nonDirectoryFD) currentFDLocked(ctx context.Context) (*vfs.FileDescription, error) {
+func (fd *regularFileFD) currentFDLocked(ctx context.Context) (*vfs.FileDescription, error) {
d := fd.dentry()
statusFlags := fd.vfsfd.StatusFlags()
if !fd.copiedUp && d.isCopiedUp() {
@@ -87,10 +98,21 @@ func (fd *nonDirectoryFD) currentFDLocked(ctx context.Context) (*vfs.FileDescrip
return nil, err
}
}
+ if len(fd.lowerWaiters) != 0 {
+ ready := upperFD.Readiness(^waiter.EventMask(0))
+ for e, mask := range fd.lowerWaiters {
+ fd.cachedFD.EventUnregister(e)
+ upperFD.EventRegister(e, mask)
+ if ready&mask != 0 {
+ e.Callback.Callback(e)
+ }
+ }
+ }
fd.cachedFD.DecRef(ctx)
fd.copiedUp = true
fd.cachedFD = upperFD
fd.cachedFlags = statusFlags
+ fd.lowerWaiters = nil
} else if fd.cachedFlags != statusFlags {
if err := fd.cachedFD.SetStatusFlags(ctx, d.fs.creds, statusFlags); err != nil {
return nil, err
@@ -101,13 +123,13 @@ func (fd *nonDirectoryFD) currentFDLocked(ctx context.Context) (*vfs.FileDescrip
}
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *nonDirectoryFD) Release(ctx context.Context) {
+func (fd *regularFileFD) Release(ctx context.Context) {
fd.cachedFD.DecRef(ctx)
fd.cachedFD = nil
}
// OnClose implements vfs.FileDescriptionImpl.OnClose.
-func (fd *nonDirectoryFD) OnClose(ctx context.Context) error {
+func (fd *regularFileFD) OnClose(ctx context.Context) error {
// Linux doesn't define ovl_file_operations.flush at all (i.e. its
// equivalent to OnClose is a no-op). We pass through to
// fd.cachedFD.OnClose() without upgrading if fd.dentry() has been
@@ -128,7 +150,7 @@ func (fd *nonDirectoryFD) OnClose(ctx context.Context) error {
}
// Stat implements vfs.FileDescriptionImpl.Stat.
-func (fd *nonDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+func (fd *regularFileFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
var stat linux.Statx
if layerMask := opts.Mask &^ statInternalMask; layerMask != 0 {
wrappedFD, err := fd.getCurrentFD(ctx)
@@ -149,7 +171,7 @@ func (fd *nonDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux
}
// Allocate implements vfs.FileDescriptionImpl.Allocate.
-func (fd *nonDirectoryFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
+func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
wrappedFD, err := fd.getCurrentFD(ctx)
if err != nil {
return err
@@ -159,7 +181,7 @@ func (fd *nonDirectoryFD) Allocate(ctx context.Context, mode, offset, length uin
}
// SetStat implements vfs.FileDescriptionImpl.SetStat.
-func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+func (fd *regularFileFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
d := fd.dentry()
mode := linux.FileMode(atomic.LoadUint32(&d.mode))
if err := vfs.CheckSetStat(ctx, auth.CredentialsFromContext(ctx), &opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
@@ -191,12 +213,61 @@ func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions)
}
// StatFS implements vfs.FileDescriptionImpl.StatFS.
-func (fd *nonDirectoryFD) StatFS(ctx context.Context) (linux.Statfs, error) {
+func (fd *regularFileFD) StatFS(ctx context.Context) (linux.Statfs, error) {
return fd.filesystem().statFS(ctx)
}
+// Readiness implements waiter.Waitable.Readiness.
+func (fd *regularFileFD) Readiness(mask waiter.EventMask) waiter.EventMask {
+ ctx := context.Background()
+ wrappedFD, err := fd.getCurrentFD(ctx)
+ if err != nil {
+ // TODO(b/171089913): Just use fd.cachedFD since Readiness can't return
+ // an error. This is obviously wrong, but at least consistent with
+ // VFS1.
+ log.Warningf("overlay.regularFileFD.Readiness: currentFDLocked failed: %v", err)
+ fd.mu.Lock()
+ wrappedFD = fd.cachedFD
+ wrappedFD.IncRef()
+ fd.mu.Unlock()
+ }
+ defer wrappedFD.DecRef(ctx)
+ return wrappedFD.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (fd *regularFileFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ wrappedFD, err := fd.currentFDLocked(context.Background())
+ if err != nil {
+ // TODO(b/171089913): Just use fd.cachedFD since EventRegister can't
+ // return an error. This is obviously wrong, but at least consistent
+ // with VFS1.
+ log.Warningf("overlay.regularFileFD.EventRegister: currentFDLocked failed: %v", err)
+ wrappedFD = fd.cachedFD
+ }
+ wrappedFD.EventRegister(e, mask)
+ if !fd.copiedUp {
+ if fd.lowerWaiters == nil {
+ fd.lowerWaiters = make(map[*waiter.Entry]waiter.EventMask)
+ }
+ fd.lowerWaiters[e] = mask
+ }
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (fd *regularFileFD) EventUnregister(e *waiter.Entry) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ fd.cachedFD.EventUnregister(e)
+ if !fd.copiedUp {
+ delete(fd.lowerWaiters, e)
+ }
+}
+
// PRead implements vfs.FileDescriptionImpl.PRead.
-func (fd *nonDirectoryFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
wrappedFD, err := fd.getCurrentFD(ctx)
if err != nil {
return 0, err
@@ -206,7 +277,7 @@ func (fd *nonDirectoryFD) PRead(ctx context.Context, dst usermem.IOSequence, off
}
// Read implements vfs.FileDescriptionImpl.Read.
-func (fd *nonDirectoryFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
// Hold fd.mu during the read to serialize the file offset.
fd.mu.Lock()
defer fd.mu.Unlock()
@@ -218,7 +289,7 @@ func (fd *nonDirectoryFD) Read(ctx context.Context, dst usermem.IOSequence, opts
}
// PWrite implements vfs.FileDescriptionImpl.PWrite.
-func (fd *nonDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
wrappedFD, err := fd.getCurrentFD(ctx)
if err != nil {
return 0, err
@@ -228,7 +299,7 @@ func (fd *nonDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence, of
}
// Write implements vfs.FileDescriptionImpl.Write.
-func (fd *nonDirectoryFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
// Hold fd.mu during the write to serialize the file offset.
fd.mu.Lock()
defer fd.mu.Unlock()
@@ -240,7 +311,7 @@ func (fd *nonDirectoryFD) Write(ctx context.Context, src usermem.IOSequence, opt
}
// Seek implements vfs.FileDescriptionImpl.Seek.
-func (fd *nonDirectoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
// Hold fd.mu during the seek to serialize the file offset.
fd.mu.Lock()
defer fd.mu.Unlock()
@@ -252,7 +323,7 @@ func (fd *nonDirectoryFD) Seek(ctx context.Context, offset int64, whence int32)
}
// Sync implements vfs.FileDescriptionImpl.Sync.
-func (fd *nonDirectoryFD) Sync(ctx context.Context) error {
+func (fd *regularFileFD) Sync(ctx context.Context) error {
fd.mu.Lock()
if !fd.dentry().isCopiedUp() {
fd.mu.Unlock()
@@ -269,8 +340,18 @@ func (fd *nonDirectoryFD) Sync(ctx context.Context) error {
return wrappedFD.Sync(ctx)
}
+// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
+func (fd *regularFileFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ wrappedFD, err := fd.getCurrentFD(ctx)
+ if err != nil {
+ return 0, err
+ }
+ defer wrappedFD.DecRef(ctx)
+ return wrappedFD.Ioctl(ctx, uio, args)
+}
+
// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
-func (fd *nonDirectoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
if err := fd.ensureMappable(ctx, opts); err != nil {
return err
}
@@ -278,7 +359,7 @@ func (fd *nonDirectoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOp
}
// ensureMappable ensures that fd.dentry().wrappedMappable is not nil.
-func (fd *nonDirectoryFD) ensureMappable(ctx context.Context, opts *memmap.MMapOpts) error {
+func (fd *regularFileFD) ensureMappable(ctx context.Context, opts *memmap.MMapOpts) error {
d := fd.dentry()
// Fast path if we already have a Mappable for the current top layer.
diff --git a/pkg/sentry/fsimpl/overlay/save_restore.go b/pkg/sentry/fsimpl/overlay/save_restore.go
new file mode 100644
index 000000000..054e17b17
--- /dev/null
+++ b/pkg/sentry/fsimpl/overlay/save_restore.go
@@ -0,0 +1,27 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package overlay
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/refsvfs2"
+)
+
+func (d *dentry) afterLoad() {
+ if refsvfs2.LeakCheckEnabled() && atomic.LoadInt64(&d.refs) != -1 {
+ refsvfs2.Register(d, "overlay.dentry")
+ }
+}
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
index 2e086e34c..5196a2a80 100644
--- a/pkg/sentry/fsimpl/proc/BUILD
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "fd_dir_inode_refs.go",
package = "proc",
prefix = "fdDirInode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "fdDirInode",
},
@@ -19,7 +19,7 @@ go_template_instance(
out = "fd_info_dir_inode_refs.go",
package = "proc",
prefix = "fdInfoDirInode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "fdInfoDirInode",
},
@@ -30,7 +30,7 @@ go_template_instance(
out = "subtasks_inode_refs.go",
package = "proc",
prefix = "subtasksInode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "subtasksInode",
},
@@ -41,7 +41,7 @@ go_template_instance(
out = "task_inode_refs.go",
package = "proc",
prefix = "taskInode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "taskInode",
},
@@ -52,7 +52,7 @@ go_template_instance(
out = "tasks_inode_refs.go",
package = "proc",
prefix = "tasksInode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "tasksInode",
},
@@ -82,6 +82,7 @@ go_library(
"//pkg/context",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/fs/lock",
"//pkg/sentry/fsbridge",
diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go
index fd70a07de..99abcab66 100644
--- a/pkg/sentry/fsimpl/proc/filesystem.go
+++ b/pkg/sentry/fsimpl/proc/filesystem.go
@@ -17,6 +17,7 @@ package proc
import (
"fmt"
+ "strconv"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -24,10 +25,14 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
)
-// Name is the default filesystem name.
-const Name = "proc"
+const (
+ // Name is the default filesystem name.
+ Name = "proc"
+ defaultMaxCachedDentries = uint64(1000)
+)
// FilesystemType is the factory class for procfs.
//
@@ -63,9 +68,22 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF
if err != nil {
return nil, nil, err
}
+
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ maxCachedDentries := defaultMaxCachedDentries
+ if str, ok := mopts["dentry_cache_limit"]; ok {
+ delete(mopts, "dentry_cache_limit")
+ maxCachedDentries, err = strconv.ParseUint(str, 10, 64)
+ if err != nil {
+ ctx.Warningf("proc.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str)
+ return nil, nil, syserror.EINVAL
+ }
+ }
+
procfs := &filesystem{
devMinor: devMinor,
}
+ procfs.MaxCachedDentries = maxCachedDentries
procfs.VFSFilesystem().Init(vfsObj, &ft, procfs)
var cgroups map[string]string
@@ -74,7 +92,7 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF
cgroups = data.Cgroups
}
- inode := procfs.newTasksInode(k, pidns, cgroups)
+ inode := procfs.newTasksInode(ctx, k, pidns, cgroups)
var dentry kernfs.Dentry
dentry.Init(&procfs.Filesystem, inode)
return procfs.VFSFilesystem(), dentry.VFSDentry(), nil
@@ -94,11 +112,11 @@ type dynamicInode interface {
kernfs.Inode
vfs.DynamicBytesSource
- Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode)
+ Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode)
}
-func (fs *filesystem) newInode(creds *auth.Credentials, perm linux.FileMode, inode dynamicInode) dynamicInode {
- inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), inode, perm)
+func (fs *filesystem) newInode(ctx context.Context, creds *auth.Credentials, perm linux.FileMode, inode dynamicInode) dynamicInode {
+ inode.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), inode, perm)
return inode
}
@@ -114,8 +132,8 @@ func newStaticFile(data string) *staticFile {
return &staticFile{StaticData: vfs.StaticData{Data: data}}
}
-func (fs *filesystem) newStaticDir(creds *auth.Credentials, children map[string]kernfs.Inode) kernfs.Inode {
- return kernfs.NewStaticDir(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, children, kernfs.GenericDirectoryFDOptions{
+func (fs *filesystem) newStaticDir(ctx context.Context, creds *auth.Credentials, children map[string]kernfs.Inode) kernfs.Inode {
+ return kernfs.NewStaticDir(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, children, kernfs.GenericDirectoryFDOptions{
SeekEnd: kernfs.SeekEndZero,
})
}
diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go
index bad2fab4f..cb3c5e0fd 100644
--- a/pkg/sentry/fsimpl/proc/subtasks.go
+++ b/pkg/sentry/fsimpl/proc/subtasks.go
@@ -58,7 +58,7 @@ func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace,
cgroupControllers: cgroupControllers,
}
// Note: credentials are overridden by taskOwnedInode.
- subInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+ subInode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
subInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
subInode.EnableLeakCheck()
@@ -84,7 +84,7 @@ func (i *subtasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode,
}
// IterDirents implements kernfs.inodeDirectory.IterDirents.
-func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+func (i *subtasksInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
tasks := i.task.ThreadGroup().MemberIDs(i.pidns)
if len(tasks) == 0 {
return offset, syserror.ENOENT
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index b63a4eca0..19011b010 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -64,6 +64,7 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace
"gid_map": fs.newTaskOwnedInode(task, fs.NextIno(), 0644, &idMapData{task: task, gids: true}),
"io": fs.newTaskOwnedInode(task, fs.NextIno(), 0400, newIO(task, isThreadGroup)),
"maps": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mapsData{task: task}),
+ "mem": fs.newMemInode(task, fs.NextIno(), 0400),
"mountinfo": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mountInfoData{task: task}),
"mounts": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mountsData{task: task}),
"net": fs.newTaskNetDir(task),
@@ -89,7 +90,7 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace
taskInode := &taskInode{task: task}
// Note: credentials are overridden by taskOwnedInode.
- taskInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+ taskInode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
taskInode.EnableLeakCheck()
inode := &taskOwnedInode{Inode: taskInode, owner: task}
@@ -144,7 +145,7 @@ var _ kernfs.Inode = (*taskOwnedInode)(nil)
func (fs *filesystem) newTaskOwnedInode(task *kernel.Task, ino uint64, perm linux.FileMode, inode dynamicInode) kernfs.Inode {
// Note: credentials are overridden by taskOwnedInode.
- inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm)
+ inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm)
return &taskOwnedInode{Inode: inode, owner: task}
}
@@ -152,7 +153,7 @@ func (fs *filesystem) newTaskOwnedInode(task *kernel.Task, ino uint64, perm linu
func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.FileMode, children map[string]kernfs.Inode) kernfs.Inode {
// Note: credentials are overridden by taskOwnedInode.
fdOpts := kernfs.GenericDirectoryFDOptions{SeekEnd: kernfs.SeekEndZero}
- dir := kernfs.NewStaticDir(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, children, fdOpts)
+ dir := kernfs.NewStaticDir(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, children, fdOpts)
return &taskOwnedInode{Inode: dir, owner: task}
}
diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go
index 2c80ac5c2..d268b44be 100644
--- a/pkg/sentry/fsimpl/proc/task_fds.go
+++ b/pkg/sentry/fsimpl/proc/task_fds.go
@@ -64,7 +64,7 @@ type fdDir struct {
}
// IterDirents implements kernfs.inodeDirectory.IterDirents.
-func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+func (i *fdDir) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
var fds []int32
i.task.WithMuLocked(func(t *kernel.Task) {
if fdTable := t.FDTable(); fdTable != nil {
@@ -127,15 +127,15 @@ func (fs *filesystem) newFDDirInode(task *kernel.Task) kernfs.Inode {
produceSymlink: true,
},
}
- inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+ inode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
inode.EnableLeakCheck()
inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
return inode
}
// IterDirents implements kernfs.inodeDirectory.IterDirents.
-func (i *fdDirInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
- return i.fdDir.IterDirents(ctx, cb, offset, relOffset)
+func (i *fdDirInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+ return i.fdDir.IterDirents(ctx, mnt, cb, offset, relOffset)
}
// Lookup implements kernfs.inodeDirectory.Lookup.
@@ -209,7 +209,7 @@ func (fs *filesystem) newFDSymlink(task *kernel.Task, fd int32, ino uint64) kern
task: task,
fd: fd,
}
- inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
+ inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
return inode
}
@@ -264,7 +264,7 @@ func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) kernfs.Inode {
task: task,
},
}
- inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+ inode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
inode.EnableLeakCheck()
inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
return inode
@@ -288,8 +288,8 @@ func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (kernfs.Inode,
}
// IterDirents implements Inode.IterDirents.
-func (i *fdInfoDirInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
- return i.fdDir.IterDirents(ctx, cb, offset, relOffset)
+func (i *fdInfoDirInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
+ return i.fdDir.IterDirents(ctx, mnt, cb, offset, relOffset)
}
// Open implements kernfs.Inode.Open.
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index 79f8b7e9f..ba71d0fde 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -31,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -249,7 +250,7 @@ type commInode struct {
func (fs *filesystem) newComm(task *kernel.Task, ino uint64, perm linux.FileMode) kernfs.Inode {
inode := &commInode{task: task}
- inode.DynamicBytesFile.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, &commData{task: task}, perm)
+ inode.DynamicBytesFile.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, &commData{task: task}, perm)
return inode
}
@@ -366,6 +367,162 @@ func (d *idMapData) Write(ctx context.Context, src usermem.IOSequence, offset in
return int64(srclen), nil
}
+var _ kernfs.Inode = (*memInode)(nil)
+
+// memInode implements kernfs.Inode for /proc/[pid]/mem.
+//
+// +stateify savable
+type memInode struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoStatFS
+ kernfs.InodeNoopRefCount
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+
+ task *kernel.Task
+ locks vfs.FileLocks
+}
+
+func (fs *filesystem) newMemInode(task *kernel.Task, ino uint64, perm linux.FileMode) kernfs.Inode {
+ // Note: credentials are overridden by taskOwnedInode.
+ inode := &memInode{task: task}
+ inode.init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm)
+ return &taskOwnedInode{Inode: inode, owner: task}
+}
+
+func (f *memInode) init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) {
+ if perm&^linux.PermissionsMask != 0 {
+ panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask))
+ }
+ f.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm)
+}
+
+// Open implements kernfs.Inode.Open.
+func (f *memInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ // TODO(gvisor.dev/issue/260): Add check for PTRACE_MODE_ATTACH_FSCREDS
+ // Permission to read this file is governed by PTRACE_MODE_ATTACH_FSCREDS
+ // Since we dont implement setfsuid/setfsgid we can just use PTRACE_MODE_ATTACH
+ if !kernel.ContextCanTrace(ctx, f.task, true) {
+ return nil, syserror.EACCES
+ }
+ if err := checkTaskState(f.task); err != nil {
+ return nil, err
+ }
+ fd := &memFD{}
+ if err := fd.Init(rp.Mount(), d, f, opts.Flags); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// SetStat implements kernfs.Inode.SetStat.
+func (*memInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+var _ vfs.FileDescriptionImpl = (*memFD)(nil)
+
+// memFD implements vfs.FileDescriptionImpl for /proc/[pid]/mem.
+//
+// +stateify savable
+type memFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+
+ inode *memInode
+
+ // mu guards the fields below.
+ mu sync.Mutex `state:"nosave"`
+ offset int64
+}
+
+// Init initializes memFD.
+func (fd *memFD) Init(m *vfs.Mount, d *kernfs.Dentry, inode *memInode, flags uint32) error {
+ fd.LockFD.Init(&inode.locks)
+ if err := fd.vfsfd.Init(fd, flags, m, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil {
+ return err
+ }
+ fd.inode = inode
+ return nil
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *memFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ switch whence {
+ case linux.SEEK_SET:
+ case linux.SEEK_CUR:
+ offset += fd.offset
+ default:
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ fd.offset = offset
+ return offset, nil
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *memFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ m, err := getMMIncRef(fd.inode.task)
+ if err != nil {
+ return 0, nil
+ }
+ defer m.DecUsers(ctx)
+ // Buffer the read data because of MM locks
+ buf := make([]byte, dst.NumBytes())
+ n, readErr := m.CopyIn(ctx, usermem.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true})
+ if n > 0 {
+ if _, err := dst.CopyOut(ctx, buf[:n]); err != nil {
+ return 0, syserror.EFAULT
+ }
+ return int64(n), nil
+ }
+ if readErr != nil {
+ return 0, syserror.EIO
+ }
+ return 0, nil
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *memFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ fd.mu.Lock()
+ n, err := fd.PRead(ctx, dst, fd.offset, opts)
+ fd.offset += n
+ fd.mu.Unlock()
+ return n, err
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *memFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ fs := fd.vfsfd.VirtualDentry().Mount().Filesystem()
+ return fd.inode.Stat(ctx, fs, opts)
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *memFD) SetStat(context.Context, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *memFD) Release(context.Context) {}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *memFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *memFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
+
// mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps.
//
// +stateify savable
@@ -657,7 +814,7 @@ var _ kernfs.Inode = (*exeSymlink)(nil)
func (fs *filesystem) newExeSymlink(task *kernel.Task, ino uint64) kernfs.Inode {
inode := &exeSymlink{task: task}
- inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
+ inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
return inode
}
@@ -733,7 +890,7 @@ var _ kernfs.Inode = (*cwdSymlink)(nil)
func (fs *filesystem) newCwdSymlink(task *kernel.Task, ino uint64) kernfs.Inode {
inode := &cwdSymlink{task: task}
- inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
+ inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
return inode
}
@@ -850,7 +1007,7 @@ func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns stri
inode := &namespaceSymlink{task: task}
// Note: credentials are overridden by taskOwnedInode.
- inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, target)
+ inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, target)
taskInode := &taskOwnedInode{Inode: inode, owner: task}
return taskInode
@@ -872,8 +1029,10 @@ func (s *namespaceSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.Vir
// Create a synthetic inode to represent the namespace.
fs := mnt.Filesystem().Impl().(*filesystem)
+ nsInode := &namespaceInode{}
+ nsInode.Init(ctx, auth.CredentialsFromContext(ctx), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0444)
dentry := &kernfs.Dentry{}
- dentry.Init(&fs.Filesystem, &namespaceInode{})
+ dentry.Init(&fs.Filesystem, nsInode)
vd := vfs.MakeVirtualDentry(mnt, dentry.VFSDentry())
// Only IncRef vd.Mount() because vd.Dentry() already holds a ref of 1.
mnt.IncRef()
@@ -897,11 +1056,11 @@ type namespaceInode struct {
var _ kernfs.Inode = (*namespaceInode)(nil)
// Init initializes a namespace inode.
-func (i *namespaceInode) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) {
+func (i *namespaceInode) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) {
if perm&^linux.PermissionsMask != 0 {
panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask))
}
- i.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm)
+ i.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm)
}
// Open implements kernfs.Inode.Open.
diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go
index 3425e8698..5a9ee111f 100644
--- a/pkg/sentry/fsimpl/proc/task_net.go
+++ b/pkg/sentry/fsimpl/proc/task_net.go
@@ -57,33 +57,33 @@ func (fs *filesystem) newTaskNetDir(task *kernel.Task) kernfs.Inode {
// TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task
// network namespace.
contents = map[string]kernfs.Inode{
- "dev": fs.newInode(root, 0444, &netDevData{stack: stack}),
- "snmp": fs.newInode(root, 0444, &netSnmpData{stack: stack}),
+ "dev": fs.newInode(task, root, 0444, &netDevData{stack: stack}),
+ "snmp": fs.newInode(task, root, 0444, &netSnmpData{stack: stack}),
// The following files are simple stubs until they are implemented in
// netstack, if the file contains a header the stub is just the header
// otherwise it is an empty file.
- "arp": fs.newInode(root, 0444, newStaticFile(arp)),
- "netlink": fs.newInode(root, 0444, newStaticFile(netlink)),
- "netstat": fs.newInode(root, 0444, &netStatData{}),
- "packet": fs.newInode(root, 0444, newStaticFile(packet)),
- "protocols": fs.newInode(root, 0444, newStaticFile(protocols)),
+ "arp": fs.newInode(task, root, 0444, newStaticFile(arp)),
+ "netlink": fs.newInode(task, root, 0444, newStaticFile(netlink)),
+ "netstat": fs.newInode(task, root, 0444, &netStatData{}),
+ "packet": fs.newInode(task, root, 0444, newStaticFile(packet)),
+ "protocols": fs.newInode(task, root, 0444, newStaticFile(protocols)),
// Linux sets psched values to: nsec per usec, psched tick in ns, 1000000,
// high res timer ticks per sec (ClockGetres returns 1ns resolution).
- "psched": fs.newInode(root, 0444, newStaticFile(psched)),
- "ptype": fs.newInode(root, 0444, newStaticFile(ptype)),
- "route": fs.newInode(root, 0444, &netRouteData{stack: stack}),
- "tcp": fs.newInode(root, 0444, &netTCPData{kernel: k}),
- "udp": fs.newInode(root, 0444, &netUDPData{kernel: k}),
- "unix": fs.newInode(root, 0444, &netUnixData{kernel: k}),
+ "psched": fs.newInode(task, root, 0444, newStaticFile(psched)),
+ "ptype": fs.newInode(task, root, 0444, newStaticFile(ptype)),
+ "route": fs.newInode(task, root, 0444, &netRouteData{stack: stack}),
+ "tcp": fs.newInode(task, root, 0444, &netTCPData{kernel: k}),
+ "udp": fs.newInode(task, root, 0444, &netUDPData{kernel: k}),
+ "unix": fs.newInode(task, root, 0444, &netUnixData{kernel: k}),
}
if stack.SupportsIPv6() {
- contents["if_inet6"] = fs.newInode(root, 0444, &ifinet6{stack: stack})
- contents["ipv6_route"] = fs.newInode(root, 0444, newStaticFile(""))
- contents["tcp6"] = fs.newInode(root, 0444, &netTCP6Data{kernel: k})
- contents["udp6"] = fs.newInode(root, 0444, newStaticFile(upd6))
+ contents["if_inet6"] = fs.newInode(task, root, 0444, &ifinet6{stack: stack})
+ contents["ipv6_route"] = fs.newInode(task, root, 0444, newStaticFile(""))
+ contents["tcp6"] = fs.newInode(task, root, 0444, &netTCP6Data{kernel: k})
+ contents["udp6"] = fs.newInode(task, root, 0444, newStaticFile(upd6))
}
}
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index 3259c3732..b81ea14bf 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -62,19 +62,19 @@ type tasksInode struct {
var _ kernfs.Inode = (*tasksInode)(nil)
-func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode {
+func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode {
root := auth.NewRootCredentials(pidns.UserNamespace())
contents := map[string]kernfs.Inode{
- "cpuinfo": fs.newInode(root, 0444, newStaticFileSetStat(cpuInfoData(k))),
- "filesystems": fs.newInode(root, 0444, &filesystemsData{}),
- "loadavg": fs.newInode(root, 0444, &loadavgData{}),
- "sys": fs.newSysDir(root, k),
- "meminfo": fs.newInode(root, 0444, &meminfoData{}),
- "mounts": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/mounts"),
- "net": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/net"),
- "stat": fs.newInode(root, 0444, &statData{}),
- "uptime": fs.newInode(root, 0444, &uptimeData{}),
- "version": fs.newInode(root, 0444, &versionData{}),
+ "cpuinfo": fs.newInode(ctx, root, 0444, newStaticFileSetStat(cpuInfoData(k))),
+ "filesystems": fs.newInode(ctx, root, 0444, &filesystemsData{}),
+ "loadavg": fs.newInode(ctx, root, 0444, &loadavgData{}),
+ "sys": fs.newSysDir(ctx, root, k),
+ "meminfo": fs.newInode(ctx, root, 0444, &meminfoData{}),
+ "mounts": kernfs.NewStaticSymlink(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/mounts"),
+ "net": kernfs.NewStaticSymlink(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/net"),
+ "stat": fs.newInode(ctx, root, 0444, &statData{}),
+ "uptime": fs.newInode(ctx, root, 0444, &uptimeData{}),
+ "version": fs.newInode(ctx, root, 0444, &versionData{}),
}
inode := &tasksInode{
@@ -82,7 +82,7 @@ func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace
fs: fs,
cgroupControllers: cgroupControllers,
}
- inode.InodeAttrs.Init(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+ inode.InodeAttrs.Init(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
inode.EnableLeakCheck()
inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
@@ -106,9 +106,9 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err
// If it failed to parse, check if it's one of the special handled files.
switch name {
case selfName:
- return i.newSelfSymlink(root), nil
+ return i.newSelfSymlink(ctx, root), nil
case threadSelfName:
- return i.newThreadSelfSymlink(root), nil
+ return i.newThreadSelfSymlink(ctx, root), nil
}
return nil, syserror.ENOENT
}
@@ -122,7 +122,7 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err
}
// IterDirents implements kernfs.inodeDirectory.IterDirents.
-func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) {
+func (i *tasksInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) {
// fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256
const FIRST_PROCESS_ENTRY = 256
diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go
index 07c27cdd9..01b7a6678 100644
--- a/pkg/sentry/fsimpl/proc/tasks_files.go
+++ b/pkg/sentry/fsimpl/proc/tasks_files.go
@@ -43,9 +43,9 @@ type selfSymlink struct {
var _ kernfs.Inode = (*selfSymlink)(nil)
-func (i *tasksInode) newSelfSymlink(creds *auth.Credentials) kernfs.Inode {
+func (i *tasksInode) newSelfSymlink(ctx context.Context, creds *auth.Credentials) kernfs.Inode {
inode := &selfSymlink{pidns: i.pidns}
- inode.Init(creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777)
+ inode.Init(ctx, creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777)
return inode
}
@@ -84,9 +84,9 @@ type threadSelfSymlink struct {
var _ kernfs.Inode = (*threadSelfSymlink)(nil)
-func (i *tasksInode) newThreadSelfSymlink(creds *auth.Credentials) kernfs.Inode {
+func (i *tasksInode) newThreadSelfSymlink(ctx context.Context, creds *auth.Credentials) kernfs.Inode {
inode := &threadSelfSymlink{pidns: i.pidns}
- inode.Init(creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777)
+ inode.Init(ctx, creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777)
return inode
}
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index 95420368d..7c7afdcfa 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -40,93 +40,93 @@ const (
)
// newSysDir returns the dentry corresponding to /proc/sys directory.
-func (fs *filesystem) newSysDir(root *auth.Credentials, k *kernel.Kernel) kernfs.Inode {
- return fs.newStaticDir(root, map[string]kernfs.Inode{
- "kernel": fs.newStaticDir(root, map[string]kernfs.Inode{
- "hostname": fs.newInode(root, 0444, &hostnameData{}),
- "shmall": fs.newInode(root, 0444, shmData(linux.SHMALL)),
- "shmmax": fs.newInode(root, 0444, shmData(linux.SHMMAX)),
- "shmmni": fs.newInode(root, 0444, shmData(linux.SHMMNI)),
+func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k *kernel.Kernel) kernfs.Inode {
+ return fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
+ "kernel": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
+ "hostname": fs.newInode(ctx, root, 0444, &hostnameData{}),
+ "shmall": fs.newInode(ctx, root, 0444, shmData(linux.SHMALL)),
+ "shmmax": fs.newInode(ctx, root, 0444, shmData(linux.SHMMAX)),
+ "shmmni": fs.newInode(ctx, root, 0444, shmData(linux.SHMMNI)),
}),
- "vm": fs.newStaticDir(root, map[string]kernfs.Inode{
- "mmap_min_addr": fs.newInode(root, 0444, &mmapMinAddrData{k: k}),
- "overcommit_memory": fs.newInode(root, 0444, newStaticFile("0\n")),
+ "vm": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
+ "mmap_min_addr": fs.newInode(ctx, root, 0444, &mmapMinAddrData{k: k}),
+ "overcommit_memory": fs.newInode(ctx, root, 0444, newStaticFile("0\n")),
}),
- "net": fs.newSysNetDir(root, k),
+ "net": fs.newSysNetDir(ctx, root, k),
})
}
// newSysNetDir returns the dentry corresponding to /proc/sys/net directory.
-func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) kernfs.Inode {
+func (fs *filesystem) newSysNetDir(ctx context.Context, root *auth.Credentials, k *kernel.Kernel) kernfs.Inode {
var contents map[string]kernfs.Inode
// TODO(gvisor.dev/issue/1833): Support for using the network stack in the
// network namespace of the calling process.
if stack := k.RootNetworkNamespace().Stack(); stack != nil {
contents = map[string]kernfs.Inode{
- "ipv4": fs.newStaticDir(root, map[string]kernfs.Inode{
- "tcp_recovery": fs.newInode(root, 0644, &tcpRecoveryData{stack: stack}),
- "tcp_rmem": fs.newInode(root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}),
- "tcp_sack": fs.newInode(root, 0644, &tcpSackData{stack: stack}),
- "tcp_wmem": fs.newInode(root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}),
- "ip_forward": fs.newInode(root, 0444, &ipForwarding{stack: stack}),
+ "ipv4": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
+ "tcp_recovery": fs.newInode(ctx, root, 0644, &tcpRecoveryData{stack: stack}),
+ "tcp_rmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}),
+ "tcp_sack": fs.newInode(ctx, root, 0644, &tcpSackData{stack: stack}),
+ "tcp_wmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}),
+ "ip_forward": fs.newInode(ctx, root, 0444, &ipForwarding{stack: stack}),
// The following files are simple stubs until they are implemented in
// netstack, most of these files are configuration related. We use the
// value closest to the actual netstack behavior or any empty file, all
// of these files will have mode 0444 (read-only for all users).
- "ip_local_port_range": fs.newInode(root, 0444, newStaticFile("16000 65535")),
- "ip_local_reserved_ports": fs.newInode(root, 0444, newStaticFile("")),
- "ipfrag_time": fs.newInode(root, 0444, newStaticFile("30")),
- "ip_nonlocal_bind": fs.newInode(root, 0444, newStaticFile("0")),
- "ip_no_pmtu_disc": fs.newInode(root, 0444, newStaticFile("1")),
+ "ip_local_port_range": fs.newInode(ctx, root, 0444, newStaticFile("16000 65535")),
+ "ip_local_reserved_ports": fs.newInode(ctx, root, 0444, newStaticFile("")),
+ "ipfrag_time": fs.newInode(ctx, root, 0444, newStaticFile("30")),
+ "ip_nonlocal_bind": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "ip_no_pmtu_disc": fs.newInode(ctx, root, 0444, newStaticFile("1")),
// tcp_allowed_congestion_control tell the user what they are able to
// do as an unprivledged process so we leave it empty.
- "tcp_allowed_congestion_control": fs.newInode(root, 0444, newStaticFile("")),
- "tcp_available_congestion_control": fs.newInode(root, 0444, newStaticFile("reno")),
- "tcp_congestion_control": fs.newInode(root, 0444, newStaticFile("reno")),
+ "tcp_allowed_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("")),
+ "tcp_available_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("reno")),
+ "tcp_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("reno")),
// Many of the following stub files are features netstack doesn't
// support. The unsupported features return "0" to indicate they are
// disabled.
- "tcp_base_mss": fs.newInode(root, 0444, newStaticFile("1280")),
- "tcp_dsack": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_early_retrans": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_fack": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_fastopen": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_fastopen_key": fs.newInode(root, 0444, newStaticFile("")),
- "tcp_invalid_ratelimit": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_keepalive_intvl": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_keepalive_probes": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_keepalive_time": fs.newInode(root, 0444, newStaticFile("7200")),
- "tcp_mtu_probing": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_no_metrics_save": fs.newInode(root, 0444, newStaticFile("1")),
- "tcp_probe_interval": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_probe_threshold": fs.newInode(root, 0444, newStaticFile("0")),
- "tcp_retries1": fs.newInode(root, 0444, newStaticFile("3")),
- "tcp_retries2": fs.newInode(root, 0444, newStaticFile("15")),
- "tcp_rfc1337": fs.newInode(root, 0444, newStaticFile("1")),
- "tcp_slow_start_after_idle": fs.newInode(root, 0444, newStaticFile("1")),
- "tcp_synack_retries": fs.newInode(root, 0444, newStaticFile("5")),
- "tcp_syn_retries": fs.newInode(root, 0444, newStaticFile("3")),
- "tcp_timestamps": fs.newInode(root, 0444, newStaticFile("1")),
+ "tcp_base_mss": fs.newInode(ctx, root, 0444, newStaticFile("1280")),
+ "tcp_dsack": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_early_retrans": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_fack": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_fastopen": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_fastopen_key": fs.newInode(ctx, root, 0444, newStaticFile("")),
+ "tcp_invalid_ratelimit": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_keepalive_intvl": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_keepalive_probes": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_keepalive_time": fs.newInode(ctx, root, 0444, newStaticFile("7200")),
+ "tcp_mtu_probing": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_no_metrics_save": fs.newInode(ctx, root, 0444, newStaticFile("1")),
+ "tcp_probe_interval": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_probe_threshold": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "tcp_retries1": fs.newInode(ctx, root, 0444, newStaticFile("3")),
+ "tcp_retries2": fs.newInode(ctx, root, 0444, newStaticFile("15")),
+ "tcp_rfc1337": fs.newInode(ctx, root, 0444, newStaticFile("1")),
+ "tcp_slow_start_after_idle": fs.newInode(ctx, root, 0444, newStaticFile("1")),
+ "tcp_synack_retries": fs.newInode(ctx, root, 0444, newStaticFile("5")),
+ "tcp_syn_retries": fs.newInode(ctx, root, 0444, newStaticFile("3")),
+ "tcp_timestamps": fs.newInode(ctx, root, 0444, newStaticFile("1")),
}),
- "core": fs.newStaticDir(root, map[string]kernfs.Inode{
- "default_qdisc": fs.newInode(root, 0444, newStaticFile("pfifo_fast")),
- "message_burst": fs.newInode(root, 0444, newStaticFile("10")),
- "message_cost": fs.newInode(root, 0444, newStaticFile("5")),
- "optmem_max": fs.newInode(root, 0444, newStaticFile("0")),
- "rmem_default": fs.newInode(root, 0444, newStaticFile("212992")),
- "rmem_max": fs.newInode(root, 0444, newStaticFile("212992")),
- "somaxconn": fs.newInode(root, 0444, newStaticFile("128")),
- "wmem_default": fs.newInode(root, 0444, newStaticFile("212992")),
- "wmem_max": fs.newInode(root, 0444, newStaticFile("212992")),
+ "core": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
+ "default_qdisc": fs.newInode(ctx, root, 0444, newStaticFile("pfifo_fast")),
+ "message_burst": fs.newInode(ctx, root, 0444, newStaticFile("10")),
+ "message_cost": fs.newInode(ctx, root, 0444, newStaticFile("5")),
+ "optmem_max": fs.newInode(ctx, root, 0444, newStaticFile("0")),
+ "rmem_default": fs.newInode(ctx, root, 0444, newStaticFile("212992")),
+ "rmem_max": fs.newInode(ctx, root, 0444, newStaticFile("212992")),
+ "somaxconn": fs.newInode(ctx, root, 0444, newStaticFile("128")),
+ "wmem_default": fs.newInode(ctx, root, 0444, newStaticFile("212992")),
+ "wmem_max": fs.newInode(ctx, root, 0444, newStaticFile("212992")),
}),
}
}
- return fs.newStaticDir(root, contents)
+ return fs.newStaticDir(ctx, root, contents)
}
// mmapMinAddrData implements vfs.DynamicBytesSource for
diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go
index 2582ababd..7ee6227a9 100644
--- a/pkg/sentry/fsimpl/proc/tasks_test.go
+++ b/pkg/sentry/fsimpl/proc/tasks_test.go
@@ -77,6 +77,7 @@ var (
"gid_map": linux.DT_REG,
"io": linux.DT_REG,
"maps": linux.DT_REG,
+ "mem": linux.DT_REG,
"mountinfo": linux.DT_REG,
"mounts": linux.DT_REG,
"net": linux.DT_DIR,
diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go
index cf91ea36c..fda1fa942 100644
--- a/pkg/sentry/fsimpl/sockfs/sockfs.go
+++ b/pkg/sentry/fsimpl/sockfs/sockfs.go
@@ -108,13 +108,13 @@ func (i *inode) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, e
// NewDentry constructs and returns a sockfs dentry.
//
// Preconditions: mnt.Filesystem() must have been returned by NewFilesystem().
-func NewDentry(creds *auth.Credentials, mnt *vfs.Mount) *vfs.Dentry {
+func NewDentry(ctx context.Context, mnt *vfs.Mount) *vfs.Dentry {
fs := mnt.Filesystem().Impl().(*filesystem)
// File mode matches net/socket.c:sock_alloc.
filemode := linux.FileMode(linux.S_IFSOCK | 0600)
i := &inode{}
- i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.Filesystem.NextIno(), filemode)
+ i.InodeAttrs.Init(ctx, auth.CredentialsFromContext(ctx), linux.UNNAMED_MAJOR, fs.devMinor, fs.Filesystem.NextIno(), filemode)
d := &kernfs.Dentry{}
d.Init(&fs.Filesystem, i)
diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD
index 906cd52cb..09043b572 100644
--- a/pkg/sentry/fsimpl/sys/BUILD
+++ b/pkg/sentry/fsimpl/sys/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "dir_refs.go",
package = "sys",
prefix = "dir",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "dir",
},
@@ -28,6 +28,7 @@ go_library(
"//pkg/coverage",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/sentry/arch",
"//pkg/sentry/fsimpl/kernfs",
"//pkg/sentry/kernel",
diff --git a/pkg/sentry/fsimpl/sys/kcov.go b/pkg/sentry/fsimpl/sys/kcov.go
index 94366d429..b13f141a8 100644
--- a/pkg/sentry/fsimpl/sys/kcov.go
+++ b/pkg/sentry/fsimpl/sys/kcov.go
@@ -29,7 +29,7 @@ import (
func (fs *filesystem) newKcovFile(ctx context.Context, creds *auth.Credentials) kernfs.Inode {
k := &kcovInode{}
- k.InodeAttrs.Init(creds, 0, 0, fs.NextIno(), linux.S_IFREG|0600)
+ k.InodeAttrs.Init(ctx, creds, 0, 0, fs.NextIno(), linux.S_IFREG|0600)
return k
}
@@ -102,7 +102,7 @@ func (fd *kcovFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) erro
func (fd *kcovFD) Release(ctx context.Context) {
// kcov instances have reference counts in Linux, but this seems sufficient
// for our purposes.
- fd.kcov.Clear()
+ fd.kcov.Clear(ctx)
}
// SetStat implements vfs.FileDescriptionImpl.SetStat.
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
index 1ad679830..7d2147141 100644
--- a/pkg/sentry/fsimpl/sys/sys.go
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -18,6 +18,7 @@ package sys
import (
"bytes"
"fmt"
+ "strconv"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -29,9 +30,12 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
-// Name is the default filesystem name.
-const Name = "sysfs"
-const defaultSysDirMode = linux.FileMode(0755)
+const (
+ // Name is the default filesystem name.
+ Name = "sysfs"
+ defaultSysDirMode = linux.FileMode(0755)
+ defaultMaxCachedDentries = uint64(1000)
+)
// FilesystemType implements vfs.FilesystemType.
//
@@ -62,28 +66,40 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
return nil, nil, err
}
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ maxCachedDentries := defaultMaxCachedDentries
+ if str, ok := mopts["dentry_cache_limit"]; ok {
+ delete(mopts, "dentry_cache_limit")
+ maxCachedDentries, err = strconv.ParseUint(str, 10, 64)
+ if err != nil {
+ ctx.Warningf("sys.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str)
+ return nil, nil, syserror.EINVAL
+ }
+ }
+
fs := &filesystem{
devMinor: devMinor,
}
+ fs.MaxCachedDentries = maxCachedDentries
fs.VFSFilesystem().Init(vfsObj, &fsType, fs)
- root := fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{
- "block": fs.newDir(creds, defaultSysDirMode, nil),
- "bus": fs.newDir(creds, defaultSysDirMode, nil),
- "class": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{
- "power_supply": fs.newDir(creds, defaultSysDirMode, nil),
+ root := fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{
+ "block": fs.newDir(ctx, creds, defaultSysDirMode, nil),
+ "bus": fs.newDir(ctx, creds, defaultSysDirMode, nil),
+ "class": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{
+ "power_supply": fs.newDir(ctx, creds, defaultSysDirMode, nil),
}),
- "dev": fs.newDir(creds, defaultSysDirMode, nil),
- "devices": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{
- "system": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{
+ "dev": fs.newDir(ctx, creds, defaultSysDirMode, nil),
+ "devices": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{
+ "system": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{
"cpu": cpuDir(ctx, fs, creds),
}),
}),
- "firmware": fs.newDir(creds, defaultSysDirMode, nil),
- "fs": fs.newDir(creds, defaultSysDirMode, nil),
+ "firmware": fs.newDir(ctx, creds, defaultSysDirMode, nil),
+ "fs": fs.newDir(ctx, creds, defaultSysDirMode, nil),
"kernel": kernelDir(ctx, fs, creds),
- "module": fs.newDir(creds, defaultSysDirMode, nil),
- "power": fs.newDir(creds, defaultSysDirMode, nil),
+ "module": fs.newDir(ctx, creds, defaultSysDirMode, nil),
+ "power": fs.newDir(ctx, creds, defaultSysDirMode, nil),
})
var rootD kernfs.Dentry
rootD.Init(&fs.Filesystem, root)
@@ -94,14 +110,14 @@ func cpuDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs
k := kernel.KernelFromContext(ctx)
maxCPUCores := k.ApplicationCores()
children := map[string]kernfs.Inode{
- "online": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)),
- "possible": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)),
- "present": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)),
+ "online": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)),
+ "possible": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)),
+ "present": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)),
}
for i := uint(0); i < maxCPUCores; i++ {
- children[fmt.Sprintf("cpu%d", i)] = fs.newDir(creds, linux.FileMode(0555), nil)
+ children[fmt.Sprintf("cpu%d", i)] = fs.newDir(ctx, creds, linux.FileMode(0555), nil)
}
- return fs.newDir(creds, defaultSysDirMode, children)
+ return fs.newDir(ctx, creds, defaultSysDirMode, children)
}
func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs.Inode {
@@ -111,12 +127,12 @@ func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) ker
var children map[string]kernfs.Inode
if coverage.KcovAvailable() {
children = map[string]kernfs.Inode{
- "debug": fs.newDir(creds, linux.FileMode(0700), map[string]kernfs.Inode{
+ "debug": fs.newDir(ctx, creds, linux.FileMode(0700), map[string]kernfs.Inode{
"kcov": fs.newKcovFile(ctx, creds),
}),
}
}
- return fs.newDir(creds, defaultSysDirMode, children)
+ return fs.newDir(ctx, creds, defaultSysDirMode, children)
}
// Release implements vfs.FilesystemImpl.Release.
@@ -140,9 +156,9 @@ type dir struct {
locks vfs.FileLocks
}
-func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
+func (fs *filesystem) newDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode {
d := &dir{}
- d.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755)
+ d.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755)
d.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
d.EnableLeakCheck()
d.IncLinks(d.OrderedChildren.Populate(contents))
@@ -191,9 +207,9 @@ func (c *cpuFile) Generate(ctx context.Context, buf *bytes.Buffer) error {
return nil
}
-func (fs *filesystem) newCPUFile(creds *auth.Credentials, maxCores uint, mode linux.FileMode) kernfs.Inode {
+func (fs *filesystem) newCPUFile(ctx context.Context, creds *auth.Credentials, maxCores uint, mode linux.FileMode) kernfs.Inode {
c := &cpuFile{maxCores: maxCores}
- c.DynamicBytesFile.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), c, mode)
+ c.DynamicBytesFile.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), c, mode)
return c
}
diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go
index 1813269e0..738c0c9cc 100644
--- a/pkg/sentry/fsimpl/testutil/kernel.go
+++ b/pkg/sentry/fsimpl/testutil/kernel.go
@@ -147,7 +147,12 @@ func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns
FSContext: kernel.NewFSContextVFS2(root, cwd, 0022),
FDTable: k.NewFDTable(),
}
- return k.TaskSet().NewTask(config)
+ t, err := k.TaskSet().NewTask(ctx, config)
+ if err != nil {
+ config.ThreadGroup.Release(ctx)
+ return nil, err
+ }
+ return t, nil
}
func newFakeExecutable(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, root vfs.VirtualDentry) (*vfs.FileDescription, error) {
diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD
index 5cd428d64..fe520b6fd 100644
--- a/pkg/sentry/fsimpl/tmpfs/BUILD
+++ b/pkg/sentry/fsimpl/tmpfs/BUILD
@@ -31,7 +31,7 @@ go_template_instance(
out = "inode_refs.go",
package = "tmpfs",
prefix = "inode",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "inode",
},
@@ -48,6 +48,7 @@ go_library(
"inode_refs.go",
"named_pipe.go",
"regular_file.go",
+ "save_restore.go",
"socket_file.go",
"symlink.go",
"tmpfs.go",
@@ -60,6 +61,7 @@ go_library(
"//pkg/fspath",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/fs",
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index ce4e3eda7..98680fde9 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -42,7 +42,7 @@ type regularFile struct {
inode inode
// memFile is a platform.File used to allocate pages to this regularFile.
- memFile *pgalloc.MemoryFile
+ memFile *pgalloc.MemoryFile `state:"nosave"`
// memoryUsageKind is the memory accounting category under which pages backing
// this regularFile's contents are accounted.
@@ -92,7 +92,7 @@ type regularFile struct {
func (fs *filesystem) newRegularFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode {
file := &regularFile{
- memFile: fs.memFile,
+ memFile: fs.mfp.MemoryFile(),
memoryUsageKind: usage.Tmpfs,
seals: linux.F_SEAL_SEAL,
}
diff --git a/pkg/sentry/fsimpl/tmpfs/save_restore.go b/pkg/sentry/fsimpl/tmpfs/save_restore.go
new file mode 100644
index 000000000..b27f75cc2
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/save_restore.go
@@ -0,0 +1,20 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+// afterLoad is called by stateify.
+func (rf *regularFile) afterLoad() {
+ rf.memFile = rf.inode.fs.mfp.MemoryFile()
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index e2a0aac69..4ce859d57 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -61,8 +61,9 @@ type FilesystemType struct{}
type filesystem struct {
vfsfs vfs.Filesystem
- // memFile is used to allocate pages to for regular files.
- memFile *pgalloc.MemoryFile
+ // mfp is used to allocate memory that stores regular file contents. mfp is
+ // immutable.
+ mfp pgalloc.MemoryFileProvider
// clock is a realtime clock used to set timestamps in file operations.
clock time.Clock
@@ -106,8 +107,8 @@ type FilesystemOpts struct {
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, _ string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
- memFileProvider := pgalloc.MemoryFileProviderFromContext(ctx)
- if memFileProvider == nil {
+ mfp := pgalloc.MemoryFileProviderFromContext(ctx)
+ if mfp == nil {
panic("MemoryFileProviderFromContext returned nil")
}
@@ -181,7 +182,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
clock := time.RealtimeClockFromContext(ctx)
fs := filesystem{
- memFile: memFileProvider.MemoryFile(),
+ mfp: mfp,
clock: clock,
devMinor: devMinor,
}
diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD
index 0ca750281..e265be0ee 100644
--- a/pkg/sentry/fsimpl/verity/BUILD
+++ b/pkg/sentry/fsimpl/verity/BUILD
@@ -6,6 +6,7 @@ go_library(
name = "verity",
srcs = [
"filesystem.go",
+ "save_restore.go",
"verity.go",
],
visibility = ["//pkg/sentry:internal"],
@@ -15,6 +16,7 @@ go_library(
"//pkg/fspath",
"//pkg/marshal/primitive",
"//pkg/merkletree",
+ "//pkg/refsvfs2",
"//pkg/sentry/arch",
"//pkg/sentry/fs/lock",
"//pkg/sentry/kernel",
@@ -38,10 +40,12 @@ go_test(
"//pkg/context",
"//pkg/fspath",
"//pkg/sentry/arch",
+ "//pkg/sentry/fsimpl/testutil",
"//pkg/sentry/fsimpl/tmpfs",
+ "//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
- "//pkg/sentry/kernel/contexttest",
"//pkg/sentry/vfs",
+ "//pkg/syserror",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index 03da505e1..2f6050cfd 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -192,7 +192,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// contains the expected xattrs. If the file or the xattr does not
// exist, it indicates unexpected modifications to the file system.
if err == syserror.ENOENT || err == syserror.ENODATA {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err))
}
if err != nil {
return nil, err
@@ -201,7 +201,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// unexpected modifications to the file system.
offset, err := strconv.Atoi(off)
if err != nil {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err))
}
// Open parent Merkle tree file to read and verify child's hash.
@@ -215,7 +215,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// The parent Merkle tree file should have been created. If it's
// missing, it indicates an unexpected modification to the file system.
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err))
}
if err != nil {
return nil, err
@@ -233,7 +233,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// contains the expected xattrs. If the file or the xattr does not
// exist, it indicates unexpected modifications to the file system.
if err == syserror.ENOENT || err == syserror.ENODATA {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err))
}
if err != nil {
return nil, err
@@ -243,7 +243,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// unexpected modifications to the file system.
parentSize, err := strconv.Atoi(dataSize)
if err != nil {
- return nil, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
}
fdReader := vfs.FileReadWriteSeeker{
@@ -256,7 +256,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
Start: parent.lowerVD,
}, &vfs.StatOptions{})
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err))
}
if err != nil {
return nil, err
@@ -267,20 +267,22 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de
// Verify returns with success.
var buf bytes.Buffer
if _, err := merkletree.Verify(&merkletree.VerifyParams{
- Out: &buf,
- File: &fdReader,
- Tree: &fdReader,
- Size: int64(parentSize),
- Name: parent.name,
- Mode: uint32(parentStat.Mode),
- UID: parentStat.UID,
- GID: parentStat.GID,
+ Out: &buf,
+ File: &fdReader,
+ Tree: &fdReader,
+ Size: int64(parentSize),
+ Name: parent.name,
+ Mode: uint32(parentStat.Mode),
+ UID: parentStat.UID,
+ GID: parentStat.GID,
+ //TODO(b/156980949): Support passing other hash algorithms.
+ HashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
ReadOffset: int64(offset),
- ReadSize: int64(merkletree.DigestSize()),
+ ReadSize: int64(merkletree.DigestSize(linux.FS_VERITY_HASH_ALG_SHA256)),
Expected: parent.hash,
DataAndTreeInSameFile: true,
}); err != nil && err != io.EOF {
- return nil, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification for %s failed: %v", childPath, err))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Verification for %s failed: %v", childPath, err))
}
// Cache child hash when it's verified the first time.
@@ -312,7 +314,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat
Flags: linux.O_RDONLY,
})
if err == syserror.ENOENT {
- return alertIntegrityViolation(err, fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err))
+ return alertIntegrityViolation(fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err))
}
if err != nil {
return err
@@ -324,7 +326,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat
})
if err == syserror.ENODATA {
- return alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err))
+ return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err))
}
if err != nil {
return err
@@ -332,7 +334,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat
size, err := strconv.Atoi(merkleSize)
if err != nil {
- return alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
+ return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
}
fdReader := vfs.FileReadWriteSeeker{
@@ -342,14 +344,16 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat
var buf bytes.Buffer
params := &merkletree.VerifyParams{
- Out: &buf,
- Tree: &fdReader,
- Size: int64(size),
- Name: d.name,
- Mode: uint32(stat.Mode),
- UID: stat.UID,
- GID: stat.GID,
- ReadOffset: 0,
+ Out: &buf,
+ Tree: &fdReader,
+ Size: int64(size),
+ Name: d.name,
+ Mode: uint32(stat.Mode),
+ UID: stat.UID,
+ GID: stat.GID,
+ //TODO(b/156980949): Support passing other hash algorithms.
+ HashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
+ ReadOffset: 0,
// Set read size to 0 so only the metadata is verified.
ReadSize: 0,
Expected: d.hash,
@@ -360,17 +364,57 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat
}
if _, err := merkletree.Verify(params); err != nil && err != io.EOF {
- return alertIntegrityViolation(err, fmt.Sprintf("Verification stat for %s failed: %v", childPath, err))
+ return alertIntegrityViolation(fmt.Sprintf("Verification stat for %s failed: %v", childPath, err))
}
d.mode = uint32(stat.Mode)
d.uid = stat.UID
d.gid = stat.GID
+ d.size = uint32(size)
return nil
}
// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
if child, ok := parent.children[name]; ok {
+ // If verity is enabled on child, we should check again whether
+ // the file and the corresponding Merkle tree are as expected,
+ // in order to catch deletion/renaming after the last time it's
+ // accessed.
+ if child.verityEnabled() {
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ // Get the path to the child dentry. This is only used
+ // to provide path information in failure case.
+ path, err := vfsObj.PathnameWithDeleted(ctx, child.fs.rootDentry.lowerVD, child.lowerVD)
+ if err != nil {
+ return nil, err
+ }
+
+ childVD, err := parent.getLowerAt(ctx, vfsObj, name)
+ if err == syserror.ENOENT {
+ // The file was previously accessed. If the
+ // file does not exist now, it indicates an
+ // unexpected modification to the file system.
+ return nil, alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", path))
+ }
+ if err != nil {
+ return nil, err
+ }
+ defer childVD.DecRef(ctx)
+
+ childMerkleVD, err := parent.getLowerAt(ctx, vfsObj, merklePrefix+name)
+ // The Merkle tree file was previous accessed. If it
+ // does not exist now, it indicates an unexpected
+ // modification to the file system.
+ if err == syserror.ENOENT {
+ return nil, alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", path))
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ defer childMerkleVD.DecRef(ctx)
+ }
+
// If enabling verification on files/directories is not allowed
// during runtime, all cached children are already verified. If
// runtime enable is allowed and the parent directory is
@@ -418,13 +462,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s
func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) {
vfsObj := fs.vfsfs.VirtualFilesystem()
- childFilename := fspath.Parse(name)
- childVD, childErr := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{
- Root: parent.lowerVD,
- Start: parent.lowerVD,
- Path: childFilename,
- }, &vfs.GetDentryOptions{})
-
+ childVD, childErr := parent.getLowerAt(ctx, vfsObj, name)
// We will handle ENOENT separately, as it may indicate unexpected
// modifications to the file system, and may cause a sentry panic.
if childErr != nil && childErr != syserror.ENOENT {
@@ -437,13 +475,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
defer childVD.DecRef(ctx)
}
- childMerkleFilename := merklePrefix + name
- childMerkleVD, childMerkleErr := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{
- Root: parent.lowerVD,
- Start: parent.lowerVD,
- Path: fspath.Parse(childMerkleFilename),
- }, &vfs.GetDentryOptions{})
-
+ childMerkleVD, childMerkleErr := parent.getLowerAt(ctx, vfsObj, merklePrefix+name)
// We will handle ENOENT separately, as it may indicate unexpected
// modifications to the file system, and may cause a sentry panic.
if childMerkleErr != nil && childMerkleErr != syserror.ENOENT {
@@ -472,7 +504,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
// corresponding Merkle tree is found. This indicates an
// unexpected modification to the file system that
// removed/renamed the child.
- return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name))
} else if childErr == nil && childMerkleErr == syserror.ENOENT {
// If in allowRuntimeEnable mode, and the Merkle tree file is
// not created yet, we create an empty Merkle tree file, so that
@@ -488,7 +520,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{
Root: parent.lowerVD,
Start: parent.lowerVD,
- Path: fspath.Parse(childMerkleFilename),
+ Path: fspath.Parse(merklePrefix + name),
}, &vfs.OpenOptions{
Flags: linux.O_RDWR | linux.O_CREAT,
Mode: 0644,
@@ -497,11 +529,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
return nil, err
}
childMerkleFD.DecRef(ctx)
- childMerkleVD, err = vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{
- Root: parent.lowerVD,
- Start: parent.lowerVD,
- Path: fspath.Parse(childMerkleFilename),
- }, &vfs.GetDentryOptions{})
+ childMerkleVD, err = parent.getLowerAt(ctx, vfsObj, merklePrefix+name)
if err != nil {
return nil, err
}
@@ -509,7 +537,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
// If runtime enable is not allowed. This indicates an
// unexpected modification to the file system that
// removed/renamed the Merkle tree file.
- return nil, alertIntegrityViolation(childMerkleErr, fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name))
}
} else if childErr == syserror.ENOENT && childMerkleErr == syserror.ENOENT {
// Both the child and the corresponding Merkle tree are missing.
@@ -518,7 +546,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
// TODO(b/167752508): Investigate possible ways to differentiate
// cases that both files are deleted from cases that they never
// exist in the file system.
- return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Failed to find file %s", parentPath+"/"+name))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Failed to find file %s", parentPath+"/"+name))
}
mask := uint32(linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID)
@@ -762,7 +790,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
// missing, it indicates an unexpected modification to the file system.
if err != nil {
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("File %s expected but not found", path))
+ return nil, alertIntegrityViolation(fmt.Sprintf("File %s expected but not found", path))
}
return nil, err
}
@@ -785,7 +813,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
// the file system.
if err != nil {
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path))
}
return nil, err
}
@@ -810,7 +838,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
})
if err != nil {
if err == syserror.ENOENT {
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path))
}
return nil, err
}
@@ -828,7 +856,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf
if err != nil {
if err == syserror.ENOENT {
parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD)
- return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath))
+ return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath))
}
return nil, err
}
diff --git a/pkg/sentry/fsimpl/verity/save_restore.go b/pkg/sentry/fsimpl/verity/save_restore.go
new file mode 100644
index 000000000..4a161163c
--- /dev/null
+++ b/pkg/sentry/fsimpl/verity/save_restore.go
@@ -0,0 +1,27 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package verity
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/refsvfs2"
+)
+
+func (d *dentry) afterLoad() {
+ if refsvfs2.LeakCheckEnabled() && atomic.LoadInt64(&d.refs) != -1 {
+ refsvfs2.Register(d, "verity.dentry")
+ }
+}
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index 8dc9e26bc..92ca6ca6b 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -23,6 +23,7 @@ package verity
import (
"fmt"
+ "math"
"strconv"
"sync/atomic"
@@ -31,6 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/merkletree"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/arch"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -153,10 +155,10 @@ func (FilesystemType) Release(ctx context.Context) {}
// alertIntegrityViolation alerts a violation of integrity, which usually means
// unexpected modification to the file system is detected. In
-// noCrashOnVerificationFailure mode, it returns an error, otherwise it panic.
-func alertIntegrityViolation(err error, msg string) error {
+// noCrashOnVerificationFailure mode, it returns EIO, otherwise it panic.
+func alertIntegrityViolation(msg string) error {
if noCrashOnVerificationFailure {
- return err
+ return syserror.EIO
}
panic(msg)
}
@@ -236,7 +238,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// the root Merkle file, or it's never generated.
fs.vfsfs.DecRef(ctx)
d.DecRef(ctx)
- return nil, nil, alertIntegrityViolation(err, "Failed to find root Merkle file")
+ return nil, nil, alertIntegrityViolation("Failed to find root Merkle file")
}
d.lowerMerkleVD = lowerMerkleVD
@@ -289,11 +291,12 @@ type dentry struct {
// fs is the owning filesystem. fs is immutable.
fs *filesystem
- // mode, uid and gid are the file mode, owner, and group of the file in
- // the underlying file system.
+ // mode, uid, gid and size are the file mode, owner, group, and size of
+ // the file in the underlying file system.
mode uint32
uid uint32
gid uint32
+ size uint32
// parent is the dentry corresponding to this dentry's parent directory.
// name is this dentry's name in parent. If this dentry is a filesystem
@@ -331,6 +334,9 @@ func (fs *filesystem) newDentry() *dentry {
fs: fs,
}
d.vfsd.Init(d)
+ if refsvfs2.LeakCheckEnabled() {
+ refsvfs2.Register(d, "verity.dentry")
+ }
return d
}
@@ -393,6 +399,9 @@ func (d *dentry) destroyLocked(ctx context.Context) {
if d.lowerVD.Ok() {
d.lowerVD.DecRef(ctx)
}
+ if refsvfs2.LeakCheckEnabled() {
+ refsvfs2.Unregister(d, "verity.dentry")
+ }
if d.lowerMerkleVD.Ok() {
d.lowerMerkleVD.DecRef(ctx)
@@ -412,6 +421,11 @@ func (d *dentry) destroyLocked(ctx context.Context) {
}
}
+// LeakMessage implements refsvfs2.CheckedObject.LeakMessage.
+func (d *dentry) LeakMessage() string {
+ return fmt.Sprintf("[verity.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs))
+}
+
// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
func (d *dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) {
//TODO(b/159261227): Implement InotifyWithParent.
@@ -448,6 +462,16 @@ func (d *dentry) verityEnabled() bool {
return !d.fs.allowRuntimeEnable || len(d.hash) != 0
}
+// getLowerAt returns the dentry in the underlying file system, which is
+// represented by filename relative to d.
+func (d *dentry) getLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, filename string) (vfs.VirtualDentry, error) {
+ return vfsObj.GetDentryAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.lowerVD,
+ Start: d.lowerVD,
+ Path: fspath.Parse(filename),
+ }, &vfs.GetDentryOptions{})
+}
+
func (d *dentry) readlink(ctx context.Context) (string, error) {
return d.fs.vfsfs.VirtualFilesystem().ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{
Root: d.lowerVD,
@@ -489,6 +513,10 @@ type fileDescription struct {
// directory that contains the current file/directory. This is only used
// if allowRuntimeEnable is set to true.
parentMerkleWriter *vfs.FileDescription
+
+ // off is the file offset. off is protected by mu.
+ mu sync.Mutex `state:"nosave"`
+ off int64
}
// Release implements vfs.FileDescriptionImpl.Release.
@@ -524,6 +552,32 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions)
return syserror.EPERM
}
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ n := int64(0)
+ switch whence {
+ case linux.SEEK_SET:
+ // use offset as specified
+ case linux.SEEK_CUR:
+ n = fd.off
+ case linux.SEEK_END:
+ n = int64(fd.d.size)
+ default:
+ return 0, syserror.EINVAL
+ }
+ if offset > math.MaxInt64-n {
+ return 0, syserror.EINVAL
+ }
+ offset += n
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ fd.off = offset
+ return offset, nil
+}
+
// generateMerkle generates a Merkle tree file for fd. If fd points to a file
// /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The hash
// of the generated Merkle tree and the data size is returned. If fd points to
@@ -546,6 +600,8 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64,
params := &merkletree.GenerateParams{
TreeReader: &merkleReader,
TreeWriter: &merkleWriter,
+ //TODO(b/156980949): Support passing other hash algorithms.
+ HashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
}
switch atomic.LoadUint32(&fd.d.mode) & linux.S_IFMT {
@@ -611,7 +667,7 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (ui
// or directory other than the root, the parent Merkle tree file should
// have also been initialized.
if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || (fd.parentMerkleWriter == nil && fd.d != fd.d.fs.rootDentry) {
- return 0, alertIntegrityViolation(syserror.EIO, "Unexpected verity fd: missing expected underlying fds")
+ return 0, alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds")
}
hash, dataSize, err := fd.generateMerkle(ctx)
@@ -657,6 +713,9 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (ui
// measureVerity returns the hash of fd, saved in verityDigest.
func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, verityDigest usermem.Addr) (uintptr, error) {
t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ return 0, syserror.EINVAL
+ }
var metadata linux.DigestMetadata
// If allowRuntimeEnable is true, an empty fd.d.hash indicates that
@@ -667,7 +726,7 @@ func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, ve
if fd.d.fs.allowRuntimeEnable {
return 0, syserror.ENODATA
}
- return 0, alertIntegrityViolation(syserror.ENODATA, "Ioctl measureVerity: no hash found")
+ return 0, alertIntegrityViolation("Ioctl measureVerity: no hash found")
}
// The first part of VerityDigest is the metadata.
@@ -702,6 +761,9 @@ func (fd *fileDescription) verityFlags(ctx context.Context, uio usermem.IO, flag
}
t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ return 0, syserror.EINVAL
+ }
_, err := primitive.CopyInt32Out(t, flags, f)
return 0, err
}
@@ -722,6 +784,16 @@ func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.
}
}
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // Implement Read with PRead by setting offset.
+ fd.mu.Lock()
+ n, err := fd.PRead(ctx, dst, fd.off, opts)
+ fd.off += n
+ fd.mu.Unlock()
+ return n, err
+}
+
// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
// No need to verify if the file is not enabled yet in
@@ -742,7 +814,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
// contains the expected xattrs. If the xattr does not exist, it
// indicates unexpected modifications to the file system.
if err == syserror.ENODATA {
- return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err))
+ return 0, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err))
}
if err != nil {
return 0, err
@@ -752,7 +824,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
// unexpected modifications to the file system.
size, err := strconv.Atoi(dataSize)
if err != nil {
- return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err))
+ return 0, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err))
}
dataReader := vfs.FileReadWriteSeeker{
@@ -766,25 +838,37 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
}
n, err := merkletree.Verify(&merkletree.VerifyParams{
- Out: dst.Writer(ctx),
- File: &dataReader,
- Tree: &merkleReader,
- Size: int64(size),
- Name: fd.d.name,
- Mode: fd.d.mode,
- UID: fd.d.uid,
- GID: fd.d.gid,
+ Out: dst.Writer(ctx),
+ File: &dataReader,
+ Tree: &merkleReader,
+ Size: int64(size),
+ Name: fd.d.name,
+ Mode: fd.d.mode,
+ UID: fd.d.uid,
+ GID: fd.d.gid,
+ //TODO(b/156980949): Support passing other hash algorithms.
+ HashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256,
ReadOffset: offset,
ReadSize: dst.NumBytes(),
Expected: fd.d.hash,
DataAndTreeInSameFile: false,
})
if err != nil {
- return 0, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification failed: %v", err))
+ return 0, alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err))
}
return n, err
}
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.EROFS
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.EROFS
+}
+
// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
return fd.lowerFD.LockPOSIX(ctx, uid, t, start, length, whence, block)
diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go
index e301d35f5..c647cbfd3 100644
--- a/pkg/sentry/fsimpl/verity/verity_test.go
+++ b/pkg/sentry/fsimpl/verity/verity_test.go
@@ -25,10 +25,12 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -41,11 +43,18 @@ const maxDataSize = 100000
// newVerityRoot creates a new verity mount, and returns the root. The
// underlying file system is tmpfs. If the error is not nil, then cleanup
// should be called when the root is no longer needed.
-func newVerityRoot(ctx context.Context, t *testing.T) (*vfs.VirtualFilesystem, vfs.VirtualDentry, error) {
+func newVerityRoot(t *testing.T) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *kernel.Task, error) {
+ k, err := testutil.Boot()
+ if err != nil {
+ t.Fatalf("testutil.Boot: %v", err)
+ }
+
+ ctx := k.SupervisorContext()
+
rand.Seed(time.Now().UnixNano())
vfsObj := &vfs.VirtualFilesystem{}
if err := vfsObj.Init(ctx); err != nil {
- return nil, vfs.VirtualDentry{}, fmt.Errorf("VFS init: %v", err)
+ return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err)
}
vfsObj.MustRegisterFilesystemType("verity", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
@@ -67,16 +76,26 @@ func newVerityRoot(ctx context.Context, t *testing.T) (*vfs.VirtualFilesystem, v
},
})
if err != nil {
- return nil, vfs.VirtualDentry{}, fmt.Errorf("NewMountNamespace: %v", err)
+ return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("NewMountNamespace: %v", err)
}
root := mntns.Root()
root.IncRef()
+
+ // Use lowerRoot in the task as we modify the lower file system
+ // directly in many tests.
+ lowerRoot := root.Dentry().Impl().(*dentry).lowerVD
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ task, err := testutil.CreateTask(ctx, "name", tc, mntns, lowerRoot, lowerRoot)
+ if err != nil {
+ t.Fatalf("testutil.CreateTask: %v", err)
+ }
+
t.Helper()
t.Cleanup(func() {
root.DecRef(ctx)
mntns.DecRef(ctx)
})
- return vfsObj, root, nil
+ return vfsObj, root, task, nil
}
// newFileFD creates a new file in the verity mount, and returns the FD. The FD
@@ -145,8 +164,7 @@ func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) er
// TestOpen ensures that when a file is created, the corresponding Merkle tree
// file and the root Merkle tree file exist.
func TestOpen(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
+ vfsObj, root, ctx, err := newVerityRoot(t)
if err != nil {
t.Fatalf("newVerityRoot: %v", err)
}
@@ -180,11 +198,10 @@ func TestOpen(t *testing.T) {
}
}
-// TestUnmodifiedFileSucceeds ensures that read from an untouched verity file
-// succeeds after enabling verity for it.
-func TestReadUnmodifiedFileSucceeds(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
+// TestPReadUnmodifiedFileSucceeds ensures that pread from an untouched verity
+// file succeeds after enabling verity for it.
+func TestPReadUnmodifiedFileSucceeds(t *testing.T) {
+ vfsObj, root, ctx, err := newVerityRoot(t)
if err != nil {
t.Fatalf("newVerityRoot: %v", err)
}
@@ -213,11 +230,42 @@ func TestReadUnmodifiedFileSucceeds(t *testing.T) {
}
}
+// TestReadUnmodifiedFileSucceeds ensures that read from an untouched verity
+// file succeeds after enabling verity for it.
+func TestReadUnmodifiedFileSucceeds(t *testing.T) {
+ vfsObj, root, ctx, err := newVerityRoot(t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirm a normal read succeeds.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ buf := make([]byte, size)
+ n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
+ if err != nil && err != io.EOF {
+ t.Fatalf("fd.Read: %v", err)
+ }
+
+ if n != int64(size) {
+ t.Errorf("fd.PRead got read length %d, want %d", n, size)
+ }
+}
+
// TestReopenUnmodifiedFileSucceeds ensures that reopen an untouched verity file
// succeeds after enabling verity for it.
func TestReopenUnmodifiedFileSucceeds(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
+ vfsObj, root, ctx, err := newVerityRoot(t)
if err != nil {
t.Fatalf("newVerityRoot: %v", err)
}
@@ -248,10 +296,10 @@ func TestReopenUnmodifiedFileSucceeds(t *testing.T) {
}
}
-// TestModifiedFileFails ensures that read from a modified verity file fails.
-func TestModifiedFileFails(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
+// TestPReadModifiedFileFails ensures that read from a modified verity file
+// fails.
+func TestPReadModifiedFileFails(t *testing.T) {
+ vfsObj, root, ctx, err := newVerityRoot(t)
if err != nil {
t.Fatalf("newVerityRoot: %v", err)
}
@@ -289,15 +337,59 @@ func TestModifiedFileFails(t *testing.T) {
// Confirm that read from the modified file fails.
buf := make([]byte, size)
if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
- t.Fatalf("fd.PRead succeeded with modified file")
+ t.Fatalf("fd.PRead succeeded, expected failure")
+ }
+}
+
+// TestReadModifiedFileFails ensures that read from a modified verity file
+// fails.
+func TestReadModifiedFileFails(t *testing.T) {
+ vfsObj, root, ctx, err := newVerityRoot(t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ // Open a new lowerFD that's read/writable.
+ lowerVD := fd.Impl().(*fileDescription).d.lowerVD
+
+ lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: lowerVD,
+ Start: lowerVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt: %v", err)
+ }
+
+ if err := corruptRandomBit(ctx, lowerFD, size); err != nil {
+ t.Fatalf("corruptRandomBit: %v", err)
+ }
+
+ // Confirm that read from the modified file fails.
+ buf := make([]byte, size)
+ if _, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}); err == nil {
+ t.Fatalf("fd.Read succeeded, expected failure")
}
}
// TestModifiedMerkleFails ensures that read from a verity file fails if the
// corresponding Merkle tree file is modified.
func TestModifiedMerkleFails(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
+ vfsObj, root, ctx, err := newVerityRoot(t)
if err != nil {
t.Fatalf("newVerityRoot: %v", err)
}
@@ -350,8 +442,7 @@ func TestModifiedMerkleFails(t *testing.T) {
// verity enabled directory fails if the hashes related to the target file in
// the parent Merkle tree file is modified.
func TestModifiedParentMerkleFails(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
+ vfsObj, root, ctx, err := newVerityRoot(t)
if err != nil {
t.Fatalf("newVerityRoot: %v", err)
}
@@ -428,8 +519,7 @@ func TestModifiedParentMerkleFails(t *testing.T) {
// TestUnmodifiedStatSucceeds ensures that stat of an untouched verity file
// succeeds after enabling verity for it.
func TestUnmodifiedStatSucceeds(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
+ vfsObj, root, ctx, err := newVerityRoot(t)
if err != nil {
t.Fatalf("newVerityRoot: %v", err)
}
@@ -455,8 +545,7 @@ func TestUnmodifiedStatSucceeds(t *testing.T) {
// TestModifiedStatFails checks that getting stat for a file with modified stat
// should fail.
func TestModifiedStatFails(t *testing.T) {
- ctx := contexttest.Context(t)
- vfsObj, root, err := newVerityRoot(ctx, t)
+ vfsObj, root, ctx, err := newVerityRoot(t)
if err != nil {
t.Fatalf("newVerityRoot: %v", err)
}
@@ -489,3 +578,123 @@ func TestModifiedStatFails(t *testing.T) {
t.Errorf("fd.Stat succeeded when it should fail")
}
}
+
+// TestOpenDeletedOrRenamedFileFails ensures that opening a deleted/renamed
+// verity enabled file or the corresponding Merkle tree file fails with the
+// verify error.
+func TestOpenDeletedFileFails(t *testing.T) {
+ testCases := []struct {
+ // Tests removing files is remove is true. Otherwise tests
+ // renaming files.
+ remove bool
+ // The original file is removed/renamed if changeFile is true.
+ changeFile bool
+ // The Merkle tree file is removed/renamed if changeMerkleFile
+ // is true.
+ changeMerkleFile bool
+ }{
+ {
+ remove: true,
+ changeFile: true,
+ changeMerkleFile: false,
+ },
+ {
+ remove: true,
+ changeFile: false,
+ changeMerkleFile: true,
+ },
+ {
+ remove: false,
+ changeFile: true,
+ changeMerkleFile: false,
+ },
+ {
+ remove: false,
+ changeFile: true,
+ changeMerkleFile: false,
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("remove:%t", tc.remove), func(t *testing.T) {
+ vfsObj, root, ctx, err := newVerityRoot(t)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-file"
+ fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newFileFD: %v", err)
+ }
+
+ // Enable verity on the file.
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("Ioctl: %v", err)
+ }
+
+ rootLowerVD := root.Dentry().Impl().(*dentry).lowerVD
+ if tc.remove {
+ if tc.changeFile {
+ if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(filename),
+ }); err != nil {
+ t.Fatalf("UnlinkAt: %v", err)
+ }
+ }
+ if tc.changeMerkleFile {
+ if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(merklePrefix + filename),
+ }); err != nil {
+ t.Fatalf("UnlinkAt: %v", err)
+ }
+ }
+ } else {
+ newFilename := "renamed-test-file"
+ if tc.changeFile {
+ if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(filename),
+ }, &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(newFilename),
+ }, &vfs.RenameOptions{}); err != nil {
+ t.Fatalf("RenameAt: %v", err)
+ }
+ }
+ if tc.changeMerkleFile {
+ if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(merklePrefix + filename),
+ }, &vfs.PathOperation{
+ Root: rootLowerVD,
+ Start: rootLowerVD,
+ Path: fspath.Parse(merklePrefix + newFilename),
+ }, &vfs.RenameOptions{}); err != nil {
+ t.Fatalf("UnlinkAt: %v", err)
+ }
+ }
+ }
+
+ // Ensure reopening the verity enabled file fails.
+ if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ Mode: linux.ModeRegular,
+ }); err != syserror.EIO {
+ t.Errorf("got OpenAt error: %v, expected EIO", err)
+ }
+ })
+ }
+}
diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go
index fbe6d6aa6..f31277d30 100644
--- a/pkg/sentry/inet/inet.go
+++ b/pkg/sentry/inet/inet.go
@@ -32,9 +32,13 @@ type Stack interface {
InterfaceAddrs() map[int32][]InterfaceAddr
// AddInterfaceAddr adds an address to the network interface identified by
- // index.
+ // idx.
AddInterfaceAddr(idx int32, addr InterfaceAddr) error
+ // RemoveInterfaceAddr removes an address from the network interface
+ // identified by idx.
+ RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error
+
// SupportsIPv6 returns true if the stack supports IPv6 connectivity.
SupportsIPv6() bool
diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go
index 1779cc6f3..9ebeba8a3 100644
--- a/pkg/sentry/inet/test_stack.go
+++ b/pkg/sentry/inet/test_stack.go
@@ -15,6 +15,9 @@
package inet
import (
+ "bytes"
+ "fmt"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -58,6 +61,24 @@ func (s *TestStack) AddInterfaceAddr(idx int32, addr InterfaceAddr) error {
return nil
}
+// RemoveInterfaceAddr implements Stack.RemoveInterfaceAddr.
+func (s *TestStack) RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error {
+ interfaceAddrs, ok := s.InterfaceAddrsMap[idx]
+ if !ok {
+ return fmt.Errorf("unknown idx: %d", idx)
+ }
+
+ var filteredAddrs []InterfaceAddr
+ for _, interfaceAddr := range interfaceAddrs {
+ if !bytes.Equal(interfaceAddr.Addr, addr.Addr) {
+ filteredAddrs = append(filteredAddrs, addr)
+ }
+ }
+ s.InterfaceAddrsMap[idx] = filteredAddrs
+
+ return nil
+}
+
// SupportsIPv6 implements Stack.SupportsIPv6.
func (s *TestStack) SupportsIPv6() bool {
return s.SupportsIPv6Flag
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index 9a24c6bdb..90dd4a047 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -79,7 +79,7 @@ go_template_instance(
out = "fd_table_refs.go",
package = "kernel",
prefix = "FDTable",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "FDTable",
},
@@ -90,7 +90,7 @@ go_template_instance(
out = "fs_context_refs.go",
package = "kernel",
prefix = "FSContext",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "FSContext",
},
@@ -101,7 +101,7 @@ go_template_instance(
out = "ipc_namespace_refs.go",
package = "kernel",
prefix = "IPCNamespace",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "IPCNamespace",
},
@@ -112,7 +112,7 @@ go_template_instance(
out = "process_group_refs.go",
package = "kernel",
prefix = "ProcessGroup",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "ProcessGroup",
},
@@ -123,7 +123,7 @@ go_template_instance(
out = "session_refs.go",
package = "kernel",
prefix = "Session",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "Session",
},
@@ -218,6 +218,7 @@ go_library(
"//pkg/amutex",
"//pkg/bits",
"//pkg/bpf",
+ "//pkg/cleanup",
"//pkg/context",
"//pkg/coverage",
"//pkg/cpuid",
@@ -228,7 +229,7 @@ go_library(
"//pkg/marshal/primitive",
"//pkg/metric",
"//pkg/refs",
- "//pkg/refs_vfs2",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/secio",
"//pkg/sentry/arch",
diff --git a/pkg/sentry/kernel/abstract_socket_namespace.go b/pkg/sentry/kernel/abstract_socket_namespace.go
index 1b9721534..0ddbe5ff6 100644
--- a/pkg/sentry/kernel/abstract_socket_namespace.go
+++ b/pkg/sentry/kernel/abstract_socket_namespace.go
@@ -19,7 +19,7 @@ import (
"syscall"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/refs_vfs2"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -27,7 +27,7 @@ import (
// +stateify savable
type abstractEndpoint struct {
ep transport.BoundEndpoint
- socket refs_vfs2.RefCounter
+ socket refsvfs2.RefCounter
name string
ns *AbstractSocketNamespace
}
@@ -57,7 +57,7 @@ func NewAbstractSocketNamespace() *AbstractSocketNamespace {
// its backing socket.
type boundEndpoint struct {
transport.BoundEndpoint
- socket refs_vfs2.RefCounter
+ socket refsvfs2.RefCounter
}
// Release implements transport.BoundEndpoint.Release.
@@ -89,7 +89,7 @@ func (a *AbstractSocketNamespace) BoundEndpoint(name string) transport.BoundEndp
//
// When the last reference managed by socket is dropped, ep may be removed from the
// namespace.
-func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep transport.BoundEndpoint, socket refs_vfs2.RefCounter) error {
+func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep transport.BoundEndpoint, socket refsvfs2.RefCounter) error {
a.mu.Lock()
defer a.mu.Unlock()
@@ -109,7 +109,7 @@ func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep tran
// Remove removes the specified socket at name from the abstract socket
// namespace, if it has not yet been replaced.
-func (a *AbstractSocketNamespace) Remove(name string, socket refs_vfs2.RefCounter) {
+func (a *AbstractSocketNamespace) Remove(name string, socket refsvfs2.RefCounter) {
a.mu.Lock()
defer a.mu.Unlock()
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index 0ec7344cd..7aba31587 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -110,7 +110,7 @@ func (f *FDTable) saveDescriptorTable() map[int32]descriptor {
func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) {
ctx := context.Background()
- f.init() // Initialize table.
+ f.initNoLeakCheck() // Initialize table.
f.used = 0
for fd, d := range m {
if file, fileVFS2 := f.setAll(ctx, fd, d.file, d.fileVFS2, d.flags); file != nil || fileVFS2 != nil {
@@ -240,6 +240,10 @@ func (f *FDTable) String() string {
case fileVFS2 != nil:
vfsObj := fileVFS2.Mount().Filesystem().VirtualFilesystem()
+ vd := fileVFS2.VirtualDentry()
+ if vd.Dentry() == nil {
+ panic(fmt.Sprintf("fd %d (type %T) has nil dentry: %#v", fd, fileVFS2.Impl(), fileVFS2))
+ }
name, err := vfsObj.PathnameWithDeleted(ctx, vfs.VirtualDentry{}, fileVFS2.VirtualDentry())
if err != nil {
fmt.Fprintf(&buf, "<err: %v>\n", err)
diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go
index da79e6627..3476551f3 100644
--- a/pkg/sentry/kernel/fd_table_unsafe.go
+++ b/pkg/sentry/kernel/fd_table_unsafe.go
@@ -31,14 +31,21 @@ type descriptorTable struct {
slice unsafe.Pointer `state:".(map[int32]*descriptor)"`
}
-// init initializes the table.
+// initNoLeakCheck initializes the table without enabling leak checking.
//
-// TODO(gvisor.dev/1486): Enable leak check for FDTable.
-func (f *FDTable) init() {
+// This is used when loading an FDTable after S/R, during which the ref count
+// object itself will enable leak checking if necessary.
+func (f *FDTable) initNoLeakCheck() {
var slice []unsafe.Pointer // Empty slice.
atomic.StorePointer(&f.slice, unsafe.Pointer(&slice))
}
+// init initializes the table with leak checking.
+func (f *FDTable) init() {
+ f.initNoLeakCheck()
+ f.EnableLeakCheck()
+}
+
// get gets a file entry.
//
// The boolean indicates whether this was in range.
diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go
index d46d1e1c1..41fb2a784 100644
--- a/pkg/sentry/kernel/fs_context.go
+++ b/pkg/sentry/kernel/fs_context.go
@@ -130,13 +130,15 @@ func (f *FSContext) Fork() *FSContext {
f.root.IncRef()
}
- return &FSContext{
+ ctx := &FSContext{
cwd: f.cwd,
root: f.root,
cwdVFS2: f.cwdVFS2,
rootVFS2: f.rootVFS2,
umask: f.umask,
}
+ ctx.EnableLeakCheck()
+ return ctx
}
// WorkingDirectory returns the current working directory.
@@ -147,19 +149,23 @@ func (f *FSContext) WorkingDirectory() *fs.Dirent {
f.mu.Lock()
defer f.mu.Unlock()
- f.cwd.IncRef()
+ if f.cwd != nil {
+ f.cwd.IncRef()
+ }
return f.cwd
}
// WorkingDirectoryVFS2 returns the current working directory.
//
-// This will return nil if called after f is destroyed, otherwise it will return
-// a Dirent with a reference taken.
+// This will return an empty vfs.VirtualDentry if called after f is
+// destroyed, otherwise it will return a Dirent with a reference taken.
func (f *FSContext) WorkingDirectoryVFS2() vfs.VirtualDentry {
f.mu.Lock()
defer f.mu.Unlock()
- f.cwdVFS2.IncRef()
+ if f.cwdVFS2.Ok() {
+ f.cwdVFS2.IncRef()
+ }
return f.cwdVFS2
}
@@ -218,13 +224,15 @@ func (f *FSContext) RootDirectory() *fs.Dirent {
// RootDirectoryVFS2 returns the current filesystem root.
//
-// This will return nil if called after f is destroyed, otherwise it will return
-// a Dirent with a reference taken.
+// This will return an empty vfs.VirtualDentry if called after f is
+// destroyed, otherwise it will return a Dirent with a reference taken.
func (f *FSContext) RootDirectoryVFS2() vfs.VirtualDentry {
f.mu.Lock()
defer f.mu.Unlock()
- f.rootVFS2.IncRef()
+ if f.rootVFS2.Ok() {
+ f.rootVFS2.IncRef()
+ }
return f.rootVFS2
}
diff --git a/pkg/sentry/kernel/ipc_namespace.go b/pkg/sentry/kernel/ipc_namespace.go
index 3f34ee0db..b87e40dd1 100644
--- a/pkg/sentry/kernel/ipc_namespace.go
+++ b/pkg/sentry/kernel/ipc_namespace.go
@@ -55,7 +55,7 @@ func (i *IPCNamespace) ShmRegistry() *shm.Registry {
return i.shms
}
-// DecRef implements refs_vfs2.RefCounter.DecRef.
+// DecRef implements refsvfs2.RefCounter.DecRef.
func (i *IPCNamespace) DecRef(ctx context.Context) {
i.IPCNamespaceRefs.DecRef(func() {
i.shms.Release(ctx)
diff --git a/pkg/sentry/kernel/kcov.go b/pkg/sentry/kernel/kcov.go
index 060c056df..4fcdfc541 100644
--- a/pkg/sentry/kernel/kcov.go
+++ b/pkg/sentry/kernel/kcov.go
@@ -199,23 +199,25 @@ func (kcov *Kcov) DisableTrace(ctx context.Context) error {
}
kcov.mode = linux.KCOV_MODE_INIT
kcov.owningTask = nil
- kcov.mappable = nil
+ if kcov.mappable != nil {
+ kcov.mappable.DecRef(ctx)
+ kcov.mappable = nil
+ }
return nil
}
// Clear resets the mode and clears the owning task and memory mapping for kcov.
// It is called when the fd corresponding to kcov is closed. Note that the mode
// needs to be set so that the next call to kcov.TaskWork() will exit early.
-func (kcov *Kcov) Clear() {
+func (kcov *Kcov) Clear(ctx context.Context) {
kcov.mu.Lock()
- kcov.clearLocked()
- kcov.mu.Unlock()
-}
-
-func (kcov *Kcov) clearLocked() {
kcov.mode = linux.KCOV_MODE_INIT
kcov.owningTask = nil
- kcov.mappable = nil
+ if kcov.mappable != nil {
+ kcov.mappable.DecRef(ctx)
+ kcov.mappable = nil
+ }
+ kcov.mu.Unlock()
}
// OnTaskExit is called when the owning task exits. It is similar to
@@ -254,6 +256,7 @@ func (kcov *Kcov) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) erro
// will look different under /proc/[pid]/maps than they do on Linux.
kcov.mappable = mm.NewSpecialMappable(fmt.Sprintf("[kcov:%d]", t.ThreadID()), kcov.mfp, fr)
}
+ kcov.mappable.IncRef()
opts.Mappable = kcov.mappable
opts.MappingIdentity = kcov.mappable
return nil
diff --git a/pkg/sentry/kernel/kcov_unsafe.go b/pkg/sentry/kernel/kcov_unsafe.go
index 6f64022eb..6f8a0266b 100644
--- a/pkg/sentry/kernel/kcov_unsafe.go
+++ b/pkg/sentry/kernel/kcov_unsafe.go
@@ -20,9 +20,9 @@ import (
"gvisor.dev/gvisor/pkg/safemem"
)
-// countBlock provides a safemem.BlockSeq for k.count.
+// countBlock provides a safemem.BlockSeq for kcov.count.
//
// Like k.count, the block returned is protected by k.mu.
-func (k *Kcov) countBlock() safemem.BlockSeq {
- return safemem.BlockSeqOf(safemem.BlockFromSafePointer(unsafe.Pointer(&k.count), int(unsafe.Sizeof(k.count))))
+func (kcov *Kcov) countBlock() safemem.BlockSeq {
+ return safemem.BlockSeqOf(safemem.BlockFromSafePointer(unsafe.Pointer(&kcov.count), int(unsafe.Sizeof(kcov.count))))
}
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 652cbb732..9b2be44d4 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -39,6 +39,7 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/cleanup"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/eventchannel"
@@ -340,7 +341,7 @@ func (k *Kernel) Init(args InitKernelArgs) error {
return fmt.Errorf("Timekeeper is nil")
}
if args.Timekeeper.clocks == nil {
- return fmt.Errorf("Must call Timekeeper.SetClocks() before Kernel.Init()")
+ return fmt.Errorf("must call Timekeeper.SetClocks() before Kernel.Init()")
}
if args.RootUserNamespace == nil {
return fmt.Errorf("RootUserNamespace is nil")
@@ -365,7 +366,7 @@ func (k *Kernel) Init(args InitKernelArgs) error {
k.useHostCores = true
maxCPU, err := hostcpu.MaxPossibleCPU()
if err != nil {
- return fmt.Errorf("Failed to get maximum CPU number: %v", err)
+ return fmt.Errorf("failed to get maximum CPU number: %v", err)
}
minAppCores := uint(maxCPU) + 1
if k.applicationCores < minAppCores {
@@ -429,9 +430,8 @@ func (k *Kernel) Init(args InitKernelArgs) error {
// SaveTo saves the state of k to w.
//
// Preconditions: The kernel must be paused throughout the call to SaveTo.
-func (k *Kernel) SaveTo(w wire.Writer) error {
+func (k *Kernel) SaveTo(ctx context.Context, w wire.Writer) error {
saveStart := time.Now()
- ctx := k.SupervisorContext()
// Do not allow other Kernel methods to affect it while it's being saved.
k.extMu.Lock()
@@ -445,38 +445,55 @@ func (k *Kernel) SaveTo(w wire.Writer) error {
k.mf.StartEvictions()
k.mf.WaitForEvictions()
- // Flush write operations on open files so data reaches backing storage.
- // This must come after MemoryFile eviction since eviction may cause file
- // writes.
- if err := k.tasks.flushWritesToFiles(ctx); err != nil {
- return err
- }
+ if VFS2Enabled {
+ // Discard unsavable mappings, such as those for host file descriptors.
+ if err := k.invalidateUnsavableMappings(ctx); err != nil {
+ return fmt.Errorf("failed to invalidate unsavable mappings: %v", err)
+ }
- // Remove all epoll waiter objects from underlying wait queues.
- // NOTE: for programs to resume execution in future snapshot scenarios,
- // we will need to re-establish these waiter objects after saving.
- k.tasks.unregisterEpollWaiters(ctx)
+ // Prepare filesystems for saving. This must be done after
+ // invalidateUnsavableMappings(), since dropping memory mappings may
+ // affect filesystem state (e.g. page cache reference counts).
+ if err := k.vfs.PrepareSave(ctx); err != nil {
+ return err
+ }
+ } else {
+ // Flush cached file writes to backing storage. This must come after
+ // MemoryFile eviction since eviction may cause file writes.
+ if err := k.flushWritesToFiles(ctx); err != nil {
+ return err
+ }
- // Clear the dirent cache before saving because Dirents must be Loaded in a
- // particular order (parents before children), and Loading dirents from a cache
- // breaks that order.
- if err := k.flushMountSourceRefs(ctx); err != nil {
- return err
- }
+ // Remove all epoll waiter objects from underlying wait queues.
+ // NOTE: for programs to resume execution in future snapshot scenarios,
+ // we will need to re-establish these waiter objects after saving.
+ k.tasks.unregisterEpollWaiters(ctx)
- // Ensure that all inode and mount release operations have completed.
- fs.AsyncBarrier()
+ // Clear the dirent cache before saving because Dirents must be Loaded in a
+ // particular order (parents before children), and Loading dirents from a cache
+ // breaks that order.
+ if err := k.flushMountSourceRefs(ctx); err != nil {
+ return err
+ }
- // Once all fs work has completed (flushed references have all been released),
- // reset mount mappings. This allows individual mounts to save how inodes map
- // to filesystem resources. Without this, fs.Inodes cannot be restored.
- fs.SaveInodeMappings()
+ // Ensure that all inode and mount release operations have completed.
+ fs.AsyncBarrier()
- // Discard unsavable mappings, such as those for host file descriptors.
- // This must be done after waiting for "asynchronous fs work", which
- // includes async I/O that may touch application memory.
- if err := k.invalidateUnsavableMappings(ctx); err != nil {
- return fmt.Errorf("failed to invalidate unsavable mappings: %v", err)
+ // Once all fs work has completed (flushed references have all been released),
+ // reset mount mappings. This allows individual mounts to save how inodes map
+ // to filesystem resources. Without this, fs.Inodes cannot be restored.
+ fs.SaveInodeMappings()
+
+ // Discard unsavable mappings, such as those for host file descriptors.
+ // This must be done after waiting for "asynchronous fs work", which
+ // includes async I/O that may touch application memory.
+ //
+ // TODO(gvisor.dev/issue/1624): This rationale is believed to be
+ // obsolete since AIO callbacks are now waited-for by Kernel.Pause(),
+ // but this order is conservatively retained for VFS1.
+ if err := k.invalidateUnsavableMappings(ctx); err != nil {
+ return fmt.Errorf("failed to invalidate unsavable mappings: %v", err)
+ }
}
// Save the CPUID FeatureSet before the rest of the kernel so we can
@@ -485,14 +502,14 @@ func (k *Kernel) SaveTo(w wire.Writer) error {
//
// N.B. This will also be saved along with the full kernel save below.
cpuidStart := time.Now()
- if _, err := state.Save(k.SupervisorContext(), w, k.FeatureSet()); err != nil {
+ if _, err := state.Save(ctx, w, k.FeatureSet()); err != nil {
return err
}
log.Infof("CPUID save took [%s].", time.Since(cpuidStart))
// Save the kernel state.
kernelStart := time.Now()
- stats, err := state.Save(k.SupervisorContext(), w, k)
+ stats, err := state.Save(ctx, w, k)
if err != nil {
return err
}
@@ -501,7 +518,7 @@ func (k *Kernel) SaveTo(w wire.Writer) error {
// Save the memory file's state.
memoryStart := time.Now()
- if err := k.mf.SaveTo(k.SupervisorContext(), w); err != nil {
+ if err := k.mf.SaveTo(ctx, w); err != nil {
return err
}
log.Infof("Memory save took [%s].", time.Since(memoryStart))
@@ -513,11 +530,9 @@ func (k *Kernel) SaveTo(w wire.Writer) error {
// flushMountSourceRefs flushes the MountSources for all mounted filesystems
// and open FDs.
+//
+// Preconditions: !VFS2Enabled.
func (k *Kernel) flushMountSourceRefs(ctx context.Context) error {
- if VFS2Enabled {
- return nil // Not relevant.
- }
-
// Flush all mount sources for currently mounted filesystems in each task.
flushed := make(map[*fs.MountNamespace]struct{})
k.tasks.mu.RLock()
@@ -560,13 +575,9 @@ func (ts *TaskSet) forEachFDPaused(ctx context.Context, f func(*fs.File, *vfs.Fi
return err
}
-func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error {
- // TODO(gvisor.dev/issue/1663): Add save support for VFS2.
- if VFS2Enabled {
- return nil
- }
-
- return ts.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error {
+// Preconditions: !VFS2Enabled.
+func (k *Kernel) flushWritesToFiles(ctx context.Context) error {
+ return k.tasks.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error {
if flags := file.Flags(); !flags.Write {
return nil
}
@@ -588,37 +599,8 @@ func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error {
})
}
-// Preconditions: The kernel must be paused.
-func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error {
- invalidated := make(map[*mm.MemoryManager]struct{})
- k.tasks.mu.RLock()
- defer k.tasks.mu.RUnlock()
- for t := range k.tasks.Root.tids {
- // We can skip locking Task.mu here since the kernel is paused.
- if mm := t.tc.MemoryManager; mm != nil {
- if _, ok := invalidated[mm]; !ok {
- if err := mm.InvalidateUnsavable(ctx); err != nil {
- return err
- }
- invalidated[mm] = struct{}{}
- }
- }
- // I really wish we just had a sync.Map of all MMs...
- if r, ok := t.runState.(*runSyscallAfterExecStop); ok {
- if err := r.tc.MemoryManager.InvalidateUnsavable(ctx); err != nil {
- return err
- }
- }
- }
- return nil
-}
-
+// Preconditions: !VFS2Enabled.
func (ts *TaskSet) unregisterEpollWaiters(ctx context.Context) {
- // TODO(gvisor.dev/issue/1663): Add save support for VFS2.
- if VFS2Enabled {
- return
- }
-
ts.mu.RLock()
defer ts.mu.RUnlock()
@@ -643,8 +625,33 @@ func (ts *TaskSet) unregisterEpollWaiters(ctx context.Context) {
}
}
+// Preconditions: The kernel must be paused.
+func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error {
+ invalidated := make(map[*mm.MemoryManager]struct{})
+ k.tasks.mu.RLock()
+ defer k.tasks.mu.RUnlock()
+ for t := range k.tasks.Root.tids {
+ // We can skip locking Task.mu here since the kernel is paused.
+ if mm := t.tc.MemoryManager; mm != nil {
+ if _, ok := invalidated[mm]; !ok {
+ if err := mm.InvalidateUnsavable(ctx); err != nil {
+ return err
+ }
+ invalidated[mm] = struct{}{}
+ }
+ }
+ // I really wish we just had a sync.Map of all MMs...
+ if r, ok := t.runState.(*runSyscallAfterExecStop); ok {
+ if err := r.tc.MemoryManager.InvalidateUnsavable(ctx); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
// LoadFrom returns a new Kernel loaded from args.
-func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clocks) error {
+func (k *Kernel) LoadFrom(ctx context.Context, r wire.Reader, net inet.Stack, clocks sentrytime.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error {
loadStart := time.Now()
initAppCores := k.applicationCores
@@ -655,7 +662,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock
// don't need to explicitly install it in the Kernel.
cpuidStart := time.Now()
var features cpuid.FeatureSet
- if _, err := state.Load(k.SupervisorContext(), r, &features); err != nil {
+ if _, err := state.Load(ctx, r, &features); err != nil {
return err
}
log.Infof("CPUID load took [%s].", time.Since(cpuidStart))
@@ -670,7 +677,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock
// Load the kernel state.
kernelStart := time.Now()
- stats, err := state.Load(k.SupervisorContext(), r, k)
+ stats, err := state.Load(ctx, r, k)
if err != nil {
return err
}
@@ -683,7 +690,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock
// Load the memory file's state.
memoryStart := time.Now()
- if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil {
+ if err := k.mf.LoadFrom(ctx, r); err != nil {
return err
}
log.Infof("Memory load took [%s].", time.Since(memoryStart))
@@ -695,11 +702,17 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock
net.Resume()
}
- // Ensure that all pending asynchronous work is complete:
- // - namedpipe opening
- // - inode file opening
- if err := fs.AsyncErrorBarrier(); err != nil {
- return err
+ if VFS2Enabled {
+ if err := k.vfs.CompleteRestore(ctx, vfsOpts); err != nil {
+ return err
+ }
+ } else {
+ // Ensure that all pending asynchronous work is complete:
+ // - namedpipe opening
+ // - inode file opening
+ if err := fs.AsyncErrorBarrier(); err != nil {
+ return err
+ }
}
tcpip.AsyncLoading.Wait()
@@ -966,6 +979,10 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
}
tg := k.NewThreadGroup(mntns, args.PIDNamespace, NewSignalHandlers(), linux.SIGCHLD, args.Limits)
+ cu := cleanup.Make(func() {
+ tg.Release(ctx)
+ })
+ defer cu.Clean()
// Check which file to start from.
switch {
@@ -1025,13 +1042,14 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
MountNamespaceVFS2: mntnsVFS2,
ContainerID: args.ContainerID,
}
- t, err := k.tasks.NewTask(config)
+ t, err := k.tasks.NewTask(ctx, config)
if err != nil {
return nil, 0, err
}
t.traceExecEvent(tc) // Simulate exec for tracing.
// Success.
+ cu.Release()
tgid := k.tasks.Root.IDOfThreadGroup(tg)
if k.globalInit == nil {
k.globalInit = tg
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
index f61039f5b..d96bf253b 100644
--- a/pkg/sentry/kernel/pipe/vfs.go
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -33,6 +33,8 @@ import (
// VFSPipe represents the actual pipe, analagous to an inode. VFSPipes should
// not be copied.
+//
+// +stateify savable
type VFSPipe struct {
// mu protects the fields below.
mu sync.Mutex `state:"nosave"`
@@ -164,6 +166,8 @@ func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, l
// VFSPipeFD implements vfs.FileDescriptionImpl for pipes. It also implements
// non-atomic usermem.IO methods, allowing it to be passed as usermem.IO to
// other FileDescriptions for splice(2) and tee(2).
+//
+// +stateify savable
type VFSPipeFD struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
@@ -237,8 +241,7 @@ func (fd *VFSPipeFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal
// PipeSize implements fcntl(F_GETPIPE_SZ).
func (fd *VFSPipeFD) PipeSize() int64 {
- // Inline Pipe.FifoSize() rather than calling it with nil Context and
- // fs.File and ignoring the returned error (which is always nil).
+ // Inline Pipe.FifoSize() since we don't have a fs.File.
fd.pipe.mu.Lock()
defer fd.pipe.mu.Unlock()
return fd.pipe.max
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go
index c00fa1138..c39ecfb8f 100644
--- a/pkg/sentry/kernel/semaphore/semaphore.go
+++ b/pkg/sentry/kernel/semaphore/semaphore.go
@@ -283,6 +283,33 @@ func (s *Set) Change(ctx context.Context, creds *auth.Credentials, owner fs.File
return nil
}
+// GetStat extracts semid_ds information from the set.
+func (s *Set) GetStat(creds *auth.Credentials) (*linux.SemidDS, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // "The calling process must have read permission on the semaphore set."
+ if !s.checkPerms(creds, fs.PermMask{Read: true}) {
+ return nil, syserror.EACCES
+ }
+
+ ds := &linux.SemidDS{
+ SemPerm: linux.IPCPerm{
+ Key: uint32(s.key),
+ UID: uint32(creds.UserNamespace.MapFromKUID(s.owner.UID)),
+ GID: uint32(creds.UserNamespace.MapFromKGID(s.owner.GID)),
+ CUID: uint32(creds.UserNamespace.MapFromKUID(s.creator.UID)),
+ CGID: uint32(creds.UserNamespace.MapFromKGID(s.creator.GID)),
+ Mode: uint16(s.perms.LinuxMode()),
+ Seq: 0, // IPC sequence not supported.
+ },
+ SemOTime: s.opTime.TimeT(),
+ SemCTime: s.changeTime.TimeT(),
+ SemNSems: uint64(s.Size()),
+ }
+ return ds, nil
+}
+
// SetVal overrides a semaphore value, waking up waiters as needed.
func (s *Set) SetVal(ctx context.Context, num int32, val int16, creds *auth.Credentials, pid int32) error {
if val < 0 || val > valueMax {
@@ -320,7 +347,7 @@ func (s *Set) SetValAll(ctx context.Context, vals []uint16, creds *auth.Credenti
}
for _, val := range vals {
- if val < 0 || val > valueMax {
+ if val > valueMax {
return syserror.ERANGE
}
}
diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD
index f8a382fd8..80a592c8f 100644
--- a/pkg/sentry/kernel/shm/BUILD
+++ b/pkg/sentry/kernel/shm/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "shm_refs.go",
package = "shm",
prefix = "Shm",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "Shm",
},
@@ -27,7 +27,7 @@ go_library(
"//pkg/context",
"//pkg/log",
"//pkg/refs",
- "//pkg/refs_vfs2",
+ "//pkg/refsvfs2",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index 7a053f369..682080c14 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bpf"
+ "gvisor.dev/gvisor/pkg/cleanup"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -206,6 +207,10 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
} else {
ipcns.IncRef()
}
+ cu := cleanup.Make(func() {
+ ipcns.DecRef(t)
+ })
+ defer cu.Clean()
netns := t.NetworkNamespace()
if opts.NewNetworkNamespace {
@@ -216,13 +221,18 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
mntnsVFS2 := t.mountNamespaceVFS2
if mntnsVFS2 != nil {
mntnsVFS2.IncRef()
+ cu.Add(func() {
+ mntnsVFS2.DecRef(t)
+ })
}
tc, err := t.tc.Fork(t, t.k, !opts.NewAddressSpace)
if err != nil {
- ipcns.DecRef(t)
return 0, nil, err
}
+ cu.Add(func() {
+ tc.release()
+ })
// clone() returns 0 in the child.
tc.Arch.SetReturn(0)
if opts.Stack != 0 {
@@ -230,7 +240,6 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
}
if opts.SetTLS {
if !tc.Arch.SetTLS(uintptr(opts.TLS)) {
- ipcns.DecRef(t)
return 0, nil, syserror.EPERM
}
}
@@ -299,11 +308,11 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
} else {
cfg.InheritParent = t
}
- nt, err := t.tg.pidns.owner.NewTask(cfg)
+ nt, err := t.tg.pidns.owner.NewTask(t, cfg)
+ // If NewTask succeeds, we transfer references to nt. If NewTask fails, it does
+ // the cleanup for us.
+ cu.Release()
if err != nil {
- if opts.NewThreadGroup {
- tg.release(t)
- }
return 0, nil, err
}
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index 239551eb6..ce7b9641d 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -286,7 +286,7 @@ func (*runExitMain) execute(t *Task) taskRunState {
// If this is the last task to exit from the thread group, release the
// thread group's resources.
if lastExiter {
- t.tg.release(t)
+ t.tg.Release(t)
}
// Detach tracees.
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index 6e2ff573a..8e28230cc 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -16,6 +16,7 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -98,15 +99,18 @@ type TaskConfig struct {
// NewTask creates a new task defined by cfg.
//
// NewTask does not start the returned task; the caller must call Task.Start.
-func (ts *TaskSet) NewTask(cfg *TaskConfig) (*Task, error) {
+//
+// If successful, NewTask transfers references held by cfg to the new task.
+// Otherwise, NewTask releases them.
+func (ts *TaskSet) NewTask(ctx context.Context, cfg *TaskConfig) (*Task, error) {
t, err := ts.newTask(cfg)
if err != nil {
cfg.TaskContext.release()
- cfg.FSContext.DecRef(t)
- cfg.FDTable.DecRef(t)
- cfg.IPCNamespace.DecRef(t)
+ cfg.FSContext.DecRef(ctx)
+ cfg.FDTable.DecRef(ctx)
+ cfg.IPCNamespace.DecRef(ctx)
if cfg.MountNamespaceVFS2 != nil {
- cfg.MountNamespaceVFS2.DecRef(t)
+ cfg.MountNamespaceVFS2.DecRef(ctx)
}
return nil, err
}
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index 0b34c0099..a183b28c1 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -18,6 +18,7 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -307,8 +308,8 @@ func (tg *ThreadGroup) Limits() *limits.LimitSet {
return tg.limits
}
-// release releases the thread group's resources.
-func (tg *ThreadGroup) release(t *Task) {
+// Release releases the thread group's resources.
+func (tg *ThreadGroup) Release(ctx context.Context) {
// Timers must be destroyed without holding the TaskSet or signal mutexes
// since timers send signals with Timer.mu locked.
tg.itimerRealTimer.Destroy()
@@ -325,7 +326,7 @@ func (tg *ThreadGroup) release(t *Task) {
it.DestroyTimer()
}
if tg.mounts != nil {
- tg.mounts.DecRef(t)
+ tg.mounts.DecRef(ctx)
}
}
diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go
index d4610ec3b..98af2cc38 100644
--- a/pkg/sentry/loader/elf.go
+++ b/pkg/sentry/loader/elf.go
@@ -194,6 +194,10 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) {
log.Infof("Too many phdrs (%d): total size %d > %d", hdr.Phnum, totalPhdrSize, maxTotalPhdrSize)
return elfInfo{}, syserror.ENOEXEC
}
+ if int64(hdr.Phoff) < 0 || int64(hdr.Phoff+uint64(totalPhdrSize)) < 0 {
+ ctx.Infof("Unsupported phdr offset %d", hdr.Phoff)
+ return elfInfo{}, syserror.ENOEXEC
+ }
phdrBuf := make([]byte, totalPhdrSize)
_, err = f.ReadFull(ctx, usermem.BytesIOSequence(phdrBuf), int64(hdr.Phoff))
@@ -437,6 +441,10 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, in
ctx.Infof("PT_INTERP path too big: %v", phdr.Filesz)
return loadedELF{}, syserror.ENOEXEC
}
+ if int64(phdr.Off) < 0 || int64(phdr.Off+phdr.Filesz) < 0 {
+ ctx.Infof("Unsupported PT_INTERP offset %d", phdr.Off)
+ return loadedELF{}, syserror.ENOEXEC
+ }
path := make([]byte, phdr.Filesz)
_, err := f.ReadFull(ctx, usermem.BytesIOSequence(path), int64(phdr.Off))
diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD
index b4a47ccca..6dbeccfe2 100644
--- a/pkg/sentry/mm/BUILD
+++ b/pkg/sentry/mm/BUILD
@@ -78,7 +78,7 @@ go_template_instance(
out = "aio_mappable_refs.go",
package = "mm",
prefix = "aioMappable",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "aioMappable",
},
@@ -89,7 +89,7 @@ go_template_instance(
out = "special_mappable_refs.go",
package = "mm",
prefix = "SpecialMappable",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "SpecialMappable",
},
@@ -127,6 +127,7 @@ go_library(
"//pkg/context",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safecopy",
"//pkg/safemem",
"//pkg/sentry/arch",
diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD
index 7a3311a70..5b09b9feb 100644
--- a/pkg/sentry/pgalloc/BUILD
+++ b/pkg/sentry/pgalloc/BUILD
@@ -83,6 +83,7 @@ go_library(
],
visibility = ["//pkg/sentry:internal"],
deps = [
+ "//pkg/abi/linux",
"//pkg/context",
"//pkg/log",
"//pkg/memutil",
diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go
index 626d1eaa4..7c297fb9e 100644
--- a/pkg/sentry/pgalloc/pgalloc.go
+++ b/pkg/sentry/pgalloc/pgalloc.go
@@ -29,6 +29,7 @@ import (
"syscall"
"time"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/safemem"
@@ -224,6 +225,18 @@ type usageInfo struct {
refs uint64
}
+// canCommit returns true if the tracked region can be committed.
+func (u *usageInfo) canCommit() bool {
+ // refs must be greater than 0 because we assume that reclaimable pages
+ // (that aren't already known to be committed) are not committed. This
+ // isn't necessarily true, even after the reclaimer does Decommit(),
+ // because the kernel may subsequently back the hugepage-sized region
+ // containing the decommitted page with a hugepage. However, it's
+ // consistent with our treatment of unallocated pages, which have the same
+ // property.
+ return !u.knownCommitted && u.refs != 0
+}
+
// An EvictableMemoryUser represents a user of MemoryFile-allocated memory that
// may be asked to deallocate that memory in the presence of memory pressure.
type EvictableMemoryUser interface {
@@ -828,6 +841,11 @@ func (f *MemoryFile) UpdateUsage() error {
log.Debugf("UpdateUsage: skipped with usageSwapped!=0.")
return nil
}
+ // Linux updates usage values at CONFIG_HZ.
+ if scanningAfter := time.Now().Sub(f.usageLast).Milliseconds(); scanningAfter < time.Second.Milliseconds()/linux.CLOCKS_PER_SEC {
+ log.Debugf("UpdateUsage: skipped because previous scan happened %d ms back", scanningAfter)
+ return nil
+ }
f.usageLast = time.Now()
err = f.updateUsageLocked(currentUsage, mincore)
@@ -841,7 +859,7 @@ func (f *MemoryFile) UpdateUsage() error {
// pages by invoking checkCommitted, which is a function that, for each page i
// in bs, sets committed[i] to 1 if the page is committed and 0 otherwise.
//
-// Precondition: f.mu must be held.
+// Precondition: f.mu must be held; it may be unlocked and reacquired.
func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func(bs []byte, committed []byte) error) error {
// Track if anything changed to elide the merge. In the common case, we
// expect all segments to be committed and no merge to occur.
@@ -868,7 +886,7 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func(
} else if f.usageSwapped != 0 {
// We have more usage accounted for than the file itself.
// That's fine, we probably caught a race where pages were
- // being committed while the above loop was running. Just
+ // being committed while the below loop was running. Just
// report the higher number that we found and ignore swap.
usage.MemoryAccounting.Dec(f.usageSwapped, usage.System)
f.usageSwapped = 0
@@ -880,21 +898,9 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func(
// Iterate over all usage data. There will only be usage segments
// present when there is an associated reference.
- for seg := f.usage.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
- val := seg.Value()
-
- // Already known to be committed; ignore.
- if val.knownCommitted {
- continue
- }
-
- // Assume that reclaimable pages (that aren't already known to be
- // committed) are not committed. This isn't necessarily true, even
- // after the reclaimer does Decommit(), because the kernel may
- // subsequently back the hugepage-sized region containing the
- // decommitted page with a hugepage. However, it's consistent with our
- // treatment of unallocated pages, which have the same property.
- if val.refs == 0 {
+ for seg := f.usage.FirstSegment(); seg.Ok(); {
+ if !seg.ValuePtr().canCommit() {
+ seg = seg.NextSegment()
continue
}
@@ -917,56 +923,53 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func(
}
// Query for new pages in core.
- if err := checkCommitted(s, buf); err != nil {
+ // NOTE(b/165896008): mincore (which is passed as checkCommitted)
+ // by f.UpdateUsage() might take a really long time. So unlock f.mu
+ // while checkCommitted runs.
+ f.mu.Unlock()
+ err := checkCommitted(s, buf)
+ f.mu.Lock()
+ if err != nil {
checkErr = err
return
}
// Scan each page and switch out segments.
- populatedRun := false
- populatedRunStart := 0
- for i := 0; i <= bufLen; i++ {
- // We run past the end of the slice here to
- // simplify the logic and only set populated if
- // we're still looking at elements.
- populated := false
- if i < bufLen {
- populated = buf[i]&0x1 != 0
- }
-
- switch {
- case populated == populatedRun:
- // Keep the run going.
- continue
- case populated && !populatedRun:
- // Begin the run.
- populatedRun = true
- populatedRunStart = i
- // Keep going.
+ seg := f.usage.LowerBoundSegment(r.Start)
+ for i := 0; i < bufLen; {
+ if buf[i]&0x1 == 0 {
+ i++
continue
- case !populated && populatedRun:
- // Finish the run by changing this segment.
- runRange := memmap.FileRange{
- Start: r.Start + uint64(populatedRunStart*usermem.PageSize),
- End: r.Start + uint64(i*usermem.PageSize),
+ }
+ // Scan to the end of this committed range.
+ j := i + 1
+ for ; j < bufLen; j++ {
+ if buf[j]&0x1 == 0 {
+ break
}
- seg = f.usage.Isolate(seg, runRange)
- seg.ValuePtr().knownCommitted = true
- // Advance the segment only if we still
- // have work to do in the context of
- // the original segment from the for
- // loop. Otherwise, the for loop itself
- // will advance the segment
- // appropriately.
- if runRange.End != r.End {
- seg = seg.NextSegment()
+ }
+ committedFR := memmap.FileRange{
+ Start: r.Start + uint64(i*usermem.PageSize),
+ End: r.Start + uint64(j*usermem.PageSize),
+ }
+ // Advance seg to committedFR.Start.
+ for seg.Ok() && seg.End() < committedFR.Start {
+ seg = seg.NextSegment()
+ }
+ // Mark pages overlapping committedFR as committed.
+ for seg.Ok() && seg.Start() < committedFR.End {
+ if seg.ValuePtr().canCommit() {
+ seg = f.usage.Isolate(seg, committedFR)
+ seg.ValuePtr().knownCommitted = true
+ amount := seg.Range().Length()
+ usage.MemoryAccounting.Inc(amount, seg.ValuePtr().kind)
+ f.usageExpected += amount
+ changedAny = true
}
- amount := runRange.Length()
- usage.MemoryAccounting.Inc(amount, val.kind)
- f.usageExpected += amount
- changedAny = true
- populatedRun = false
+ seg = seg.NextSegment()
}
+ // Continue scanning for committed pages.
+ i = j + 1
}
// Advance r.Start.
@@ -978,6 +981,9 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func(
if err != nil {
return err
}
+
+ // Continue with the first segment after r.End.
+ seg = f.usage.LowerBoundSegment(r.End)
}
return nil
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go
index ed5ae03d3..58f3d6fdd 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.go
@@ -39,6 +39,16 @@ var (
}
)
+// getTLS returns the value of TPIDR_EL0 register.
+//
+//go:nosplit
+func getTLS() (value uint64)
+
+// setTLS writes the TPIDR_EL0 value.
+//
+//go:nosplit
+func setTLS(value uint64)
+
// bluepillArchEnter is called during bluepillEnter.
//
//go:nosplit
@@ -51,6 +61,8 @@ func bluepillArchEnter(context *arch.SignalContext64) (c *vCPU) {
regs.Pstate = context.Pstate
regs.Pstate &^= uint64(ring0.PsrFlagsClear)
regs.Pstate |= ring0.KernelFlagsSet
+ regs.TPIDR_EL0 = getTLS()
+
return
}
@@ -65,6 +77,7 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) {
context.Pstate = regs.Pstate
context.Pstate &^= uint64(ring0.PsrFlagsClear)
context.Pstate |= ring0.UserFlagsSet
+ setTLS(regs.TPIDR_EL0)
lazyVfp := c.GetLazyVFP()
if lazyVfp != 0 {
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s
index 04efa0147..09c7e88e5 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.s
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.s
@@ -32,6 +32,18 @@
#define CONTEXT_PC 0x1B8
#define CONTEXT_R0 0xB8
+// getTLS returns the value of TPIDR_EL0 register.
+TEXT ·getTLS(SB),NOSPLIT,$0-8
+ MRS TPIDR_EL0, R1
+ MOVD R1, ret+0(FP)
+ RET
+
+// setTLS writes the TPIDR_EL0 value.
+TEXT ·setTLS(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R1
+ MSR R1, TPIDR_EL0
+ RET
+
// See bluepill.go.
TEXT ·bluepill(SB),NOSPLIT,$0
begin:
diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go
index 84df0f878..5831b9345 100644
--- a/pkg/sentry/platform/kvm/kvm_const_arm64.go
+++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go
@@ -38,6 +38,8 @@ const (
_KVM_ARM64_REGS_SCTLR_EL1 = 0x603000000013c080
_KVM_ARM64_REGS_CPACR_EL1 = 0x603000000013c082
_KVM_ARM64_REGS_VBAR_EL1 = 0x603000000013c600
+ _KVM_ARM64_REGS_TIMER_CNT = 0x603000000013df1a
+ _KVM_ARM64_REGS_CNTFRQ_EL0 = 0x603000000013df00
)
// Arm64: Architectural Feature Access Control Register EL1.
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index 61ed24d01..f70d761fd 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/procid"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ ktime "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -625,3 +626,35 @@ func (c *vCPU) BounceToKernel() {
func (c *vCPU) BounceToHost() {
c.bounce(true)
}
+
+// setSystemTimeLegacy calibrates and sets an approximate system time.
+func (c *vCPU) setSystemTimeLegacy() error {
+ const minIterations = 10
+ minimum := uint64(0)
+ for iter := 0; ; iter++ {
+ // Try to set the TSC to an estimate of where it will be
+ // on the host during a "fast" system call iteration.
+ start := uint64(ktime.Rdtsc())
+ if err := c.setTSC(start + (minimum / 2)); err != nil {
+ return err
+ }
+ // See if this is our new minimum call time. Note that this
+ // serves two functions: one, we make sure that we are
+ // accurately predicting the offset we need to set. Second, we
+ // don't want to do the final set on a slow call, which could
+ // produce a really bad result.
+ end := uint64(ktime.Rdtsc())
+ if end < start {
+ continue // Totally bogus: unstable TSC?
+ }
+ current := end - start
+ if current < minimum || iter == 0 {
+ minimum = current // Set our new minimum.
+ }
+ // Is this past minIterations and within ~10% of minimum?
+ upperThreshold := (((minimum << 3) + minimum) >> 3)
+ if iter >= minIterations && current <= upperThreshold {
+ return nil
+ }
+ }
+}
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index c67127d95..a8b729e62 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -252,38 +252,6 @@ func (c *vCPU) setSystemTime() error {
}
}
-// setSystemTimeLegacy calibrates and sets an approximate system time.
-func (c *vCPU) setSystemTimeLegacy() error {
- const minIterations = 10
- minimum := uint64(0)
- for iter := 0; ; iter++ {
- // Try to set the TSC to an estimate of where it will be
- // on the host during a "fast" system call iteration.
- start := uint64(ktime.Rdtsc())
- if err := c.setTSC(start + (minimum / 2)); err != nil {
- return err
- }
- // See if this is our new minimum call time. Note that this
- // serves two functions: one, we make sure that we are
- // accurately predicting the offset we need to set. Second, we
- // don't want to do the final set on a slow call, which could
- // produce a really bad result.
- end := uint64(ktime.Rdtsc())
- if end < start {
- continue // Totally bogus: unstable TSC?
- }
- current := end - start
- if current < minimum || iter == 0 {
- minimum = current // Set our new minimum.
- }
- // Is this past minIterations and within ~10% of minimum?
- upperThreshold := (((minimum << 3) + minimum) >> 3)
- if iter >= minIterations && current <= upperThreshold {
- return nil
- }
- }
-}
-
// nonCanonical generates a canonical address return.
//
//go:nosplit
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index a163f956d..1344ed3c9 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -159,9 +159,33 @@ func (c *vCPU) initArchState() error {
}
c.floatingPointState = arch.NewFloatingPointData()
+
+ return c.setSystemTime()
+}
+
+// setTSC sets the counter Virtual Offset.
+func (c *vCPU) setTSC(value uint64) error {
+ var (
+ reg kvmOneReg
+ data uint64
+ )
+
+ reg.addr = uint64(reflect.ValueOf(&data).Pointer())
+ reg.id = _KVM_ARM64_REGS_TIMER_CNT
+ data = uint64(value)
+
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
return nil
}
+// setSystemTime sets the vCPU to the system time.
+func (c *vCPU) setSystemTime() error {
+ return c.setSystemTimeLegacy()
+}
+
//go:nosplit
func (c *vCPU) loadSegments(tid uint64) {
// TODO(gvisor.dev/issue/1238): TLS is not supported.
@@ -235,8 +259,9 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
return c.fault(int32(syscall.SIGSEGV), info)
case ring0.Vector(bounce): // ring0.VirtualizationException
return usermem.NoAccess, platform.ErrContextInterrupt
- case ring0.El0Sync_undef,
- ring0.El1Sync_undef:
+ case ring0.El0Sync_undef:
+ return c.fault(int32(syscall.SIGILL), info)
+ case ring0.El1Sync_undef:
*info = arch.SignalInfo{
Signo: int32(syscall.SIGILL),
Code: 1, // ILL_ILLOPC (illegal opcode).
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index 2370a9276..1079a024b 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -366,6 +366,19 @@
MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \
LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack.
+// EXCEPTION_WITH_ERROR is a common exception handler function.
+#define EXCEPTION_WITH_ERROR(user, vector) \
+ WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
+ WORD $0xd538601a; \ //MRS FAR_EL1, R26
+ MOVD R26, CPU_FAULT_ADDR(RSV_REG); \
+ MOVD $user, R3; \
+ MOVD R3, CPU_ERROR_TYPE(RSV_REG); \ // Set error type to user.
+ MOVD $vector, R3; \
+ MOVD R3, CPU_VECTOR_CODE(RSV_REG); \
+ MRS ESR_EL1, R3; \
+ MOVD R3, CPU_ERROR_CODE(RSV_REG); \
+ B ·kernelExitToEl1(SB);
+
// storeAppASID writes the application's asid value.
TEXT ·storeAppASID(SB),NOSPLIT,$0-8
MOVD asid+0(FP), R1
@@ -659,21 +672,7 @@ el0_svc:
el0_da:
el0_ia:
- WORD $0xd538d092 //MRS TPIDR_EL1, R18
- WORD $0xd538601a //MRS FAR_EL1, R26
-
- MOVD R26, CPU_FAULT_ADDR(RSV_REG)
-
- MOVD $1, R3
- MOVD R3, CPU_ERROR_TYPE(RSV_REG) // Set error type to user.
-
- MOVD $PageFault, R3
- MOVD R3, CPU_VECTOR_CODE(RSV_REG)
-
- MRS ESR_EL1, R3
- MOVD R3, CPU_ERROR_CODE(RSV_REG)
-
- B ·kernelExitToEl1(SB)
+ EXCEPTION_WITH_ERROR(1, PageFault)
el0_fpsimd_acc:
B ·Shutdown(SB)
@@ -688,10 +687,7 @@ el0_sp_pc:
B ·Shutdown(SB)
el0_undef:
- MOVD $El0Sync_undef, R3
- MOVD R3, CPU_VECTOR_CODE(RSV_REG)
-
- B ·kernelExitToEl1(SB)
+ EXCEPTION_WITH_ERROR(1, El0Sync_undef)
el0_dbg:
B ·Shutdown(SB)
diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go
index 2f1abcb0f..d91a09de1 100644
--- a/pkg/sentry/platform/ring0/lib_arm64.go
+++ b/pkg/sentry/platform/ring0/lib_arm64.go
@@ -53,12 +53,6 @@ func LoadFloatingPoint(*byte)
// SaveFloatingPoint saves floating point state.
func SaveFloatingPoint(*byte)
-// GetTLS returns the value of TPIDR_EL0 register.
-func GetTLS() (value uint64)
-
-// SetTLS writes the TPIDR_EL0 value.
-func SetTLS(value uint64)
-
// Init sets function pointers based on architectural features.
//
// This must be called prior to using ring0.
diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s
index 8aabf7d0e..da9d3cf55 100644
--- a/pkg/sentry/platform/ring0/lib_arm64.s
+++ b/pkg/sentry/platform/ring0/lib_arm64.s
@@ -29,16 +29,6 @@ TEXT ·FlushTlbAll(SB),NOSPLIT,$0
ISB $15
RET
-TEXT ·GetTLS(SB),NOSPLIT,$0-8
- MRS TPIDR_EL0, R1
- MOVD R1, ret+0(FP)
- RET
-
-TEXT ·SetTLS(SB),NOSPLIT,$0-8
- MOVD addr+0(FP), R1
- MSR R1, TPIDR_EL0
- RET
-
TEXT ·CPACREL1(SB),NOSPLIT,$0-8
WORD $0xd5381041 // MRS CPACR_EL1, R1
MOVD R1, ret+0(FP)
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
index 1a49f12a2..5ddd10256 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
@@ -36,7 +36,7 @@ const (
pudSize = 1 << pudShift
pgdSize = 1 << pgdShift
- ttbrASIDOffset = 55
+ ttbrASIDOffset = 48
ttbrASIDMask = 0xff
entriesPerPage = 512
diff --git a/pkg/sentry/socket/control/control_vfs2.go b/pkg/sentry/socket/control/control_vfs2.go
index d9621968c..37d02948f 100644
--- a/pkg/sentry/socket/control/control_vfs2.go
+++ b/pkg/sentry/socket/control/control_vfs2.go
@@ -24,6 +24,8 @@ import (
)
// SCMRightsVFS2 represents a SCM_RIGHTS socket control message.
+//
+// +stateify savable
type SCMRightsVFS2 interface {
transport.RightsControlMessage
@@ -34,9 +36,11 @@ type SCMRightsVFS2 interface {
Files(ctx context.Context, max int) (rf RightsFilesVFS2, truncated bool)
}
-// RightsFiles represents a SCM_RIGHTS socket control message. A reference is
-// maintained for each vfs.FileDescription and is release either when an FD is created or
-// when the Release method is called.
+// RightsFilesVFS2 represents a SCM_RIGHTS socket control message. A reference
+// is maintained for each vfs.FileDescription and is release either when an FD
+// is created or when the Release method is called.
+//
+// +stateify savable
type RightsFilesVFS2 []*vfs.FileDescription
// NewSCMRightsVFS2 creates a new SCM_RIGHTS socket control message
diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go
index 163af329b..9a2cac40b 100644
--- a/pkg/sentry/socket/hostinet/socket_vfs2.go
+++ b/pkg/sentry/socket/hostinet/socket_vfs2.go
@@ -33,6 +33,7 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// +stateify savable
type socketVFS2 struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
@@ -51,7 +52,7 @@ var _ = socket.SocketVFS2(&socketVFS2{})
func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol int, fd int, flags uint32) (*vfs.FileDescription, *syserr.Error) {
mnt := t.Kernel().SocketMount()
- d := sockfs.NewDentry(t.Credentials(), mnt)
+ d := sockfs.NewDentry(t, mnt)
defer d.DecRef(t)
s := &socketVFS2{
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index faa61160e..7e7857ac3 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -324,7 +324,12 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
}
// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr.
-func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+func (s *Stack) AddInterfaceAddr(int32, inet.InterfaceAddr) error {
+ return syserror.EACCES
+}
+
+// RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr.
+func (s *Stack) RemoveInterfaceAddr(int32, inet.InterfaceAddr) error {
return syserror.EACCES
}
@@ -359,7 +364,7 @@ func (s *Stack) TCPSACKEnabled() (bool, error) {
}
// SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled.
-func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
+func (s *Stack) SetTCPSACKEnabled(bool) error {
return syserror.EACCES
}
@@ -369,7 +374,7 @@ func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) {
}
// SetTCPRecovery implements inet.Stack.SetTCPRecovery.
-func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error {
+func (s *Stack) SetTCPRecovery(inet.TCPLossRecovery) error {
return syserror.EACCES
}
@@ -430,18 +435,18 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
}
if rawLine == "" {
- return fmt.Errorf("Failed to get raw line")
+ return fmt.Errorf("failed to get raw line")
}
parts := strings.SplitN(rawLine, ":", 2)
if len(parts) != 2 {
- return fmt.Errorf("Failed to get prefix from: %q", rawLine)
+ return fmt.Errorf("failed to get prefix from: %q", rawLine)
}
sliceStat = toSlice(stat)
fields := strings.Fields(strings.TrimSpace(parts[1]))
if len(fields) != len(sliceStat) {
- return fmt.Errorf("Failed to parse fields: %q", rawLine)
+ return fmt.Errorf("failed to parse fields: %q", rawLine)
}
if _, ok := stat.(*inet.StatSNMPTCP); ok {
snmpTCP = true
@@ -457,7 +462,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
sliceStat[i], err = strconv.ParseUint(fields[i], 10, 64)
}
if err != nil {
- return fmt.Errorf("Failed to parse field %d from: %q, %v", i, rawLine, err)
+ return fmt.Errorf("failed to parse field %d from: %q, %v", i, rawLine, err)
}
}
@@ -495,6 +500,6 @@ func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
}
// SetForwarding implements inet.Stack.SetForwarding.
-func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error {
+func (s *Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error {
return syserror.EACCES
}
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index 844acfede..352c51390 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -71,7 +71,7 @@ func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma
}
if filter.Protocol != header.TCPProtocolNumber {
- return nil, fmt.Errorf("TCP matching is only valid for protocol %d.", header.TCPProtocolNumber)
+ return nil, fmt.Errorf("TCP matching is only valid for protocol %d", header.TCPProtocolNumber)
}
return &TCPMatcher{
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index 63201201c..c88d8268d 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -68,7 +68,7 @@ func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma
}
if filter.Protocol != header.UDPProtocolNumber {
- return nil, fmt.Errorf("UDP matching is only valid for protocol %d.", header.UDPProtocolNumber)
+ return nil, fmt.Errorf("UDP matching is only valid for protocol %d", header.UDPProtocolNumber)
}
return &UDPMatcher{
diff --git a/pkg/sentry/socket/netlink/provider_vfs2.go b/pkg/sentry/socket/netlink/provider_vfs2.go
index e8930f031..f061c5d62 100644
--- a/pkg/sentry/socket/netlink/provider_vfs2.go
+++ b/pkg/sentry/socket/netlink/provider_vfs2.go
@@ -51,7 +51,7 @@ func (*socketProviderVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol
vfsfd := &s.vfsfd
mnt := t.Kernel().SocketMount()
- d := sockfs.NewDentry(t.Credentials(), mnt)
+ d := sockfs.NewDentry(t, mnt)
defer d.DecRef(t)
if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{
DenyPRead: true,
diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go
index c84d8bd7c..22216158e 100644
--- a/pkg/sentry/socket/netlink/route/protocol.go
+++ b/pkg/sentry/socket/netlink/route/protocol.go
@@ -36,9 +36,9 @@ type commandKind int
const (
kindNew commandKind = 0x0
- kindDel = 0x1
- kindGet = 0x2
- kindSet = 0x3
+ kindDel commandKind = 0x1
+ kindGet commandKind = 0x2
+ kindSet commandKind = 0x3
)
func typeKind(typ uint16) commandKind {
@@ -423,6 +423,11 @@ func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlin
}
attrs = rest
+ // NOTE: A netlink message will contain multiple header attributes.
+ // Both the IFA_ADDRESS and IFA_LOCAL attributes are typically sent
+ // with IFA_ADDRESS being a prefix address and IFA_LOCAL being the
+ // local interface address. We add the local interface address here
+ // and ignore the IFA_ADDRESS.
switch ahdr.Type {
case linux.IFA_LOCAL:
err := stack.AddInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{
@@ -439,8 +444,57 @@ func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlin
} else if err != nil {
return syserr.ErrInvalidArgument
}
+ case linux.IFA_ADDRESS:
+ default:
+ return syserr.ErrNotSupported
+ }
+ }
+ return nil
+}
+
+// delAddr handles RTM_DELADDR requests.
+func (p *Protocol) delAddr(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network stack.
+ return syserr.ErrProtocolNotSupported
+ }
+
+ var ifa linux.InterfaceAddrMessage
+ attrs, ok := msg.GetData(&ifa)
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+
+ for !attrs.Empty() {
+ ahdr, value, rest, ok := attrs.ParseFirst()
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+ attrs = rest
+
+ // NOTE: A netlink message will contain multiple header attributes.
+ // Both the IFA_ADDRESS and IFA_LOCAL attributes are typically sent
+ // with IFA_ADDRESS being a prefix address and IFA_LOCAL being the
+ // local interface address. We use the local interface address to
+ // remove the address and ignore the IFA_ADDRESS.
+ switch ahdr.Type {
+ case linux.IFA_LOCAL:
+ err := stack.RemoveInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{
+ Family: ifa.Family,
+ PrefixLen: ifa.PrefixLen,
+ Flags: ifa.Flags,
+ Addr: value,
+ })
+ if err != nil {
+ return syserr.ErrInvalidArgument
+ }
+ case linux.IFA_ADDRESS:
+ default:
+ return syserr.ErrNotSupported
}
}
+
return nil
}
@@ -485,6 +539,8 @@ func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms
return p.dumpRoutes(ctx, msg, ms)
case linux.RTM_NEWADDR:
return p.newAddr(ctx, msg, ms)
+ case linux.RTM_DELADDR:
+ return p.delAddr(ctx, msg, ms)
default:
return syserr.ErrNotSupported
}
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 5ddcd4be5..3baad098b 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -16,6 +16,7 @@
package netlink
import (
+ "io"
"math"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -748,6 +749,12 @@ func (s *socketOpsCommon) sendMsg(ctx context.Context, src usermem.IOSequence, t
buf := make([]byte, src.NumBytes())
n, err := src.CopyIn(ctx, buf)
+ // io.EOF can be only returned if src is a file, this means that
+ // sendMsg is called from splice and the error has to be ignored in
+ // this case.
+ if err == io.EOF {
+ err = nil
+ }
if err != nil {
// Don't partially consume messages.
return 0, syserr.FromError(err)
diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go
index c83b23242..461d524e5 100644
--- a/pkg/sentry/socket/netlink/socket_vfs2.go
+++ b/pkg/sentry/socket/netlink/socket_vfs2.go
@@ -37,6 +37,8 @@ import (
// to/from the kernel.
//
// SocketVFS2 implements socket.SocketVFS2 and transport.Credentialer.
+//
+// +stateify savable
type SocketVFS2 struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 87e30d742..86c634715 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -587,6 +587,11 @@ func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) {
}
v := buffer.NewView(size)
if _, err := i.src.CopyIn(i.ctx, v); err != nil {
+ // EOF can be returned only if src is a file and this means it
+ // is in a splice syscall and the error has to be ignored.
+ if err == io.EOF {
+ return v, nil
+ }
return nil, tcpip.ErrBadAddress
}
return v, nil
@@ -1239,6 +1244,18 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
vP := primitive.Int32(boolToInt32(v))
return &vP, nil
+ case linux.SO_ACCEPTCONN:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.AcceptConnOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
+
default:
socket.GetSockOptEmitUnimplementedEvent(t, name)
}
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index 4c6791fff..b0d9e4d9e 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -35,6 +35,8 @@ import (
// SocketVFS2 encapsulates all the state needed to represent a network stack
// endpoint in the kernel context.
+//
+// +stateify savable
type SocketVFS2 struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
@@ -55,7 +57,7 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu
}
mnt := t.Kernel().SocketMount()
- d := sockfs.NewDentry(t.Credentials(), mnt)
+ d := sockfs.NewDentry(t, mnt)
defer d.DecRef(t)
s := &SocketVFS2{
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index 1028d2a6e..fa9ac9059 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -100,56 +100,101 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
return nicAddrs
}
-// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr.
-func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+// convertAddr converts an InterfaceAddr to a ProtocolAddress.
+func convertAddr(addr inet.InterfaceAddr) (tcpip.ProtocolAddress, error) {
var (
- protocol tcpip.NetworkProtocolNumber
- address tcpip.Address
+ protocol tcpip.NetworkProtocolNumber
+ address tcpip.Address
+ protocolAddress tcpip.ProtocolAddress
)
switch addr.Family {
case linux.AF_INET:
- if len(addr.Addr) < header.IPv4AddressSize {
- return syserror.EINVAL
+ if len(addr.Addr) != header.IPv4AddressSize {
+ return protocolAddress, syserror.EINVAL
}
if addr.PrefixLen > header.IPv4AddressSize*8 {
- return syserror.EINVAL
+ return protocolAddress, syserror.EINVAL
}
protocol = ipv4.ProtocolNumber
- address = tcpip.Address(addr.Addr[:header.IPv4AddressSize])
-
+ address = tcpip.Address(addr.Addr)
case linux.AF_INET6:
- if len(addr.Addr) < header.IPv6AddressSize {
- return syserror.EINVAL
+ if len(addr.Addr) != header.IPv6AddressSize {
+ return protocolAddress, syserror.EINVAL
}
if addr.PrefixLen > header.IPv6AddressSize*8 {
- return syserror.EINVAL
+ return protocolAddress, syserror.EINVAL
}
protocol = ipv6.ProtocolNumber
- address = tcpip.Address(addr.Addr[:header.IPv6AddressSize])
-
+ address = tcpip.Address(addr.Addr)
default:
- return syserror.ENOTSUP
+ return protocolAddress, syserror.ENOTSUP
}
- protocolAddress := tcpip.ProtocolAddress{
+ protocolAddress = tcpip.ProtocolAddress{
Protocol: protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: address,
PrefixLen: int(addr.PrefixLen),
},
}
+ return protocolAddress, nil
+}
+
+// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr.
+func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+ protocolAddress, err := convertAddr(addr)
+ if err != nil {
+ return err
+ }
// Attach address to interface.
- if err := s.Stack.AddProtocolAddressWithOptions(tcpip.NICID(idx), protocolAddress, stack.CanBePrimaryEndpoint); err != nil {
+ nicID := tcpip.NICID(idx)
+ if err := s.Stack.AddProtocolAddressWithOptions(nicID, protocolAddress, stack.CanBePrimaryEndpoint); err != nil {
+ return syserr.TranslateNetstackError(err).ToError()
+ }
+
+ // Add route for local network if it doesn't exist already.
+ localRoute := tcpip.Route{
+ Destination: protocolAddress.AddressWithPrefix.Subnet(),
+ Gateway: "", // No gateway for local network.
+ NIC: nicID,
+ }
+
+ for _, rt := range s.Stack.GetRouteTable() {
+ if rt.Equal(localRoute) {
+ return nil
+ }
+ }
+
+ // Local route does not exist yet. Add it.
+ s.Stack.AddRoute(localRoute)
+
+ return nil
+}
+
+// RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr.
+func (s *Stack) RemoveInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+ protocolAddress, err := convertAddr(addr)
+ if err != nil {
+ return err
+ }
+
+ // Remove addresses matching the address and prefix.
+ nicID := tcpip.NICID(idx)
+ if err := s.Stack.RemoveAddress(nicID, protocolAddress.AddressWithPrefix.Address); err != nil {
return syserr.TranslateNetstackError(err).ToError()
}
- // Add route for local network.
- s.Stack.AddRoute(tcpip.Route{
+ // Remove the corresponding local network route if it exists.
+ localRoute := tcpip.Route{
Destination: protocolAddress.AddressWithPrefix.Subnet(),
Gateway: "", // No gateway for local network.
- NIC: tcpip.NICID(idx),
+ NIC: nicID,
+ }
+ s.Stack.RemoveRoutes(func(rt tcpip.Route) bool {
+ return rt.Equal(localRoute)
})
+
return nil
}
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index cc7408698..cce0acc33 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "socket_refs.go",
package = "unix",
prefix = "socketOperations",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "SocketOperations",
},
@@ -19,7 +19,7 @@ go_template_instance(
out = "socket_vfs2_refs.go",
package = "unix",
prefix = "socketVFS2",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "SocketVFS2",
},
@@ -43,6 +43,7 @@ go_library(
"//pkg/log",
"//pkg/marshal",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/device",
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
index 26c3a51b9..3ebbd28b0 100644
--- a/pkg/sentry/socket/unix/transport/BUILD
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -20,7 +20,7 @@ go_template_instance(
out = "queue_refs.go",
package = "transport",
prefix = "queue",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "queue",
},
@@ -44,6 +44,7 @@ go_library(
"//pkg/ilist",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/sync",
"//pkg/syserr",
"//pkg/tcpip",
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index d6fc03520..b648273a4 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -32,6 +32,8 @@ import (
const initialLimit = 16 * 1024
// A RightsControlMessage is a control message containing FDs.
+//
+// +stateify savable
type RightsControlMessage interface {
// Clone returns a copy of the RightsControlMessage.
Clone() RightsControlMessage
@@ -336,7 +338,7 @@ type Receiver interface {
RecvMaxQueueSize() int64
// Release releases any resources owned by the Receiver. It should be
- // called before droping all references to a Receiver.
+ // called before dropping all references to a Receiver.
Release(ctx context.Context)
}
@@ -487,7 +489,7 @@ func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds
c := q.control.Clone()
// Don't consume data since we are peeking.
- copied, data, _ = vecCopy(data, q.buffer)
+ copied, _, _ = vecCopy(data, q.buffer)
return copied, copied, c, false, q.addr, notify, nil
}
@@ -572,6 +574,12 @@ func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds
return copied, copied, c, cmTruncated, q.addr, notify, nil
}
+// Release implements Receiver.Release.
+func (q *streamQueueReceiver) Release(ctx context.Context) {
+ q.queueReceiver.Release(ctx)
+ q.control.Release(ctx)
+}
+
// A ConnectedEndpoint is an Endpoint that can be used to send Messages.
type ConnectedEndpoint interface {
// Passcred implements Endpoint.Passcred.
@@ -619,7 +627,7 @@ type ConnectedEndpoint interface {
SendMaxQueueSize() int64
// Release releases any resources owned by the ConnectedEndpoint. It should
- // be called before droping all references to a ConnectedEndpoint.
+ // be called before dropping all references to a ConnectedEndpoint.
Release(ctx context.Context)
// CloseUnread sets the fact that this end is closed with unread data to
@@ -879,7 +887,7 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
switch opt {
- case tcpip.KeepaliveEnabledOption:
+ case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption:
return false, nil
case tcpip.PasscredOption:
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index a4a76d0a3..adad485a9 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -81,7 +81,6 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty
},
}
s.EnableLeakCheck()
-
return fs.NewFile(ctx, d, flags, &s)
}
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index 678355fb9..7a78444dc 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -55,7 +55,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{})
// returns a corresponding file description.
func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) {
mnt := t.Kernel().SocketMount()
- d := sockfs.NewDentry(t.Credentials(), mnt)
+ d := sockfs.NewDentry(t, mnt)
defer d.DecRef(t)
fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &vfs.FileLocks{})
@@ -80,6 +80,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3
stype: stype,
},
}
+ sock.EnableLeakCheck()
sock.LockFD.Init(locks)
vfsfd := &sock.vfsfd
if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{
diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD
index 0ea4aab8b..563d60578 100644
--- a/pkg/sentry/state/BUILD
+++ b/pkg/sentry/state/BUILD
@@ -12,10 +12,12 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
+ "//pkg/context",
"//pkg/log",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/time",
+ "//pkg/sentry/vfs",
"//pkg/sentry/watchdog",
"//pkg/state/statefile",
"//pkg/syserror",
diff --git a/pkg/sentry/state/state.go b/pkg/sentry/state/state.go
index 245d2c5cf..167754537 100644
--- a/pkg/sentry/state/state.go
+++ b/pkg/sentry/state/state.go
@@ -19,10 +19,12 @@ import (
"fmt"
"io"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sentry/watchdog"
"gvisor.dev/gvisor/pkg/state/statefile"
"gvisor.dev/gvisor/pkg/syserror"
@@ -57,7 +59,7 @@ type SaveOpts struct {
}
// Save saves the system state.
-func (opts SaveOpts) Save(k *kernel.Kernel, w *watchdog.Watchdog) error {
+func (opts SaveOpts) Save(ctx context.Context, k *kernel.Kernel, w *watchdog.Watchdog) error {
log.Infof("Sandbox save started, pausing all tasks.")
k.Pause()
k.ReceiveTaskStates()
@@ -81,7 +83,7 @@ func (opts SaveOpts) Save(k *kernel.Kernel, w *watchdog.Watchdog) error {
err = ErrStateFile{err}
} else {
// Save the kernel.
- err = k.SaveTo(wc)
+ err = k.SaveTo(ctx, wc)
// ENOSPC is a state file error. This error can only come from
// writing the state file, and not from fs.FileOperations.Fsync
@@ -108,7 +110,7 @@ type LoadOpts struct {
}
// Load loads the given kernel, setting the provided platform and stack.
-func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack, clocks time.Clocks) error {
+func (opts LoadOpts) Load(ctx context.Context, k *kernel.Kernel, n inet.Stack, clocks time.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error {
// Open the file.
r, m, err := statefile.NewReader(opts.Source, opts.Key)
if err != nil {
@@ -118,5 +120,5 @@ func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack, clocks time.Clocks) er
previousMetadata = m
// Restore the Kernel object graph.
- return k.LoadFrom(r, n, clocks)
+ return k.LoadFrom(ctx, r, n, clocks, vfsOpts)
}
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index 9c9def7cd..36902d177 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -118,7 +118,7 @@ var AMD64 = &kernel.SyscallTable{
63: syscalls.Supported("uname", Uname),
64: syscalls.Supported("semget", Semget),
65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
- 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil),
+ 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil),
67: syscalls.Supported("shmdt", Shmdt),
68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
@@ -619,7 +619,7 @@ var ARM64 = &kernel.SyscallTable{
188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
190: syscalls.Supported("semget", Semget),
- 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil),
+ 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil),
192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}),
193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil),
diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go
index 47dadb800..c2d4bf805 100644
--- a/pkg/sentry/syscalls/linux/sys_sem.go
+++ b/pkg/sentry/syscalls/linux/sys_sem.go
@@ -129,9 +129,17 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
v, err := getPID(t, id, num)
return uintptr(v), nil, err
+ case linux.IPC_STAT:
+ arg := args[3].Pointer()
+ ds, err := ipcStat(t, id)
+ if err == nil {
+ _, err = ds.CopyOut(t, arg)
+ }
+
+ return 0, nil, err
+
case linux.IPC_INFO,
linux.SEM_INFO,
- linux.IPC_STAT,
linux.SEM_STAT,
linux.SEM_STAT_ANY,
linux.GETNCNT,
@@ -171,6 +179,16 @@ func ipcSet(t *kernel.Task, id int32, uid auth.UID, gid auth.GID, perms fs.FileP
return set.Change(t, creds, owner, perms)
}
+func ipcStat(t *kernel.Task, id int32) (*linux.SemidDS, error) {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByID(id)
+ if set == nil {
+ return nil, syserror.EINVAL
+ }
+ creds := auth.CredentialsFromContext(t)
+ return set.GetStat(creds)
+}
+
func setVal(t *kernel.Task, id int32, num int32, val int16) error {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
diff --git a/pkg/sentry/syscalls/linux/sys_sysinfo.go b/pkg/sentry/syscalls/linux/sys_sysinfo.go
index 6320593f0..db3d924d9 100644
--- a/pkg/sentry/syscalls/linux/sys_sysinfo.go
+++ b/pkg/sentry/syscalls/linux/sys_sysinfo.go
@@ -21,7 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/usage"
)
-// Sysinfo implements the sysinfo syscall as described in man 2 sysinfo.
+// Sysinfo implements Linux syscall sysinfo(2).
func Sysinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
addr := args[0].Pointer()
diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go
index d8b8d9783..36e89700e 100644
--- a/pkg/sentry/syscalls/linux/vfs2/fd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/fd.go
@@ -145,16 +145,6 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return uintptr(file.StatusFlags()), nil, nil
case linux.F_SETFL:
return 0, nil, file.SetStatusFlags(t, t.Credentials(), args[2].Uint())
- case linux.F_SETPIPE_SZ:
- pipefile, ok := file.Impl().(*pipe.VFSPipeFD)
- if !ok {
- return 0, nil, syserror.EBADF
- }
- n, err := pipefile.SetPipeSize(int64(args[2].Int()))
- if err != nil {
- return 0, nil, err
- }
- return uintptr(n), nil, nil
case linux.F_GETOWN:
owner, hasOwner := getAsyncOwner(t, file)
if !hasOwner {
@@ -190,6 +180,16 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, err
}
return 0, nil, setAsyncOwner(t, file, owner.Type, owner.PID)
+ case linux.F_SETPIPE_SZ:
+ pipefile, ok := file.Impl().(*pipe.VFSPipeFD)
+ if !ok {
+ return 0, nil, syserror.EBADF
+ }
+ n, err := pipefile.SetPipeSize(int64(args[2].Int()))
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
case linux.F_GETPIPE_SZ:
pipefile, ok := file.Impl().(*pipe.VFSPipeFD)
if !ok {
diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go
index bf5c1171f..035e2a6b0 100644
--- a/pkg/sentry/syscalls/linux/vfs2/splice.go
+++ b/pkg/sentry/syscalls/linux/vfs2/splice.go
@@ -45,6 +45,9 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if count > int64(kernel.MAX_RW_COUNT) {
count = int64(kernel.MAX_RW_COUNT)
}
+ if count < 0 {
+ return 0, nil, syserror.EINVAL
+ }
// Check for invalid flags.
if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 {
@@ -192,6 +195,9 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
if count > int64(kernel.MAX_RW_COUNT) {
count = int64(kernel.MAX_RW_COUNT)
}
+ if count < 0 {
+ return 0, nil, syserror.EINVAL
+ }
// Check for invalid flags.
if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 {
diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go
index ab1d140d2..5ed6726ab 100644
--- a/pkg/sentry/usage/memory.go
+++ b/pkg/sentry/usage/memory.go
@@ -278,7 +278,7 @@ func TotalMemory(memSize, used uint64) uint64 {
}
if memSize < used {
memSize = used
- // Bump totalSize to the next largest power of 2, if one exists, so
+ // Bump memSize to the next largest power of 2, if one exists, so
// that MemFree isn't 0.
if msb := bits.MostSignificantOne64(memSize); msb < 63 {
memSize = uint64(1) << (uint(msb) + 1)
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index c855608db..440c9307c 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -32,7 +32,7 @@ go_template_instance(
out = "file_description_refs.go",
package = "vfs",
prefix = "FileDescription",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "FileDescription",
},
@@ -43,7 +43,7 @@ go_template_instance(
out = "mount_namespace_refs.go",
package = "vfs",
prefix = "MountNamespace",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "MountNamespace",
},
@@ -54,7 +54,7 @@ go_template_instance(
out = "filesystem_refs.go",
package = "vfs",
prefix = "Filesystem",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "Filesystem",
},
@@ -87,6 +87,7 @@ go_library(
"pathname.go",
"permissions.go",
"resolving_path.go",
+ "save_restore.go",
"vfs.go",
],
visibility = ["//pkg/sentry:internal"],
@@ -99,6 +100,7 @@ go_library(
"//pkg/gohacks",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/fs",
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
index 8f36c3e3b..a98aac52b 100644
--- a/pkg/sentry/vfs/epoll.go
+++ b/pkg/sentry/vfs/epoll.go
@@ -74,7 +74,7 @@ type epollInterestKey struct {
// +stateify savable
type epollInterest struct {
// epoll is the owning EpollInstance. epoll is immutable.
- epoll *EpollInstance
+ epoll *EpollInstance `state:"wait"`
// key is the file to which this epollInterest applies. key is immutable.
key epollInterestKey
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 183957ad8..546e445aa 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -183,7 +183,6 @@ func (fd *FileDescription) DecRef(ctx context.Context) {
}
fd.vd.DecRef(ctx)
fd.flagsMu.Lock()
- // TODO(gvisor.dev/issue/1663): We may need to unregister during save, as we do in VFS1.
if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil {
fd.asyncHandler.Unregister(fd)
}
diff --git a/pkg/sentry/vfs/genericfstree/genericfstree.go b/pkg/sentry/vfs/genericfstree/genericfstree.go
index 2d27d9d35..ba6e6ed49 100644
--- a/pkg/sentry/vfs/genericfstree/genericfstree.go
+++ b/pkg/sentry/vfs/genericfstree/genericfstree.go
@@ -71,7 +71,7 @@ func PrependPath(vfsroot vfs.VirtualDentry, mnt *vfs.Mount, d *Dentry, b *fspath
if mnt == vfsroot.Mount() && &d.vfsd == vfsroot.Dentry() {
return vfs.PrependPathAtVFSRootError{}
}
- if &d.vfsd == mnt.Root() {
+ if mnt != nil && &d.vfsd == mnt.Root() {
return nil
}
if d.parent == nil {
@@ -81,3 +81,12 @@ func PrependPath(vfsroot vfs.VirtualDentry, mnt *vfs.Mount, d *Dentry, b *fspath
d = d.parent
}
}
+
+// DebugPathname returns a pathname to d relative to its filesystem root.
+// DebugPathname does not correspond to any Linux function; it's used to
+// generate dentry pathnames for debugging.
+func DebugPathname(d *Dentry) string {
+ var b fspath.Builder
+ _ = PrependPath(vfs.VirtualDentry{}, nil, d, &b)
+ return b.String()
+}
diff --git a/pkg/sentry/vfs/lock.go b/pkg/sentry/vfs/lock.go
index 55783d4eb..1ff202f2a 100644
--- a/pkg/sentry/vfs/lock.go
+++ b/pkg/sentry/vfs/lock.go
@@ -12,11 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package lock provides POSIX and BSD style file locking for VFS2 file
-// implementations.
-//
-// The actual implementations can be found in the lock package under
-// sentry/fs/lock.
package vfs
import (
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index 78f115bfa..d452d2cda 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -106,6 +107,9 @@ func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *Mount
if opts.ReadOnly {
mnt.setReadOnlyLocked(true)
}
+ if refsvfs2.LeakCheckEnabled() {
+ refsvfs2.Register(mnt, "vfs.Mount")
+ }
return mnt
}
@@ -489,26 +493,38 @@ func (mnt *Mount) IncRef() {
// DecRef decrements mnt's reference count.
func (mnt *Mount) DecRef(ctx context.Context) {
- refs := atomic.AddInt64(&mnt.refs, -1)
- if refs&^math.MinInt64 == 0 { // mask out MSB
- var vd VirtualDentry
- if mnt.parent() != nil {
- mnt.vfs.mountMu.Lock()
- mnt.vfs.mounts.seq.BeginWrite()
- vd = mnt.vfs.disconnectLocked(mnt)
- mnt.vfs.mounts.seq.EndWrite()
- mnt.vfs.mountMu.Unlock()
- }
- if mnt.root != nil {
- mnt.root.DecRef(ctx)
- }
- mnt.fs.DecRef(ctx)
- if vd.Ok() {
- vd.DecRef(ctx)
+ r := atomic.AddInt64(&mnt.refs, -1)
+ if r&^math.MinInt64 == 0 { // mask out MSB
+ if refsvfs2.LeakCheckEnabled() {
+ refsvfs2.Unregister(mnt, "vfs.Mount")
}
+ mnt.destroy(ctx)
}
}
+func (mnt *Mount) destroy(ctx context.Context) {
+ var vd VirtualDentry
+ if mnt.parent() != nil {
+ mnt.vfs.mountMu.Lock()
+ mnt.vfs.mounts.seq.BeginWrite()
+ vd = mnt.vfs.disconnectLocked(mnt)
+ mnt.vfs.mounts.seq.EndWrite()
+ mnt.vfs.mountMu.Unlock()
+ }
+ if mnt.root != nil {
+ mnt.root.DecRef(ctx)
+ }
+ mnt.fs.DecRef(ctx)
+ if vd.Ok() {
+ vd.DecRef(ctx)
+ }
+}
+
+// LeakMessage implements refsvfs2.CheckedObject.LeakMessage.
+func (mnt *Mount) LeakMessage() string {
+ return fmt.Sprintf("[vfs.Mount %p] reference count of %d instead of 0", mnt, atomic.LoadInt64(&mnt.refs))
+}
+
// DecRef decrements mntns' reference count.
func (mntns *MountNamespace) DecRef(ctx context.Context) {
vfs := mntns.root.fs.VirtualFilesystem()
diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go
index b7d122d22..cb48c37a1 100644
--- a/pkg/sentry/vfs/mount_unsafe.go
+++ b/pkg/sentry/vfs/mount_unsafe.go
@@ -98,7 +98,6 @@ type mountTable struct {
// length and cap in separate uint32s) for ~free.
size uint64
- // FIXME(gvisor.dev/issue/1663): Slots need to be saved.
slots unsafe.Pointer `state:"nosave"` // []mountSlot; never nil after Init
}
@@ -212,6 +211,26 @@ loop:
}
}
+// Range calls f on each Mount in mt. If f returns false, Range stops iteration
+// and returns immediately.
+func (mt *mountTable) Range(f func(*Mount) bool) {
+ tcap := uintptr(1) << (mt.size & mtSizeOrderMask)
+ slotPtr := mt.slots
+ last := unsafe.Pointer(uintptr(mt.slots) + ((tcap - 1) * mountSlotBytes))
+ for {
+ slot := (*mountSlot)(slotPtr)
+ if slot.value != nil {
+ if !f((*Mount)(slot.value)) {
+ return
+ }
+ }
+ if slotPtr == last {
+ return
+ }
+ slotPtr = unsafe.Pointer(uintptr(slotPtr) + mountSlotBytes)
+ }
+}
+
// Insert inserts the given mount into mt.
//
// Preconditions: mt must not already contain a Mount with the same mount point
diff --git a/pkg/sentry/vfs/save_restore.go b/pkg/sentry/vfs/save_restore.go
new file mode 100644
index 000000000..46e50d55d
--- /dev/null
+++ b/pkg/sentry/vfs/save_restore.go
@@ -0,0 +1,124 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
+)
+
+// FilesystemImplSaveRestoreExtension is an optional extension to
+// FilesystemImpl.
+type FilesystemImplSaveRestoreExtension interface {
+ // PrepareSave prepares this filesystem for serialization.
+ PrepareSave(ctx context.Context) error
+
+ // CompleteRestore completes restoration from checkpoint for this
+ // filesystem after deserialization.
+ CompleteRestore(ctx context.Context, opts CompleteRestoreOptions) error
+}
+
+// PrepareSave prepares all filesystems for serialization.
+func (vfs *VirtualFilesystem) PrepareSave(ctx context.Context) error {
+ failures := 0
+ for fs := range vfs.getFilesystems() {
+ if ext, ok := fs.impl.(FilesystemImplSaveRestoreExtension); ok {
+ if err := ext.PrepareSave(ctx); err != nil {
+ ctx.Warningf("%T.PrepareSave failed: %v", fs.impl, err)
+ failures++
+ }
+ }
+ fs.DecRef(ctx)
+ }
+ if failures != 0 {
+ return fmt.Errorf("%d filesystems failed to prepare for serialization", failures)
+ }
+ return nil
+}
+
+// CompleteRestore completes restoration from checkpoint for all filesystems
+// after deserialization.
+func (vfs *VirtualFilesystem) CompleteRestore(ctx context.Context, opts *CompleteRestoreOptions) error {
+ failures := 0
+ for fs := range vfs.getFilesystems() {
+ if ext, ok := fs.impl.(FilesystemImplSaveRestoreExtension); ok {
+ if err := ext.CompleteRestore(ctx, *opts); err != nil {
+ ctx.Warningf("%T.CompleteRestore failed: %v", fs.impl, err)
+ failures++
+ }
+ }
+ fs.DecRef(ctx)
+ }
+ if failures != 0 {
+ return fmt.Errorf("%d filesystems failed to complete restore after deserialization", failures)
+ }
+ return nil
+}
+
+// CompleteRestoreOptions contains options to
+// VirtualFilesystem.CompleteRestore() and
+// FilesystemImplSaveRestoreExtension.CompleteRestore().
+type CompleteRestoreOptions struct {
+ // If ValidateFileSizes is true, filesystem implementations backed by
+ // remote filesystems should verify that file sizes have not changed
+ // between checkpoint and restore.
+ ValidateFileSizes bool
+
+ // If ValidateFileModificationTimestamps is true, filesystem
+ // implementations backed by remote filesystems should validate that file
+ // mtimes have not changed between checkpoint and restore.
+ ValidateFileModificationTimestamps bool
+}
+
+// saveMounts is called by stateify.
+func (vfs *VirtualFilesystem) saveMounts() []*Mount {
+ if atomic.LoadPointer(&vfs.mounts.slots) == nil {
+ // vfs.Init() was never called.
+ return nil
+ }
+ var mounts []*Mount
+ vfs.mounts.Range(func(mount *Mount) bool {
+ mounts = append(mounts, mount)
+ return true
+ })
+ return mounts
+}
+
+// loadMounts is called by stateify.
+func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) {
+ if mounts == nil {
+ return
+ }
+ vfs.mounts.Init()
+ for _, mount := range mounts {
+ vfs.mounts.Insert(mount)
+ }
+}
+
+func (mnt *Mount) afterLoad() {
+ if refsvfs2.LeakCheckEnabled() && atomic.LoadInt64(&mnt.refs) != 0 {
+ refsvfs2.Register(mnt, "vfs.Mount")
+ }
+}
+
+// afterLoad is called by stateify.
+func (epi *epollInterest) afterLoad() {
+ // Mark all epollInterests as ready after restore so that the next call to
+ // EpollInstance.ReadEvents() rechecks their readiness.
+ epi.Callback(nil)
+}
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 38d2701d2..48d6252f7 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -71,7 +71,7 @@ type VirtualFilesystem struct {
// points.
//
// mounts is analogous to Linux's mount_hashtable.
- mounts mountTable
+ mounts mountTable `state:".([]*Mount)"`
// mountpoints maps mount points to mounts at those points in all
// namespaces. mountpoints is protected by mountMu.
@@ -780,23 +780,27 @@ func (vfs *VirtualFilesystem) RemoveXattrAt(ctx context.Context, creds *auth.Cre
// SyncAllFilesystems has the semantics of Linux's sync(2).
func (vfs *VirtualFilesystem) SyncAllFilesystems(ctx context.Context) error {
+ var retErr error
+ for fs := range vfs.getFilesystems() {
+ if err := fs.impl.Sync(ctx); err != nil && retErr == nil {
+ retErr = err
+ }
+ fs.DecRef(ctx)
+ }
+ return retErr
+}
+
+func (vfs *VirtualFilesystem) getFilesystems() map[*Filesystem]struct{} {
fss := make(map[*Filesystem]struct{})
vfs.filesystemsMu.Lock()
+ defer vfs.filesystemsMu.Unlock()
for fs := range vfs.filesystems {
if !fs.TryIncRef() {
continue
}
fss[fs] = struct{}{}
}
- vfs.filesystemsMu.Unlock()
- var retErr error
- for fs := range fss {
- if err := fs.impl.Sync(ctx); err != nil && retErr == nil {
- retErr = err
- }
- fs.DecRef(ctx)
- }
- return retErr
+ return fss
}
// MkdirAllAt recursively creates non-existent directories on the given path
diff --git a/pkg/shim/v2/runtimeoptions/BUILD b/pkg/shim/v2/runtimeoptions/BUILD
index ba2ed1ea7..abb8c3be3 100644
--- a/pkg/shim/v2/runtimeoptions/BUILD
+++ b/pkg/shim/v2/runtimeoptions/BUILD
@@ -11,12 +11,12 @@ proto_library(
go_library(
name = "runtimeoptions",
- srcs = ["runtimeoptions.go"],
- visibility = ["//pkg/shim/v2:__pkg__"],
- deps = [
- ":api_go_proto",
- "@com_github_gogo_protobuf//proto:go_default_library",
+ srcs = [
+ "runtimeoptions.go",
+ "runtimeoptions_cri.go",
],
+ visibility = ["//pkg/shim/v2:__pkg__"],
+ deps = ["@com_github_gogo_protobuf//proto:go_default_library"],
)
go_test(
@@ -27,6 +27,6 @@ go_test(
deps = [
"@com_github_containerd_containerd//runtime/v1/shim/v1:go_default_library",
"@com_github_containerd_typeurl//:go_default_library",
- "@com_github_golang_protobuf//proto:go_default_library",
+ "@com_github_gogo_protobuf//proto:go_default_library",
],
)
diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions.go b/pkg/shim/v2/runtimeoptions/runtimeoptions.go
index aaf17b87a..072dd87f0 100644
--- a/pkg/shim/v2/runtimeoptions/runtimeoptions.go
+++ b/pkg/shim/v2/runtimeoptions/runtimeoptions.go
@@ -13,18 +13,5 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package runtimeoptions contains the runtimeoptions proto.
package runtimeoptions
-
-import (
- proto "github.com/gogo/protobuf/proto"
- pb "gvisor.dev/gvisor/pkg/shim/v2/runtimeoptions/api_go_proto"
-)
-
-type Options = pb.Options
-
-func init() {
- // The generated proto file auto registers with "golang/protobuf/proto"
- // package. However, typeurl uses "golang/gogo/protobuf/proto". So registers
- // the type there too.
- proto.RegisterType((*Options)(nil), "cri.runtimeoptions.v1.Options")
-}
diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go b/pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go
new file mode 100644
index 000000000..e6102b4cf
--- /dev/null
+++ b/pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go
@@ -0,0 +1,383 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package runtimeoptions
+
+import (
+ "fmt"
+ "io"
+ "reflect"
+ "strings"
+
+ proto "github.com/gogo/protobuf/proto"
+)
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the proto package it is being compiled against.
+// A compilation error at this line likely means your copy of the
+// proto package needs to be updated.
+const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package
+
+type Options struct {
+ // TypeUrl specifies the type of the content inside the config file.
+ TypeUrl string `protobuf:"bytes,1,opt,name=type_url,json=typeUrl,proto3" json:"type_url,omitempty"`
+ // ConfigPath specifies the filesystem location of the config file
+ // used by the runtime.
+ ConfigPath string `protobuf:"bytes,2,opt,name=config_path,json=configPath,proto3" json:"config_path,omitempty"`
+}
+
+func (m *Options) Reset() { *m = Options{} }
+func (*Options) ProtoMessage() {}
+func (*Options) Descriptor() ([]byte, []int) { return fileDescriptorApi, []int{0} }
+
+func (m *Options) GetTypeUrl() string {
+ if m != nil {
+ return m.TypeUrl
+ }
+ return ""
+}
+
+func (m *Options) GetConfigPath() string {
+ if m != nil {
+ return m.ConfigPath
+ }
+ return ""
+}
+
+func init() {
+ proto.RegisterType((*Options)(nil), "cri.runtimeoptions.v1.Options")
+}
+
+func (m *Options) Marshal() (dAtA []byte, err error) {
+ size := m.Size()
+ dAtA = make([]byte, size)
+ n, err := m.MarshalTo(dAtA)
+ if err != nil {
+ return nil, err
+ }
+ return dAtA[:n], nil
+}
+
+func (m *Options) MarshalTo(dAtA []byte) (int, error) {
+ var i int
+ _ = i
+ var l int
+ _ = l
+ if len(m.TypeUrl) > 0 {
+ dAtA[i] = 0xa
+ i++
+ i = encodeVarintApi(dAtA, i, uint64(len(m.TypeUrl)))
+ i += copy(dAtA[i:], m.TypeUrl)
+ }
+ if len(m.ConfigPath) > 0 {
+ dAtA[i] = 0x12
+ i++
+ i = encodeVarintApi(dAtA, i, uint64(len(m.ConfigPath)))
+ i += copy(dAtA[i:], m.ConfigPath)
+ }
+ return i, nil
+}
+
+func encodeVarintApi(dAtA []byte, offset int, v uint64) int {
+ for v >= 1<<7 {
+ dAtA[offset] = uint8(v&0x7f | 0x80)
+ v >>= 7
+ offset++
+ }
+ dAtA[offset] = uint8(v)
+ return offset + 1
+}
+
+func (m *Options) Size() (n int) {
+ var l int
+ _ = l
+ l = len(m.TypeUrl)
+ if l > 0 {
+ n += 1 + l + sovApi(uint64(l))
+ }
+ l = len(m.ConfigPath)
+ if l > 0 {
+ n += 1 + l + sovApi(uint64(l))
+ }
+ return n
+}
+
+func sovApi(x uint64) (n int) {
+ for {
+ n++
+ x >>= 7
+ if x == 0 {
+ break
+ }
+ }
+ return n
+}
+
+func sozApi(x uint64) (n int) {
+ return sovApi(uint64((x << 1) ^ uint64((int64(x) >> 63))))
+}
+
+func (this *Options) String() string {
+ if this == nil {
+ return "nil"
+ }
+ s := strings.Join([]string{`&Options{`,
+ `TypeUrl:` + fmt.Sprintf("%v", this.TypeUrl) + `,`,
+ `ConfigPath:` + fmt.Sprintf("%v", this.ConfigPath) + `,`,
+ `}`,
+ }, "")
+ return s
+}
+
+func valueToStringApi(v interface{}) string {
+ rv := reflect.ValueOf(v)
+ if rv.IsNil() {
+ return "nil"
+ }
+ pv := reflect.Indirect(rv).Interface()
+ return fmt.Sprintf("*%v", pv)
+}
+
+func (m *Options) Unmarshal(dAtA []byte) error {
+ l := len(dAtA)
+ iNdEx := 0
+ for iNdEx < l {
+ preIndex := iNdEx
+ var wire uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowApi
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := dAtA[iNdEx]
+ iNdEx++
+ wire |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ fieldNum := int32(wire >> 3)
+ wireType := int(wire & 0x7)
+ if wireType == 4 {
+ return fmt.Errorf("proto: Options: wiretype end group for non-group")
+ }
+ if fieldNum <= 0 {
+ return fmt.Errorf("proto: Options: illegal tag %d (wire type %d)", fieldNum, wire)
+ }
+ switch fieldNum {
+ case 1:
+ if wireType != 2 {
+ return fmt.Errorf("proto: wrong wireType = %d for field TypeUrl", wireType)
+ }
+ var stringLen uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowApi
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := dAtA[iNdEx]
+ iNdEx++
+ stringLen |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ intStringLen := int(stringLen)
+ if intStringLen < 0 {
+ return ErrInvalidLengthApi
+ }
+ postIndex := iNdEx + intStringLen
+ if postIndex > l {
+ return io.ErrUnexpectedEOF
+ }
+ m.TypeUrl = string(dAtA[iNdEx:postIndex])
+ iNdEx = postIndex
+ case 2:
+ if wireType != 2 {
+ return fmt.Errorf("proto: wrong wireType = %d for field ConfigPath", wireType)
+ }
+ var stringLen uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return ErrIntOverflowApi
+ }
+ if iNdEx >= l {
+ return io.ErrUnexpectedEOF
+ }
+ b := dAtA[iNdEx]
+ iNdEx++
+ stringLen |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ intStringLen := int(stringLen)
+ if intStringLen < 0 {
+ return ErrInvalidLengthApi
+ }
+ postIndex := iNdEx + intStringLen
+ if postIndex > l {
+ return io.ErrUnexpectedEOF
+ }
+ m.ConfigPath = string(dAtA[iNdEx:postIndex])
+ iNdEx = postIndex
+ default:
+ iNdEx = preIndex
+ skippy, err := skipApi(dAtA[iNdEx:])
+ if err != nil {
+ return err
+ }
+ if skippy < 0 {
+ return ErrInvalidLengthApi
+ }
+ if (iNdEx + skippy) > l {
+ return io.ErrUnexpectedEOF
+ }
+ iNdEx += skippy
+ }
+ }
+
+ if iNdEx > l {
+ return io.ErrUnexpectedEOF
+ }
+ return nil
+}
+
+func skipApi(dAtA []byte) (n int, err error) {
+ l := len(dAtA)
+ iNdEx := 0
+ for iNdEx < l {
+ var wire uint64
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return 0, ErrIntOverflowApi
+ }
+ if iNdEx >= l {
+ return 0, io.ErrUnexpectedEOF
+ }
+ b := dAtA[iNdEx]
+ iNdEx++
+ wire |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ wireType := int(wire & 0x7)
+ switch wireType {
+ case 0:
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return 0, ErrIntOverflowApi
+ }
+ if iNdEx >= l {
+ return 0, io.ErrUnexpectedEOF
+ }
+ iNdEx++
+ if dAtA[iNdEx-1] < 0x80 {
+ break
+ }
+ }
+ return iNdEx, nil
+ case 1:
+ iNdEx += 8
+ return iNdEx, nil
+ case 2:
+ var length int
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return 0, ErrIntOverflowApi
+ }
+ if iNdEx >= l {
+ return 0, io.ErrUnexpectedEOF
+ }
+ b := dAtA[iNdEx]
+ iNdEx++
+ length |= (int(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ iNdEx += length
+ if length < 0 {
+ return 0, ErrInvalidLengthApi
+ }
+ return iNdEx, nil
+ case 3:
+ for {
+ var innerWire uint64
+ var start int = iNdEx
+ for shift := uint(0); ; shift += 7 {
+ if shift >= 64 {
+ return 0, ErrIntOverflowApi
+ }
+ if iNdEx >= l {
+ return 0, io.ErrUnexpectedEOF
+ }
+ b := dAtA[iNdEx]
+ iNdEx++
+ innerWire |= (uint64(b) & 0x7F) << shift
+ if b < 0x80 {
+ break
+ }
+ }
+ innerWireType := int(innerWire & 0x7)
+ if innerWireType == 4 {
+ break
+ }
+ next, err := skipApi(dAtA[start:])
+ if err != nil {
+ return 0, err
+ }
+ iNdEx = start + next
+ }
+ return iNdEx, nil
+ case 4:
+ return iNdEx, nil
+ case 5:
+ iNdEx += 4
+ return iNdEx, nil
+ default:
+ return 0, fmt.Errorf("proto: illegal wireType %d", wireType)
+ }
+ }
+ panic("unreachable")
+}
+
+var (
+ ErrInvalidLengthApi = fmt.Errorf("proto: negative length found during unmarshaling")
+ ErrIntOverflowApi = fmt.Errorf("proto: integer overflow")
+)
+
+func init() { proto.RegisterFile("api.proto", fileDescriptorApi) }
+
+var fileDescriptorApi = []byte{
+ // 183 bytes of a gzipped FileDescriptorProto
+ 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x4c, 0x2c, 0xc8, 0xd4,
+ 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x12, 0x4d, 0x2e, 0xca, 0xd4, 0x2b, 0x2a, 0xcd, 0x2b, 0xc9,
+ 0xcc, 0x4d, 0xcd, 0x2f, 0x28, 0xc9, 0xcc, 0xcf, 0x2b, 0xd6, 0x2b, 0x33, 0x94, 0xd2, 0x4d, 0xcf,
+ 0x2c, 0xc9, 0x28, 0x4d, 0xd2, 0x4b, 0xce, 0xcf, 0xd5, 0x4f, 0xcf, 0x4f, 0xcf, 0xd7, 0x07, 0xab,
+ 0x4e, 0x2a, 0x4d, 0x03, 0xf3, 0xc0, 0x1c, 0x30, 0x0b, 0x62, 0x8a, 0x92, 0x2b, 0x17, 0xbb, 0x3f,
+ 0x44, 0xb3, 0x90, 0x24, 0x17, 0x47, 0x49, 0x65, 0x41, 0x6a, 0x7c, 0x69, 0x51, 0x8e, 0x04, 0xa3,
+ 0x02, 0xa3, 0x06, 0x67, 0x10, 0x3b, 0x88, 0x1f, 0x5a, 0x94, 0x23, 0x24, 0xcf, 0xc5, 0x9d, 0x9c,
+ 0x9f, 0x97, 0x96, 0x99, 0x1e, 0x5f, 0x90, 0x58, 0x92, 0x21, 0xc1, 0x04, 0x96, 0xe5, 0x82, 0x08,
+ 0x05, 0x24, 0x96, 0x64, 0x38, 0xc9, 0x9c, 0x78, 0x28, 0xc7, 0x78, 0xe3, 0xa1, 0x1c, 0x43, 0xc3,
+ 0x23, 0x39, 0xc6, 0x13, 0x8f, 0xe4, 0x18, 0x2f, 0x3c, 0x92, 0x63, 0x7c, 0xf0, 0x48, 0x8e, 0x71,
+ 0xc2, 0x63, 0x39, 0x86, 0x24, 0x36, 0xb0, 0x5d, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0x07,
+ 0x00, 0xf2, 0x18, 0xbe, 0x00, 0x00, 0x00,
+}
diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions_test.go b/pkg/shim/v2/runtimeoptions/runtimeoptions_test.go
index f4c238a00..c59a2400e 100644
--- a/pkg/shim/v2/runtimeoptions/runtimeoptions_test.go
+++ b/pkg/shim/v2/runtimeoptions/runtimeoptions_test.go
@@ -15,11 +15,12 @@
package runtimeoptions
import (
+ "bytes"
"testing"
shim "github.com/containerd/containerd/runtime/v1/shim/v1"
"github.com/containerd/typeurl"
- "github.com/golang/protobuf/proto"
+ "github.com/gogo/protobuf/proto"
)
func TestCreateTaskRequest(t *testing.T) {
@@ -32,7 +33,11 @@ func TestCreateTaskRequest(t *testing.T) {
if err := proto.UnmarshalText(encodedText, got); err != nil {
t.Fatalf("unable to unmarshal text: %v", err)
}
- t.Logf("got: %s", proto.MarshalTextString(got))
+ var textBuffer bytes.Buffer
+ if err := proto.MarshalText(&textBuffer, got); err != nil {
+ t.Errorf("unable to marshal text: %v", err)
+ }
+ t.Logf("got: %s", string(textBuffer.Bytes()))
// Check the options.
wantOptions := &Options{}
diff --git a/pkg/state/BUILD b/pkg/state/BUILD
index 089b3bbef..92c51879b 100644
--- a/pkg/state/BUILD
+++ b/pkg/state/BUILD
@@ -4,19 +4,6 @@ load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
go_template_instance(
- name = "pending_list",
- out = "pending_list.go",
- package = "state",
- prefix = "pending",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*objectEncodeState",
- "ElementMapper": "pendingMapper",
- "Linker": "*pendingEntry",
- },
-)
-
-go_template_instance(
name = "deferred_list",
out = "deferred_list.go",
package = "state",
@@ -83,7 +70,6 @@ go_library(
"deferred_list.go",
"encode.go",
"encode_unsafe.go",
- "pending_list.go",
"state.go",
"state_norace.go",
"state_race.go",
diff --git a/pkg/state/decode.go b/pkg/state/decode.go
index 89467ca8e..e519ddeca 100644
--- a/pkg/state/decode.go
+++ b/pkg/state/decode.go
@@ -21,6 +21,7 @@ import (
"math"
"reflect"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/state/wire"
)
@@ -258,7 +259,7 @@ func (ds *decodeState) waitObject(ods *objectDecodeState, encoded wire.Object, c
// For the purposes of this function, a child object is either a field within a
// struct or an array element, with one such indirection per element in
// path. The returned value may be an unexported field, so it may not be
-// directly assignable. See unsafePointerTo.
+// directly assignable. See decode_unsafe.go.
func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value {
// See wire.Ref.Dots. The path here is specified in reverse order.
for i := len(path) - 1; i >= 0; i-- {
@@ -519,9 +520,7 @@ func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, e
// Normal assignment: authoritative only if no dots.
v := ds.register(x, obj.Type().Elem())
- if v.IsValid() {
- obj.Set(unsafePointerTo(v))
- }
+ obj.Set(reflectValueRWAddr(v))
case wire.Bool:
obj.SetBool(bool(x))
case wire.Int:
@@ -559,7 +558,7 @@ func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, e
// contents will still be filled in later on.
typ := reflect.ArrayOf(int(x.Capacity), obj.Type().Elem()) // The object type.
v := ds.register(&x.Ref, typ)
- obj.Set(v.Slice3(0, int(x.Length), int(x.Capacity)))
+ obj.Set(reflectValueRWSlice3(v, 0, int(x.Length), int(x.Capacity)))
case *wire.Array:
ds.decodeArray(ods, obj, x)
case *wire.Struct:
@@ -592,7 +591,7 @@ func (ds *decodeState) Load(obj reflect.Value) {
ds.pending.PushBack(rootOds)
// Read the number of objects.
- lastID, object, err := ReadHeader(ds.r)
+ numObjects, object, err := ReadHeader(ds.r)
if err != nil {
Failf("header error: %w", err)
}
@@ -604,42 +603,44 @@ func (ds *decodeState) Load(obj reflect.Value) {
var (
encoded wire.Object
ods *objectDecodeState
- id = objectID(1)
+ id objectID
tid = typeID(1)
)
if err := safely(func() {
// Decode all objects in the stream.
//
- // Note that the structure of this decoding loop should match
- // the raw decoding loop in printer.go.
- for id <= objectID(lastID) {
- // Unmarshal the object.
+ // Note that the structure of this decoding loop should match the raw
+ // decoding loop in state/pretty/pretty.printer.printStream().
+ for i := uint64(0); i < numObjects; {
+ // Unmarshal either a type object or object ID.
encoded = wire.Load(ds.r)
-
- // Is this a type object? Handle inline.
- if wt, ok := encoded.(*wire.Type); ok {
- ds.types.Register(wt)
+ switch we := encoded.(type) {
+ case *wire.Type:
+ ds.types.Register(we)
tid++
encoded = nil
continue
+ case wire.Uint:
+ id = objectID(we)
+ i++
+ // Unmarshal and resolve the actual object.
+ encoded = wire.Load(ds.r)
+ ods = ds.lookup(id)
+ if ods != nil {
+ // Decode the object.
+ ds.decodeObject(ods, ods.obj, encoded)
+ } else {
+ // If an object hasn't had interest registered
+ // previously or isn't yet valid, we deferred
+ // decoding until interest is registered.
+ ds.deferred[id] = encoded
+ }
+ // For error handling.
+ ods = nil
+ encoded = nil
+ default:
+ Failf("wanted type or object ID, got %#v", encoded)
}
-
- // Actually resolve the object.
- ods = ds.lookup(id)
- if ods != nil {
- // Decode the object.
- ds.decodeObject(ods, ods.obj, encoded)
- } else {
- // If an object hasn't had interest registered
- // previously or isn't yet valid, we deferred
- // decoding until interest is registered.
- ds.deferred[id] = encoded
- }
-
- // For error handling.
- ods = nil
- encoded = nil
- id++
}
}); err != nil {
// Include as much information as we can, taking into account
@@ -647,16 +648,25 @@ func (ds *decodeState) Load(obj reflect.Value) {
if ods != nil {
Failf("error decoding object ID %d (%T) from %#v: %w", id, ods.obj.Interface(), encoded, err)
} else if encoded != nil {
- Failf("lookup error decoding object ID %d from %#v: %w", id, encoded, err)
+ Failf("error decoding from %#v: %w", encoded, err)
} else {
Failf("general decoding error: %w", err)
}
}
// Check if we have any deferred objects.
+ numDeferred := 0
for id, encoded := range ds.deferred {
- // Shoud never happen, the graph was bogus.
- Failf("still have deferred objects: one is ID %d, %#v", id, encoded)
+ numDeferred++
+ if s, ok := encoded.(*wire.Struct); ok && s.TypeID != 0 {
+ typ := ds.types.LookupType(typeID(s.TypeID))
+ log.Warningf("unused deferred object: ID %d, type %v", id, typ)
+ } else {
+ log.Warningf("unused deferred object: ID %d, %#v", id, encoded)
+ }
+ }
+ if numDeferred != 0 {
+ Failf("still had %d deferred objects", numDeferred)
}
// Scan and fire all callbacks. We iterate over the list of incomplete
diff --git a/pkg/state/decode_unsafe.go b/pkg/state/decode_unsafe.go
index d048f61a1..f1208e2a2 100644
--- a/pkg/state/decode_unsafe.go
+++ b/pkg/state/decode_unsafe.go
@@ -15,13 +15,62 @@
package state
import (
+ "fmt"
"reflect"
+ "runtime"
"unsafe"
)
-// unsafePointerTo is logically equivalent to reflect.Value.Addr, but works on
-// values representing unexported fields. This bypasses visibility, but not
-// type safety.
-func unsafePointerTo(obj reflect.Value) reflect.Value {
+// reflectValueRWAddr is equivalent to obj.Addr(), except that the returned
+// reflect.Value is usable in assignments even if obj was obtained by the use
+// of unexported struct fields.
+//
+// Preconditions: obj.CanAddr().
+func reflectValueRWAddr(obj reflect.Value) reflect.Value {
return reflect.NewAt(obj.Type(), unsafe.Pointer(obj.UnsafeAddr()))
}
+
+// reflectValueRWSlice3 is equivalent to arr.Slice3(i, j, k), except that the
+// returned reflect.Value is usable in assignments even if obj was obtained by
+// the use of unexported struct fields.
+//
+// Preconditions:
+// * arr.Kind() == reflect.Array.
+// * i, j, k >= 0.
+// * i <= j <= k <= arr.Len().
+func reflectValueRWSlice3(arr reflect.Value, i, j, k int) reflect.Value {
+ if arr.Kind() != reflect.Array {
+ panic(fmt.Sprintf("arr has kind %v, wanted %v", arr.Kind(), reflect.Array))
+ }
+ if i < 0 || j < 0 || k < 0 {
+ panic(fmt.Sprintf("negative subscripts (%d, %d, %d)", i, j, k))
+ }
+ if i > j {
+ panic(fmt.Sprintf("subscript i (%d) > j (%d)", i, j))
+ }
+ if j > k {
+ panic(fmt.Sprintf("subscript j (%d) > k (%d)", j, k))
+ }
+ if k > arr.Len() {
+ panic(fmt.Sprintf("subscript k (%d) > array length (%d)", k, arr.Len()))
+ }
+
+ sliceTyp := reflect.SliceOf(arr.Type().Elem())
+ if i == arr.Len() {
+ // By precondition, i == j == k == arr.Len().
+ return reflect.MakeSlice(sliceTyp, 0, 0)
+ }
+ slh := reflect.SliceHeader{
+ // reflect.Value.CanAddr() == false for arrays, so we need to get the
+ // address from the first element of the array.
+ Data: arr.Index(i).UnsafeAddr(),
+ Len: j - i,
+ Cap: k - i,
+ }
+ slobj := reflect.NewAt(sliceTyp, unsafe.Pointer(&slh)).Elem()
+ // Before slobj is constructed, arr holds the only pointer-typed pointer to
+ // the array since reflect.SliceHeader.Data is a uintptr, so arr must be
+ // kept alive.
+ runtime.KeepAlive(arr)
+ return slobj
+}
diff --git a/pkg/state/encode.go b/pkg/state/encode.go
index 92fcad4e9..560e7c2a3 100644
--- a/pkg/state/encode.go
+++ b/pkg/state/encode.go
@@ -17,13 +17,14 @@ package state
import (
"context"
"reflect"
+ "sort"
"gvisor.dev/gvisor/pkg/state/wire"
)
// objectEncodeState the type and identity of an object occupying a memory
// address range. This is the value type for addrSet, and the intrusive entry
-// for the pending and deferred lists.
+// for the deferred list.
type objectEncodeState struct {
// id is the assigned ID for this object.
id objectID
@@ -47,7 +48,6 @@ type objectEncodeState struct {
// references may be updated directly and automatically.
refs []*wire.Ref
- pendingEntry
deferredEntry
}
@@ -93,9 +93,15 @@ type encodeState struct {
// serialized.
pendingTypes []wire.Type
- // pending is the list of objects to be serialized. Serialization does
+ // pending maps object IDs to objects to be serialized. Serialization does
// not actually occur until the full object graph is computed.
- pending pendingList
+ pending map[objectID]*objectEncodeState
+
+ // encodedStructs maps reflect.Values representing structs to previous
+ // encodings of those structs. This is necessary to avoid duplicate calls
+ // to SaverLoader.StateSave() that may result in multiple calls to
+ // Sink.SaveValue() for a given field, resulting in object duplication.
+ encodedStructs map[reflect.Value]*wire.Struct
// stats tracks time data.
stats Stats
@@ -189,7 +195,8 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) {
// depending on this value knows there's nothing there.
return
}
- if seg, _ := es.values.Find(addr); seg.Ok() {
+ seg, gap := es.values.Find(addr)
+ if seg.Ok() {
// Ensure the map types match.
existing := seg.Value()
if existing.obj.Type() != obj.Type() {
@@ -203,13 +210,20 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) {
}
// Record the map.
+ r := addrRange{addr, addr + 1}
oes := &objectEncodeState{
id: es.nextID(),
obj: obj,
how: encodeMapAsValue,
}
- es.values.Add(addrRange{addr, addr + 1}, oes)
- es.pending.PushBack(oes)
+ // Use Insert instead of InsertWithoutMergingUnchecked when race
+ // detection is enabled to get additional sanity-checking from Merge.
+ if !raceEnabled {
+ es.values.InsertWithoutMergingUnchecked(gap, r, oes)
+ } else {
+ es.values.Insert(gap, r, oes)
+ }
+ es.pending[oes.id] = oes
es.deferred.PushBack(oes)
// See above: no ref recording.
@@ -245,7 +259,7 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) {
obj: obj,
}
es.zeroValues[typ] = oes
- es.pending.PushBack(oes)
+ es.pending[oes.id] = oes
es.deferred.PushBack(oes)
}
@@ -258,86 +272,112 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) {
size = 1 // See above.
}
- // Calculate the container.
end := addr + size
r := addrRange{addr, end}
- if seg, _ := es.values.Find(addr); seg.Ok() {
+ seg := es.values.LowerBoundSegment(addr)
+ var (
+ oes *objectEncodeState
+ gap addrGapIterator
+ )
+
+ // Does at least one previously-registered object overlap this one?
+ if seg.Ok() && seg.Start() < end {
existing := seg.Value()
- switch {
- case seg.Start() == addr && seg.End() == end && obj.Type() == existing.obj.Type():
- // The object is a perfect match. Happy path. Avoid the
- // traversal and just return directly. We don't need to
- // encode the type information or any dots here.
+
+ if seg.Range() == r && typ == existing.obj.Type() {
+ // This exact object is already registered. Avoid the traversal and
+ // just return directly. We don't need to encode the type
+ // information or any dots here.
ref.Root = wire.Uint(existing.id)
existing.refs = append(existing.refs, ref)
return
+ }
- case (seg.Start() < addr && seg.End() >= end) || (seg.Start() <= addr && seg.End() > end):
- // The previously registered object is larger than
- // this, no need to update. But we expect some
- // traversal below.
+ if seg.Range().IsSupersetOf(r) && (seg.Range() != r || isSameSizeParent(existing.obj, typ)) {
+ // This object is contained within a previously-registered object.
+ // Perform traversal from the container to the new object.
+ ref.Root = wire.Uint(existing.id)
+ ref.Dots = traverse(existing.obj.Type(), typ, seg.Start(), addr)
+ ref.Type = es.findType(existing.obj.Type())
+ existing.refs = append(existing.refs, ref)
+ return
+ }
- case seg.Start() == addr && seg.End() == end:
- if !isSameSizeParent(obj, existing.obj.Type()) {
- break // Needs traversal.
+ // This object contains one or more previously-registered objects.
+ // Remove them and update existing references to use the new one.
+ oes := &objectEncodeState{
+ // Reuse the root ID of the first contained element.
+ id: existing.id,
+ obj: obj,
+ }
+ type elementEncodeState struct {
+ addr uintptr
+ typ reflect.Type
+ refs []*wire.Ref
+ }
+ var (
+ elems []elementEncodeState
+ gap addrGapIterator
+ )
+ for {
+ // Each contained object should be completely contained within
+ // this one.
+ if raceEnabled && !r.IsSupersetOf(seg.Range()) {
+ Failf("containing object %#v does not contain existing object %#v", obj, existing.obj)
}
- fallthrough // Needs update.
-
- case (seg.Start() > addr && seg.End() <= end) || (seg.Start() >= addr && seg.End() < end):
- // Update the object and redo the encoding.
- old := existing.obj
- existing.obj = obj
+ elems = append(elems, elementEncodeState{
+ addr: seg.Start(),
+ typ: existing.obj.Type(),
+ refs: existing.refs,
+ })
+ delete(es.pending, existing.id)
es.deferred.Remove(existing)
- es.deferred.PushBack(existing)
-
- // The previously registered object is superseded by
- // this new object. We are guaranteed to not have any
- // mergeable neighbours in this segment set.
- if !raceEnabled {
- seg.SetRangeUnchecked(r)
- } else {
- // Add extra paranoid. This will be statically
- // removed at compile time unless a race build.
- es.values.Remove(seg)
- es.values.Add(r, existing)
- seg = es.values.LowerBoundSegment(addr)
+ gap = es.values.Remove(seg)
+ seg = gap.NextSegment()
+ if !seg.Ok() || seg.Start() >= end {
+ break
}
-
- // Compute the traversal required & update references.
- dots := traverse(obj.Type(), old.Type(), addr, seg.Start())
- wt := es.findType(obj.Type())
- for _, ref := range existing.refs {
+ existing = seg.Value()
+ }
+ wt := es.findType(typ)
+ for _, elem := range elems {
+ dots := traverse(typ, elem.typ, addr, elem.addr)
+ for _, ref := range elem.refs {
+ ref.Root = wire.Uint(oes.id)
ref.Dots = append(ref.Dots, dots...)
ref.Type = wt
}
- default:
- // There is a non-sensical overlap.
- Failf("overlapping objects: [new object] %#v [existing object] %#v", obj, existing.obj)
+ oes.refs = append(oes.refs, elem.refs...)
}
-
- // Compute the new reference, record and return it.
- ref.Root = wire.Uint(existing.id)
- ref.Dots = traverse(existing.obj.Type(), obj.Type(), seg.Start(), addr)
- ref.Type = es.findType(obj.Type())
- existing.refs = append(existing.refs, ref)
+ // Finally register the new containing object.
+ if !raceEnabled {
+ es.values.InsertWithoutMergingUnchecked(gap, r, oes)
+ } else {
+ es.values.Insert(gap, r, oes)
+ }
+ es.pending[oes.id] = oes
+ es.deferred.PushBack(oes)
+ ref.Root = wire.Uint(oes.id)
+ oes.refs = append(oes.refs, ref)
return
}
- // The only remaining case is a pointer value that doesn't overlap with
- // any registered addresses. Create a new entry for it, and start
- // tracking the first reference we just created.
- oes := &objectEncodeState{
+ // No existing object overlaps this one. Register a new object.
+ oes = &objectEncodeState{
id: es.nextID(),
obj: obj,
}
+ if seg.Ok() {
+ gap = seg.PrevGap()
+ } else {
+ gap = es.values.LastGap()
+ }
if !raceEnabled {
- es.values.AddWithoutMerging(r, oes)
+ es.values.InsertWithoutMergingUnchecked(gap, r, oes)
} else {
- // Merges should never happen. This is just enabled extra
- // sanity checks because the Merge function below will panic.
- es.values.Add(r, oes)
+ es.values.Insert(gap, r, oes)
}
- es.pending.PushBack(oes)
+ es.pending[oes.id] = oes
es.deferred.PushBack(oes)
ref.Root = wire.Uint(oes.id)
oes.refs = append(oes.refs, ref)
@@ -439,6 +479,14 @@ func (oe *objectEncoder) save(slot int, obj reflect.Value) {
// encodeStruct encodes a composite object.
func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) {
+ if s, ok := es.encodedStructs[obj]; ok {
+ *dest = s
+ return
+ }
+ s := &wire.Struct{}
+ *dest = s
+ es.encodedStructs[obj] = s
+
// Ensure that the obj is addressable. There are two cases when it is
// not. First, is when this is dispatched via SaveValue. Second, when
// this is a map key as a struct. Either way, we need to make a copy to
@@ -449,10 +497,6 @@ func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) {
obj = localObj.Elem()
}
- // Prepare the value.
- s := &wire.Struct{}
- *dest = s
-
// Look the type up in the database.
te, ok := es.types.Lookup(obj.Type())
if te == nil {
@@ -730,45 +774,43 @@ func (es *encodeState) Save(obj reflect.Value) {
Failf("encoding error at object %#v: %w", oes.obj.Interface(), err)
}
- // Check that items are pending.
- if es.pending.Front() == nil {
+ // Check that we have objects to serialize.
+ if len(es.pending) == 0 {
Failf("pending is empty?")
}
- // Write the header with the number of objects. Note that there is no
- // way that es.lastID could conflict with objectID, which would
- // indicate that an impossibly large encoding.
- if err := WriteHeader(es.w, uint64(es.lastID), true); err != nil {
+ // Write the header with the number of objects.
+ if err := WriteHeader(es.w, uint64(len(es.pending)), true); err != nil {
Failf("error writing header: %w", err)
}
// Serialize all pending types and pending objects. Note that we don't
// bother removing from this list as we walk it because that just
// wastes time. It will not change after this point.
- var id objectID
if err := safely(func() {
for _, wt := range es.pendingTypes {
// Encode the type.
wire.Save(es.w, &wt)
}
- for oes = es.pending.Front(); oes != nil; oes = oes.pendingEntry.Next() {
- id++ // First object is 1.
- if oes.id != id {
- Failf("expected id %d, got %d", id, oes.id)
- }
-
- // Marshall the object.
+ // Emit objects in ID order.
+ ids := make([]objectID, 0, len(es.pending))
+ for id := range es.pending {
+ ids = append(ids, id)
+ }
+ sort.Slice(ids, func(i, j int) bool {
+ return ids[i] < ids[j]
+ })
+ for _, id := range ids {
+ // Encode the id.
+ wire.Save(es.w, wire.Uint(id))
+ // Marshal the object.
+ oes := es.pending[id]
wire.Save(es.w, oes.encoded)
}
}); err != nil {
// Include the object and the error.
Failf("error serializing object %#v: %w", oes.encoded, err)
}
-
- // Check what we wrote.
- if id != es.lastID {
- Failf("expected %d objects, wrote %d", es.lastID, id)
- }
}
// objectFlag indicates that the length is a # of objects, rather than a raw
@@ -797,11 +839,6 @@ func WriteHeader(w wire.Writer, length uint64, object bool) error {
})
}
-// pendingMapper is for the pending list.
-type pendingMapper struct{}
-
-func (pendingMapper) linkerFor(oes *objectEncodeState) *pendingEntry { return &oes.pendingEntry }
-
// deferredMapper is for the deferred list.
type deferredMapper struct{}
diff --git a/pkg/state/pretty/pretty.go b/pkg/state/pretty/pretty.go
index 887f453a9..c6e8bb31d 100644
--- a/pkg/state/pretty/pretty.go
+++ b/pkg/state/pretty/pretty.go
@@ -42,6 +42,7 @@ func (p *printer) formatRef(x *wire.Ref, graph uint64) string {
buf.WriteString(typ)
buf.WriteString(")(")
buf.WriteString(baseRef)
+ buf.WriteString(")")
for _, component := range x.Dots {
switch v := component.(type) {
case *wire.FieldName:
@@ -53,7 +54,6 @@ func (p *printer) formatRef(x *wire.Ref, graph uint64) string {
panic(fmt.Sprintf("unreachable: switch should be exhaustive, unhandled case %v", reflect.TypeOf(component)))
}
}
- buf.WriteString(")")
fullRef = buf.String()
}
if p.html {
@@ -242,19 +242,22 @@ func (p *printer) printStream(w io.Writer, r wire.Reader) (err error) {
// Note that this loop must match the general structure of the
// loop in decode.go. But we don't register type information,
// etc. and just print the raw structures.
+ type objectAndID struct {
+ id uint64
+ obj wire.Object
+ }
var (
tid uint64 = 1
- objects []wire.Object
+ objects []objectAndID
)
- for oid := uint64(1); oid <= length; {
- // Unmarshal the object.
+ for i := uint64(0); i < length; {
+ // Unmarshal either a type object or object ID.
encoded := wire.Load(r)
-
- // Is this a type?
- if typ, ok := encoded.(*wire.Type); ok {
+ switch we := encoded.(type) {
+ case *wire.Type:
str, _ := p.format(graph, 0, encoded)
tag := fmt.Sprintf("g%dt%d", graph, tid)
- p.typeSpecs[tag] = typ
+ p.typeSpecs[tag] = we
if p.html {
// See below.
tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">&#9875;</a>", tag, tag, tag)
@@ -263,20 +266,22 @@ func (p *printer) printStream(w io.Writer, r wire.Reader) (err error) {
return err
}
tid++
- continue
+ case wire.Uint:
+ // Unmarshal the actual object.
+ objects = append(objects, objectAndID{
+ id: uint64(we),
+ obj: wire.Load(r),
+ })
+ i++
+ default:
+ return fmt.Errorf("wanted type or object ID, got %#v", encoded)
}
-
- // Otherwise, it is a node.
- objects = append(objects, encoded)
- oid++
}
- for i, encoded := range objects {
- // oid starts at 1.
- oid := i + 1
+ for _, objAndID := range objects {
// Format the node.
- str, _ := p.format(graph, 0, encoded)
- tag := fmt.Sprintf("g%dr%d", graph, oid)
+ str, _ := p.format(graph, 0, objAndID.obj)
+ tag := fmt.Sprintf("g%dr%d", graph, objAndID.id)
if p.html {
// Create a little tag with an anchor next to it for linking.
tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">&#9875;</a>", tag, tag, tag)
diff --git a/pkg/state/state.go b/pkg/state/state.go
index acb629969..6b8540f03 100644
--- a/pkg/state/state.go
+++ b/pkg/state/state.go
@@ -90,10 +90,12 @@ func (e *ErrState) Unwrap() error {
func Save(ctx context.Context, w wire.Writer, rootPtr interface{}) (Stats, error) {
// Create the encoding state.
es := encodeState{
- ctx: ctx,
- w: w,
- types: makeTypeEncodeDatabase(),
- zeroValues: make(map[reflect.Type]*objectEncodeState),
+ ctx: ctx,
+ w: w,
+ types: makeTypeEncodeDatabase(),
+ zeroValues: make(map[reflect.Type]*objectEncodeState),
+ pending: make(map[objectID]*objectEncodeState),
+ encodedStructs: make(map[reflect.Value]*wire.Struct),
}
// Perform the encoding.
diff --git a/pkg/state/tests/struct.go b/pkg/state/tests/struct.go
index bd2c2b399..69143d194 100644
--- a/pkg/state/tests/struct.go
+++ b/pkg/state/tests/struct.go
@@ -54,12 +54,47 @@ type outerArray struct {
}
// +stateify savable
+type outerSlice struct {
+ inner []inner
+}
+
+// +stateify savable
type inner struct {
v int64
}
// +stateify savable
+type outerFieldValue struct {
+ inner innerFieldValue
+}
+
+// +stateify savable
+type innerFieldValue struct {
+ v int64 `state:".(*savedFieldValue)"`
+}
+
+// +stateify savable
+type savedFieldValue struct {
+ v int64
+}
+
+func (ifv *innerFieldValue) saveV() *savedFieldValue {
+ return &savedFieldValue{ifv.v}
+}
+
+func (ifv *innerFieldValue) loadV(sfv *savedFieldValue) {
+ ifv.v = sfv.v
+}
+
+// +stateify savable
type system struct {
v1 interface{}
v2 interface{}
}
+
+// +stateify savable
+type system3 struct {
+ v1 interface{}
+ v2 interface{}
+ v3 interface{}
+}
diff --git a/pkg/state/tests/struct_test.go b/pkg/state/tests/struct_test.go
index de9d17aa7..c91c2c032 100644
--- a/pkg/state/tests/struct_test.go
+++ b/pkg/state/tests/struct_test.go
@@ -15,6 +15,7 @@
package tests
import (
+ "math/rand"
"testing"
"gvisor.dev/gvisor/pkg/state"
@@ -67,12 +68,23 @@ func TestRegisterTypeOnlyStruct(t *testing.T) {
}
func TestEmbeddedPointers(t *testing.T) {
- var (
- ofs outerSame
- of1 outerFieldFirst
- of2 outerFieldSecond
- oa outerArray
- )
+ // Give each int64 a random value to prevent Go from using
+ // runtime.staticuint64s, which confounds tests for struct duplication.
+ magic := func() int64 {
+ for {
+ n := rand.Int63()
+ if n < 0 || n > 255 {
+ return n
+ }
+ }
+ }
+
+ ofs := outerSame{inner{magic()}}
+ of1 := outerFieldFirst{inner{magic()}, magic()}
+ of2 := outerFieldSecond{magic(), inner{magic()}}
+ oa := outerArray{[2]inner{{magic()}, {magic()}}}
+ osl := outerSlice{oa.inner[:]}
+ ofv := outerFieldValue{innerFieldValue{magic()}}
runTestCases(t, false, "embedded-pointers", []interface{}{
system{&ofs, &ofs.inner},
@@ -85,5 +97,15 @@ func TestEmbeddedPointers(t *testing.T) {
system{&oa, &oa.inner[1]},
system{&oa.inner[0], &oa},
system{&oa.inner[1], &oa},
+ system3{&oa, &oa.inner[0], &oa.inner[1]},
+ system3{&oa, &oa.inner[1], &oa.inner[0]},
+ system3{&oa.inner[0], &oa, &oa.inner[1]},
+ system3{&oa.inner[1], &oa, &oa.inner[0]},
+ system3{&oa.inner[0], &oa.inner[1], &oa},
+ system3{&oa.inner[1], &oa.inner[0], &oa},
+ system{&oa, &osl},
+ system{&osl, &oa},
+ system{&ofv, &ofv.inner},
+ system{&ofv.inner, &ofv},
})
}
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 12b061def..b196324c7 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -97,6 +97,9 @@ type testConnection struct {
func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) {
wq := &waiter.Queue{}
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ return nil, err
+ }
entry, ch := waiter.NewChannelEntry(nil)
wq.EventRegister(&entry, waiter.EventOut)
@@ -145,7 +148,9 @@ func TestCloseReader(t *testing.T) {
defer close(done)
c, err := l.Accept()
if err != nil {
- t.Fatalf("l.Accept() = %v", err)
+ t.Errorf("l.Accept() = %v", err)
+ // Cannot call Fatalf in goroutine. Just return from the goroutine.
+ return
}
// Give c.Read() a chance to block before closing the connection.
@@ -416,7 +421,9 @@ func TestDeadlineChange(t *testing.T) {
defer close(done)
c, err := l.Accept()
if err != nil {
- t.Fatalf("l.Accept() = %v", err)
+ t.Errorf("l.Accept() = %v", err)
+ // Cannot call Fatalf in goroutine. Just return from the goroutine.
+ return
}
c.SetDeadline(time.Now().Add(time.Minute))
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index d4d785cca..530f2ae2f 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -178,6 +178,24 @@ func PayloadLen(payloadLength int) NetworkChecker {
}
}
+// IPPayload creates a checker that checks the payload.
+func IPPayload(payload []byte) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ got := h[0].Payload()
+
+ // cmp.Diff does not consider nil slices equal to empty slices, but we do.
+ if len(got) == 0 && len(payload) == 0 {
+ return
+ }
+
+ if diff := cmp.Diff(payload, got); diff != "" {
+ t.Errorf("payload mismatch (-want +got):\n%s", diff)
+ }
+ }
+}
+
// IPv4Options returns a checker that checks the options in an IPv4 packet.
func IPv4Options(want []byte) NetworkChecker {
return func(t *testing.T, h []header.Network) {
@@ -187,7 +205,7 @@ func IPv4Options(want []byte) NetworkChecker {
if !ok {
t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0])
}
- options := ip.Options()
+ options := []byte(ip.Options())
// cmp.Diff does not consider nil slices equal to empty slices, but we do.
if len(want) == 0 && len(options) == 0 {
return
@@ -841,6 +859,21 @@ func ICMPv4Seq(want uint16) TransportChecker {
}
}
+// ICMPv4Pointer creates a checker that checks the ICMPv4 Param Problem pointer.
+func ICMPv4Pointer(want uint8) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h)
+ }
+ if got := icmpv4.Pointer(); got != want {
+ t.Fatalf("unexpected ICMP Param Problem pointer, got = %d, want = %d", got, want)
+ }
+ }
+}
+
// ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum.
// This assumes that the payload exactly makes up the rest of the slice.
func ICMPv4Checksum() TransportChecker {
@@ -935,6 +968,38 @@ func ICMPv6Code(want header.ICMPv6Code) TransportChecker {
}
}
+// ICMPv6TypeSpecific creates a checker that checks the ICMPv6 TypeSpecific
+// field.
+func ICMPv6TypeSpecific(want uint32) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv6, ok := h.(header.ICMPv6)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
+ }
+ if got := icmpv6.TypeSpecific(); got != want {
+ t.Fatalf("unexpected ICMP TypeSpecific, got = %d, want = %d", got, want)
+ }
+ }
+}
+
+// ICMPv6Payload creates a checker that checks the payload in an ICMPv6 packet.
+func ICMPv6Payload(want []byte) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmpv6, ok := h.(header.ICMPv6)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h)
+ }
+ payload := icmpv6.Payload()
+ if diff := cmp.Diff(want, payload); diff != "" {
+ t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff)
+ }
+ }
+}
+
// NDP creates a checker that checks that the packet contains a valid NDP
// message for type of ty, with potentially additional checks specified by
// checkers.
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index 504408878..2f13dea6a 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -99,7 +99,8 @@ const (
// ICMP codes for ICMPv4 Time Exceeded messages as defined in RFC 792.
const (
- ICMPv4TTLExceeded ICMPv4Code = 0
+ ICMPv4TTLExceeded ICMPv4Code = 0
+ ICMPv4ReassemblyTimeout ICMPv4Code = 1
)
// ICMP codes for ICMPv4 Destination Unreachable messages as defined in RFC 792.
@@ -126,6 +127,12 @@ func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) }
// SetCode sets the ICMP code field.
func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) }
+// Pointer returns the pointer field in a Parameter Problem packet.
+func (b ICMPv4) Pointer() byte { return b[icmpv4PointerOffset] }
+
+// SetPointer sets the pointer field in a Parameter Problem packet.
+func (b ICMPv4) SetPointer(c byte) { b[icmpv4PointerOffset] = c }
+
// Checksum is the ICMP checksum field.
func (b ICMPv4) Checksum() uint16 {
return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:])
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 4c6e4be64..961b77628 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -16,6 +16,7 @@ package header
import (
"encoding/binary"
+ "errors"
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -38,7 +39,6 @@ import (
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Options | Padding |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-//
const (
versIHL = 0
tos = 1
@@ -93,7 +93,7 @@ type IPv4Fields struct {
DstAddr tcpip.Address
}
-// IPv4 represents an ipv4 header stored in a byte array.
+// IPv4 is an IPv4 header.
// Most of the methods of IPv4 access to the underlying slice without
// checking the boundaries and could panic because of 'index out of range'.
// Always call IsValid() to validate an instance of IPv4 before using other
@@ -106,10 +106,13 @@ const (
IPv4MinimumSize = 20
// IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given
- // that there are only 4 bits to represents the header length in 32-bit
- // units, the header cannot exceed 15*4 = 60 bytes.
+ // that there are only 4 bits (max 0xF (15)) to represent the header length
+ // in 32-bit (4 byte) units, the header cannot exceed 15*4 = 60 bytes.
IPv4MaximumHeaderSize = 60
+ // IPv4MaximumOptionsSize is the largest size the IPv4 options can be.
+ IPv4MaximumOptionsSize = IPv4MaximumHeaderSize - IPv4MinimumSize
+
// IPv4MaximumPayloadSize is the maximum size of a valid IPv4 payload.
//
// Linux limits this to 65,515 octets (the max IP datagram size - the IPv4
@@ -130,7 +133,7 @@ const (
// IPv4ProtocolNumber is IPv4's network protocol number.
IPv4ProtocolNumber tcpip.NetworkProtocolNumber = 0x0800
- // IPv4Version is the version of the ipv4 protocol.
+ // IPv4Version is the version of the IPv4 protocol.
IPv4Version = 4
// IPv4AllSystems is the all systems IPv4 multicast address as per
@@ -148,6 +151,13 @@ const (
// packet that every IPv4 capable host must be able to
// process/reassemble.
IPv4MinimumProcessableDatagramSize = 576
+
+ // IPv4MinimumMTU is the minimum MTU required by IPv4, per RFC 791,
+ // section 3.2:
+ // Every internet module must be able to forward a datagram of 68 octets
+ // without further fragmentation. This is because an internet header may be
+ // up to 60 octets, and the minimum fragment is 8 octets.
+ IPv4MinimumMTU = 68
)
// Flags that may be set in an IPv4 packet.
@@ -191,14 +201,13 @@ func IPVersion(b []byte) int {
// Internet Header Length is the length of the internet header in 32
// bit words, and thus points to the beginning of the data. Note that
// the minimum value for a correct header is 5.
-//
const (
ipVersionShift = 4
ipIHLMask = 0x0f
IPv4IHLStride = 4
)
-// HeaderLength returns the value of the "header length" field of the ipv4
+// HeaderLength returns the value of the "header length" field of the IPv4
// header. The length returned is in bytes.
func (b IPv4) HeaderLength() uint8 {
return (b[versIHL] & ipIHLMask) * IPv4IHLStride
@@ -212,17 +221,17 @@ func (b IPv4) SetHeaderLength(hdrLen uint8) {
b[versIHL] = (IPv4Version << ipVersionShift) | ((hdrLen / IPv4IHLStride) & ipIHLMask)
}
-// ID returns the value of the identifier field of the ipv4 header.
+// ID returns the value of the identifier field of the IPv4 header.
func (b IPv4) ID() uint16 {
return binary.BigEndian.Uint16(b[id:])
}
-// Protocol returns the value of the protocol field of the ipv4 header.
+// Protocol returns the value of the protocol field of the IPv4 header.
func (b IPv4) Protocol() uint8 {
return b[protocol]
}
-// Flags returns the "flags" field of the ipv4 header.
+// Flags returns the "flags" field of the IPv4 header.
func (b IPv4) Flags() uint8 {
return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13)
}
@@ -232,41 +241,44 @@ func (b IPv4) More() bool {
return b.Flags()&IPv4FlagMoreFragments != 0
}
-// TTL returns the "TTL" field of the ipv4 header.
+// TTL returns the "TTL" field of the IPv4 header.
func (b IPv4) TTL() uint8 {
return b[ttl]
}
-// FragmentOffset returns the "fragment offset" field of the ipv4 header.
+// FragmentOffset returns the "fragment offset" field of the IPv4 header.
func (b IPv4) FragmentOffset() uint16 {
return binary.BigEndian.Uint16(b[flagsFO:]) << 3
}
-// TotalLength returns the "total length" field of the ipv4 header.
+// TotalLength returns the "total length" field of the IPv4 header.
func (b IPv4) TotalLength() uint16 {
return binary.BigEndian.Uint16(b[IPv4TotalLenOffset:])
}
-// Checksum returns the checksum field of the ipv4 header.
+// Checksum returns the checksum field of the IPv4 header.
func (b IPv4) Checksum() uint16 {
return binary.BigEndian.Uint16(b[checksum:])
}
-// SourceAddress returns the "source address" field of the ipv4 header.
+// SourceAddress returns the "source address" field of the IPv4 header.
func (b IPv4) SourceAddress() tcpip.Address {
return tcpip.Address(b[srcAddr : srcAddr+IPv4AddressSize])
}
-// DestinationAddress returns the "destination address" field of the ipv4
+// DestinationAddress returns the "destination address" field of the IPv4
// header.
func (b IPv4) DestinationAddress() tcpip.Address {
return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize])
}
-// Options returns a a buffer holding the options.
-func (b IPv4) Options() []byte {
+// IPv4Options is a buffer that holds all the raw IP options.
+type IPv4Options []byte
+
+// Options returns a buffer holding the options.
+func (b IPv4) Options() IPv4Options {
hdrLen := b.HeaderLength()
- return b[options:hdrLen:hdrLen]
+ return IPv4Options(b[options:hdrLen:hdrLen])
}
// TransportProtocol implements Network.TransportProtocol.
@@ -279,17 +291,17 @@ func (b IPv4) Payload() []byte {
return b[b.HeaderLength():][:b.PayloadLength()]
}
-// PayloadLength returns the length of the payload portion of the ipv4 packet.
+// PayloadLength returns the length of the payload portion of the IPv4 packet.
func (b IPv4) PayloadLength() uint16 {
return b.TotalLength() - uint16(b.HeaderLength())
}
-// TOS returns the "type of service" field of the ipv4 header.
+// TOS returns the "type of service" field of the IPv4 header.
func (b IPv4) TOS() (uint8, uint32) {
return b[tos], 0
}
-// SetTOS sets the "type of service" field of the ipv4 header.
+// SetTOS sets the "type of service" field of the IPv4 header.
func (b IPv4) SetTOS(v uint8, _ uint32) {
b[tos] = v
}
@@ -299,18 +311,18 @@ func (b IPv4) SetTTL(v byte) {
b[ttl] = v
}
-// SetTotalLength sets the "total length" field of the ipv4 header.
+// SetTotalLength sets the "total length" field of the IPv4 header.
func (b IPv4) SetTotalLength(totalLength uint16) {
binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength)
}
-// SetChecksum sets the checksum field of the ipv4 header.
+// SetChecksum sets the checksum field of the IPv4 header.
func (b IPv4) SetChecksum(v uint16) {
binary.BigEndian.PutUint16(b[checksum:], v)
}
// SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the
-// ipv4 header.
+// IPv4 header.
func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) {
v := (uint16(flags) << 13) | (offset >> 3)
binary.BigEndian.PutUint16(b[flagsFO:], v)
@@ -321,23 +333,23 @@ func (b IPv4) SetID(v uint16) {
binary.BigEndian.PutUint16(b[id:], v)
}
-// SetSourceAddress sets the "source address" field of the ipv4 header.
+// SetSourceAddress sets the "source address" field of the IPv4 header.
func (b IPv4) SetSourceAddress(addr tcpip.Address) {
copy(b[srcAddr:srcAddr+IPv4AddressSize], addr)
}
-// SetDestinationAddress sets the "destination address" field of the ipv4
+// SetDestinationAddress sets the "destination address" field of the IPv4
// header.
func (b IPv4) SetDestinationAddress(addr tcpip.Address) {
copy(b[dstAddr:dstAddr+IPv4AddressSize], addr)
}
-// CalculateChecksum calculates the checksum of the ipv4 header.
+// CalculateChecksum calculates the checksum of the IPv4 header.
func (b IPv4) CalculateChecksum() uint16 {
return Checksum(b[:b.HeaderLength()], 0)
}
-// Encode encodes all the fields of the ipv4 header.
+// Encode encodes all the fields of the IPv4 header.
func (b IPv4) Encode(i *IPv4Fields) {
b.SetHeaderLength(i.IHL)
b[tos] = i.TOS
@@ -351,7 +363,7 @@ func (b IPv4) Encode(i *IPv4Fields) {
copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr)
}
-// EncodePartial updates the total length and checksum fields of ipv4 header,
+// EncodePartial updates the total length and checksum fields of IPv4 header,
// taking in the partial checksum, which is the checksum of the header without
// the total length and checksum fields. It is useful in cases when similar
// packets are produced.
@@ -398,3 +410,424 @@ func IsV4LoopbackAddress(addr tcpip.Address) bool {
}
return addr[0] == 0x7f
}
+
+// ========================= Options ==========================
+
+// An IPv4OptionType can hold the valuse for the Type in an IPv4 option.
+type IPv4OptionType byte
+
+// These constants are needed to identify individual options in the option list.
+// While RFC 791 (page 31) says "Every internet module must be able to act on
+// every option." This has not generally been adhered to and some options have
+// very low rates of support. We do not support options other than those shown
+// below.
+
+const (
+ // IPv4OptionListEndType is the option type for the End Of Option List
+ // option. Anything following is ignored.
+ IPv4OptionListEndType IPv4OptionType = 0
+
+ // IPv4OptionNOPType is the No-Operation option. May appear between other
+ // options and may appear multiple times.
+ IPv4OptionNOPType IPv4OptionType = 1
+
+ // IPv4OptionRecordRouteType is used by each router on the path of the packet
+ // to record its path. It is carried over to an Echo Reply.
+ IPv4OptionRecordRouteType IPv4OptionType = 7
+
+ // IPv4OptionTimestampType is the option type for the Timestamp option.
+ IPv4OptionTimestampType IPv4OptionType = 68
+
+ // ipv4OptionTypeOffset is the offset in an option of its type field.
+ ipv4OptionTypeOffset = 0
+
+ // IPv4OptionLengthOffset is the offset in an option of its length field.
+ IPv4OptionLengthOffset = 1
+)
+
+// Potential errors when parsing generic IP options.
+var (
+ ErrIPv4OptZeroLength = errors.New("zero length IP option")
+ ErrIPv4OptDuplicate = errors.New("duplicate IP option")
+ ErrIPv4OptInvalid = errors.New("invalid IP option")
+ ErrIPv4OptMalformed = errors.New("malformed IP option")
+ ErrIPv4OptionTruncated = errors.New("truncated IP option")
+ ErrIPv4OptionAddress = errors.New("bad IP option address")
+)
+
+// IPv4Option is an interface representing various option types.
+type IPv4Option interface {
+ // Type returns the type identifier of the option.
+ Type() IPv4OptionType
+
+ // Size returns the size of the option in bytes.
+ Size() uint8
+
+ // Contents returns a slice holding the contents of the option.
+ Contents() []byte
+}
+
+var _ IPv4Option = (*IPv4OptionGeneric)(nil)
+
+// IPv4OptionGeneric is an IPv4 Option of unknown type.
+type IPv4OptionGeneric []byte
+
+// Type implements IPv4Option.
+func (o *IPv4OptionGeneric) Type() IPv4OptionType {
+ return IPv4OptionType((*o)[ipv4OptionTypeOffset])
+}
+
+// Size implements IPv4Option.
+func (o *IPv4OptionGeneric) Size() uint8 { return uint8(len(*o)) }
+
+// Contents implements IPv4Option.
+func (o *IPv4OptionGeneric) Contents() []byte { return []byte(*o) }
+
+// IPv4OptionIterator is an iterator pointing to a specific IP option
+// at any point of time. It also holds information as to a new options buffer
+// that we are building up to hand back to the caller.
+type IPv4OptionIterator struct {
+ options IPv4Options
+ // ErrCursor is where we are while parsing options. It is exported as any
+ // resulting ICMP packet is supposed to have a pointer to the byte within
+ // the IP packet where the error was detected.
+ ErrCursor uint8
+ nextErrCursor uint8
+ newOptions [IPv4MaximumOptionsSize]byte
+ writePoint int
+}
+
+// MakeIterator sets up and returns an iterator of options. It also sets up the
+// building of a new option set.
+func (o IPv4Options) MakeIterator() IPv4OptionIterator {
+ return IPv4OptionIterator{
+ options: o,
+ nextErrCursor: IPv4MinimumSize,
+ }
+}
+
+// RemainingBuffer returns the remaining (unused) part of the new option buffer,
+// into which a new option may be written.
+func (i *IPv4OptionIterator) RemainingBuffer() IPv4Options {
+ return IPv4Options(i.newOptions[i.writePoint:])
+}
+
+// ConsumeBuffer marks a portion of the new buffer as used.
+func (i *IPv4OptionIterator) ConsumeBuffer(size int) {
+ i.writePoint += size
+}
+
+// PushNOPOrEnd puts one of the single byte options onto the new options.
+// Only values 0 or 1 (ListEnd or NOP) are valid input.
+func (i *IPv4OptionIterator) PushNOPOrEnd(val IPv4OptionType) {
+ if val > IPv4OptionNOPType {
+ panic(fmt.Sprintf("invalid option type %d pushed onto option build buffer", val))
+ }
+ i.newOptions[i.writePoint] = byte(val)
+ i.writePoint++
+}
+
+// Finalize returns the completed replacement options buffer padded
+// as needed.
+func (i *IPv4OptionIterator) Finalize() IPv4Options {
+ // RFC 791 page 31 says:
+ // The options might not end on a 32-bit boundary. The internet header
+ // must be filled out with octets of zeros. The first of these would
+ // be interpreted as the end-of-options option, and the remainder as
+ // internet header padding.
+ // Since the buffer is already zero filled we just need to step the write
+ // pointer up to the next multiple of 4.
+ options := IPv4Options(i.newOptions[:(i.writePoint+0x3) & ^0x3])
+ // Poison the write pointer.
+ i.writePoint = len(i.newOptions)
+ return options
+}
+
+// Next returns the next IP option in the buffer/list of IP options.
+// It returns
+// - A slice of bytes holding the next option or nil if there is error.
+// - A boolean which is true if parsing of all the options is complete.
+// - An error which is non-nil if an error condition was encountered.
+func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) {
+ // The opts slice gets shorter as we process the options. When we have no
+ // bytes left we are done.
+ if len(i.options) == 0 {
+ return nil, true, nil
+ }
+
+ i.ErrCursor = i.nextErrCursor
+
+ optType := IPv4OptionType(i.options[ipv4OptionTypeOffset])
+
+ if optType == IPv4OptionNOPType || optType == IPv4OptionListEndType {
+ optionBody := i.options[:1]
+ i.options = i.options[1:]
+ i.nextErrCursor = i.ErrCursor + 1
+ retval := IPv4OptionGeneric(optionBody)
+ return &retval, false, nil
+ }
+
+ // There are no more single byte options defined. All the rest have a length
+ // field so we need to sanity check it.
+ if len(i.options) == 1 {
+ return nil, true, ErrIPv4OptMalformed
+ }
+
+ optLen := i.options[IPv4OptionLengthOffset]
+
+ if optLen == 0 {
+ i.ErrCursor++
+ return nil, true, ErrIPv4OptZeroLength
+ }
+
+ if optLen == 1 {
+ i.ErrCursor++
+ return nil, true, ErrIPv4OptMalformed
+ }
+
+ if optLen > uint8(len(i.options)) {
+ i.ErrCursor++
+ return nil, true, ErrIPv4OptionTruncated
+ }
+
+ optionBody := i.options[:optLen]
+ i.nextErrCursor = i.ErrCursor + optLen
+ i.options = i.options[optLen:]
+
+ // Check the length of some option types that we know.
+ switch optType {
+ case IPv4OptionTimestampType:
+ if optLen < IPv4OptionTimestampHdrLength {
+ i.ErrCursor++
+ return nil, true, ErrIPv4OptMalformed
+ }
+ retval := IPv4OptionTimestamp(optionBody)
+ return &retval, false, nil
+
+ case IPv4OptionRecordRouteType:
+ if optLen < IPv4OptionRecordRouteHdrLength {
+ i.ErrCursor++
+ return nil, true, ErrIPv4OptMalformed
+ }
+ retval := IPv4OptionRecordRoute(optionBody)
+ return &retval, false, nil
+ }
+ retval := IPv4OptionGeneric(optionBody)
+ return &retval, false, nil
+}
+
+//
+// IP Timestamp option - RFC 791 page 22.
+// +--------+--------+--------+--------+
+// |01000100| length | pointer|oflw|flg|
+// +--------+--------+--------+--------+
+// | internet address |
+// +--------+--------+--------+--------+
+// | timestamp |
+// +--------+--------+--------+--------+
+// | ... |
+//
+// Type = 68
+//
+// The Option Length is the number of octets in the option counting
+// the type, length, pointer, and overflow/flag octets (maximum
+// length 40).
+//
+// The Pointer is the number of octets from the beginning of this
+// option to the end of timestamps plus one (i.e., it points to the
+// octet beginning the space for next timestamp). The smallest
+// legal value is 5. The timestamp area is full when the pointer
+// is greater than the length.
+//
+// The Overflow (oflw) [4 bits] is the number of IP modules that
+// cannot register timestamps due to lack of space.
+//
+// The Flag (flg) [4 bits] values are
+//
+// 0 -- time stamps only, stored in consecutive 32-bit words,
+//
+// 1 -- each timestamp is preceded with internet address of the
+// registering entity,
+//
+// 3 -- the internet address fields are prespecified. An IP
+// module only registers its timestamp if it matches its own
+// address with the next specified internet address.
+//
+// Timestamps are defined in RFC 791 page 22 as milliseconds since midnight UTC.
+//
+// The Timestamp is a right-justified, 32-bit timestamp in
+// milliseconds since midnight UT. If the time is not available in
+// milliseconds or cannot be provided with respect to midnight UT
+// then any time may be inserted as a timestamp provided the high
+// order bit of the timestamp field is set to one to indicate the
+// use of a non-standard value.
+
+// IPv4OptTSFlags sefines the values expected in the Timestamp
+// option Flags field.
+type IPv4OptTSFlags uint8
+
+//
+// Timestamp option specific related constants.
+const (
+ // IPv4OptionTimestampHdrLength is the length of the timestamp option header.
+ IPv4OptionTimestampHdrLength = 4
+
+ // IPv4OptionTimestampSize is the size of an IP timestamp.
+ IPv4OptionTimestampSize = 4
+
+ // IPv4OptionTimestampWithAddrSize is the size of an IP timestamp + Address.
+ IPv4OptionTimestampWithAddrSize = IPv4AddressSize + IPv4OptionTimestampSize
+
+ // IPv4OptionTimestampMaxSize is limited by space for options
+ IPv4OptionTimestampMaxSize = IPv4MaximumOptionsSize
+
+ // IPv4OptionTimestampOnlyFlag is a flag indicating that only timestamp
+ // is present.
+ IPv4OptionTimestampOnlyFlag IPv4OptTSFlags = 0
+
+ // IPv4OptionTimestampWithIPFlag is a flag indicating that both timestamps and
+ // IP are present.
+ IPv4OptionTimestampWithIPFlag IPv4OptTSFlags = 1
+
+ // IPv4OptionTimestampWithPredefinedIPFlag is a flag indicating that
+ // predefined IP is present.
+ IPv4OptionTimestampWithPredefinedIPFlag IPv4OptTSFlags = 3
+)
+
+// ipv4TimestampTime provides the current time as specified in RFC 791.
+func ipv4TimestampTime(clock tcpip.Clock) uint32 {
+ const millisecondsPerDay = 24 * 3600 * 1000
+ const nanoPerMilli = 1000000
+ return uint32((clock.NowNanoseconds() / nanoPerMilli) % millisecondsPerDay)
+}
+
+// IP Timestamp option fields.
+const (
+ // IPv4OptTSPointerOffset is the offset of the Timestamp pointer field.
+ IPv4OptTSPointerOffset = 2
+
+ // IPv4OptTSPointerOffset is the offset of the combined Flag and Overflow
+ // fields, (each being 4 bits).
+ IPv4OptTSOFLWAndFLGOffset = 3
+ // These constants define the sub byte fields of the Flag and OverFlow field.
+ ipv4OptionTimestampOverflowshift = 4
+ ipv4OptionTimestampFlagsMask byte = 0x0f
+)
+
+var _ IPv4Option = (*IPv4OptionTimestamp)(nil)
+
+// IPv4OptionTimestamp is a Timestamp option from RFC 791.
+type IPv4OptionTimestamp []byte
+
+// Type implements IPv4Option.Type().
+func (ts *IPv4OptionTimestamp) Type() IPv4OptionType { return IPv4OptionTimestampType }
+
+// Size implements IPv4Option.
+func (ts *IPv4OptionTimestamp) Size() uint8 { return uint8(len(*ts)) }
+
+// Contents implements IPv4Option.
+func (ts *IPv4OptionTimestamp) Contents() []byte { return []byte(*ts) }
+
+// Pointer returns the pointer field in the IP Timestamp option.
+func (ts *IPv4OptionTimestamp) Pointer() uint8 {
+ return (*ts)[IPv4OptTSPointerOffset]
+}
+
+// Flags returns the flags field in the IP Timestamp option.
+func (ts *IPv4OptionTimestamp) Flags() IPv4OptTSFlags {
+ return IPv4OptTSFlags((*ts)[IPv4OptTSOFLWAndFLGOffset] & ipv4OptionTimestampFlagsMask)
+}
+
+// Overflow returns the Overflow field in the IP Timestamp option.
+func (ts *IPv4OptionTimestamp) Overflow() uint8 {
+ return (*ts)[IPv4OptTSOFLWAndFLGOffset] >> ipv4OptionTimestampOverflowshift
+}
+
+// IncOverflow increments the Overflow field in the IP Timestamp option. It
+// returns the incremented value. If the return value is 0 then the field
+// overflowed.
+func (ts *IPv4OptionTimestamp) IncOverflow() uint8 {
+ (*ts)[IPv4OptTSOFLWAndFLGOffset] += 1 << ipv4OptionTimestampOverflowshift
+ return ts.Overflow()
+}
+
+// UpdateTimestamp updates the fields of the next free timestamp slot.
+func (ts *IPv4OptionTimestamp) UpdateTimestamp(addr tcpip.Address, clock tcpip.Clock) {
+ slot := (*ts)[ts.Pointer()-1:]
+
+ switch ts.Flags() {
+ case IPv4OptionTimestampOnlyFlag:
+ binary.BigEndian.PutUint32(slot, ipv4TimestampTime(clock))
+ (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampSize
+ case IPv4OptionTimestampWithIPFlag:
+ if n := copy(slot, addr); n != IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IPv4AddressSize))
+ }
+ binary.BigEndian.PutUint32(slot[IPv4AddressSize:], ipv4TimestampTime(clock))
+ (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampWithAddrSize
+ case IPv4OptionTimestampWithPredefinedIPFlag:
+ if tcpip.Address(slot[:IPv4AddressSize]) == addr {
+ binary.BigEndian.PutUint32(slot[IPv4AddressSize:], ipv4TimestampTime(clock))
+ (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampWithAddrSize
+ }
+ }
+}
+
+// RecordRoute option specific related constants.
+//
+// from RFC 791 page 20:
+// Record Route
+//
+// +--------+--------+--------+---------//--------+
+// |00000111| length | pointer| route data |
+// +--------+--------+--------+---------//--------+
+// Type=7
+//
+// The record route option provides a means to record the route of
+// an internet datagram.
+//
+// The option begins with the option type code. The second octet
+// is the option length which includes the option type code and the
+// length octet, the pointer octet, and length-3 octets of route
+// data. The third octet is the pointer into the route data
+// indicating the octet which begins the next area to store a route
+// address. The pointer is relative to this option, and the
+// smallest legal value for the pointer is 4.
+const (
+ // IPv4OptionRecordRouteHdrLength is the length of the Record Route option
+ // header.
+ IPv4OptionRecordRouteHdrLength = 3
+
+ // IPv4OptRRPointerOffset is the offset to the pointer field in an RR
+ // option, which points to the next free slot in the list of addresses.
+ IPv4OptRRPointerOffset = 2
+)
+
+var _ IPv4Option = (*IPv4OptionRecordRoute)(nil)
+
+// IPv4OptionRecordRoute is an IPv4 RecordRoute option defined by RFC 791.
+type IPv4OptionRecordRoute []byte
+
+// Pointer returns the pointer field in the IP RecordRoute option.
+func (rr *IPv4OptionRecordRoute) Pointer() uint8 {
+ return (*rr)[IPv4OptRRPointerOffset]
+}
+
+// StoreAddress stores the given IPv4 address into the next free slot.
+func (rr *IPv4OptionRecordRoute) StoreAddress(addr tcpip.Address) {
+ start := rr.Pointer() - 1 // A one based number.
+ // start and room checked by caller.
+ if n := copy((*rr)[start:], addr); n != IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IPv4AddressSize))
+ }
+ (*rr)[IPv4OptRRPointerOffset] += IPv4AddressSize
+}
+
+// Type implements IPv4Option.
+func (rr *IPv4OptionRecordRoute) Type() IPv4OptionType { return IPv4OptionRecordRouteType }
+
+// Size implements IPv4Option.
+func (rr *IPv4OptionRecordRoute) Size() uint8 { return uint8(len(*rr)) }
+
+// Contents implements IPv4Option.
+func (rr *IPv4OptionRecordRoute) Contents() []byte { return []byte(*rr) }
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index c5d8a3456..09cb153b1 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -101,8 +101,10 @@ const (
// The address is ff02::2.
IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460,
- // section 5.
+ // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 8200,
+ // section 5:
+ // IPv6 requires that every link in the Internet have an MTU of 1280 octets
+ // or greater. This is known as the IPv6 minimum link MTU.
IPv6MinimumMTU = 1280
// IPv6Loopback is the IPv6 Loopback address.
diff --git a/pkg/tcpip/link/ethernet/BUILD b/pkg/tcpip/link/ethernet/BUILD
new file mode 100644
index 000000000..ec92ed623
--- /dev/null
+++ b/pkg/tcpip/link/ethernet/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "ethernet",
+ srcs = ["ethernet.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/nested",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go
new file mode 100644
index 000000000..3eef7cd56
--- /dev/null
+++ b/pkg/tcpip/link/ethernet/ethernet.go
@@ -0,0 +1,99 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ethernet provides an implementation of an ethernet link endpoint that
+// wraps an inner link endpoint.
+package ethernet
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/nested"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.NetworkDispatcher = (*Endpoint)(nil)
+var _ stack.LinkEndpoint = (*Endpoint)(nil)
+
+// New returns an ethernet link endpoint that wraps an inner link endpoint.
+func New(ep stack.LinkEndpoint) *Endpoint {
+ var e Endpoint
+ e.Endpoint.Init(ep, &e)
+ return &e
+}
+
+// Endpoint is an ethernet endpoint.
+//
+// It adds an ethernet header to packets before sending them out through its
+// inner link endpoint and consumes an ethernet header before sending the
+// packet to the stack.
+type Endpoint struct {
+ nested.Endpoint
+}
+
+// DeliverNetworkPacket implements stack.NetworkDispatcher.
+func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
+ if !ok {
+ return
+ }
+
+ eth := header.Ethernet(hdr)
+ if dst := eth.DestinationAddress(); dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) {
+ e.Endpoint.DeliverNetworkPacket(eth.SourceAddress() /* remote */, dst /* local */, eth.Type() /* protocol */, pkt)
+ }
+}
+
+// Capabilities implements stack.LinkEndpoint.
+func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return stack.CapabilityResolutionRequired | e.Endpoint.Capabilities()
+}
+
+// WritePacket implements stack.LinkEndpoint.
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt)
+ return e.Endpoint.WritePacket(r, gso, proto, pkt)
+}
+
+// WritePackets implements stack.LinkEndpoint.
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ linkAddr := e.Endpoint.LinkAddress()
+
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt)
+ }
+
+ return e.Endpoint.WritePackets(r, gso, pkts, proto)
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.
+func (e *Endpoint) MaxHeaderLength() uint16 {
+ return header.EthernetMinimumSize + e.Endpoint.MaxHeaderLength()
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.
+func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareEther
+}
+
+// AddHeader implements stack.LinkEndpoint.
+func (*Endpoint) AddHeader(local, remote tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
+ fields := header.EthernetFields{
+ SrcAddr: local,
+ DstAddr: remote,
+ Type: proto,
+ }
+ eth.Encode(&fields)
+}
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
index 76f563811..523b0d24b 100644
--- a/pkg/tcpip/link/pipe/pipe.go
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -26,27 +26,23 @@ import (
var _ stack.LinkEndpoint = (*Endpoint)(nil)
// New returns both ends of a new pipe.
-func New(linkAddr1, linkAddr2 tcpip.LinkAddress, capabilities stack.LinkEndpointCapabilities) (*Endpoint, *Endpoint) {
+func New(linkAddr1, linkAddr2 tcpip.LinkAddress) (*Endpoint, *Endpoint) {
ep1 := &Endpoint{
- linkAddr: linkAddr1,
- capabilities: capabilities,
+ linkAddr: linkAddr1,
}
ep2 := &Endpoint{
- linkAddr: linkAddr2,
- linked: ep1,
- capabilities: capabilities,
+ linkAddr: linkAddr2,
}
ep1.linked = ep2
+ ep2.linked = ep1
return ep1, ep2
}
// Endpoint is one end of a pipe.
type Endpoint struct {
- capabilities stack.LinkEndpointCapabilities
- linkAddr tcpip.LinkAddress
- dispatcher stack.NetworkDispatcher
- linked *Endpoint
- onWritePacket func(*stack.PacketBuffer)
+ dispatcher stack.NetworkDispatcher
+ linked *Endpoint
+ linkAddr tcpip.LinkAddress
}
// WritePacket implements stack.LinkEndpoint.
@@ -55,16 +51,11 @@ func (e *Endpoint) WritePacket(r *stack.Route, _ *stack.GSO, proto tcpip.Network
return nil
}
- // The pipe endpoint will accept all multicast/broadcast link traffic and only
- // unicast traffic destined to itself.
- if len(e.linked.linkAddr) != 0 &&
- r.RemoteLinkAddress != e.linked.linkAddr &&
- r.RemoteLinkAddress != header.EthernetBroadcastAddress &&
- !header.IsMulticastEthernetAddress(r.RemoteLinkAddress) {
- return nil
- }
-
- e.linked.dispatcher.DeliverNetworkPacket(e.linkAddr, r.RemoteLinkAddress, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ // Note that the local address from the perspective of this endpoint is the
+ // remote address from the perspective of the other end of the pipe
+ // (e.linked). Similarly, the remote address from the perspective of this
+ // endpoint is the local address on the other end.
+ e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
}))
@@ -100,8 +91,8 @@ func (*Endpoint) MTU() uint32 {
}
// Capabilities implements stack.LinkEndpoint.
-func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.capabilities
+func (*Endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return 0
}
// MaxHeaderLength implements stack.LinkEndpoint.
@@ -116,7 +107,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
// ARPHardwareType implements stack.LinkEndpoint.
func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
- return header.ARPHardwareEther
+ return header.ARPHardwareNone
}
// AddHeader implements stack.LinkEndpoint.
diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
index dc239a0d0..2777f1411 100644
--- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
+++ b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
@@ -470,6 +470,7 @@ func TestConcurrentReaderWriter(t *testing.T) {
const count = 1000000
var wg sync.WaitGroup
+ defer wg.Wait()
wg.Add(1)
go func() {
defer wg.Done()
@@ -489,30 +490,23 @@ func TestConcurrentReaderWriter(t *testing.T) {
}
}()
- wg.Add(1)
- go func() {
- defer wg.Done()
- runtime.Gosched()
- for i := 0; i < count; i++ {
- n := 1 + rr.Intn(80)
- rb := rx.Pull()
- for rb == nil {
- rb = rx.Pull()
- }
+ for i := 0; i < count; i++ {
+ n := 1 + rr.Intn(80)
+ rb := rx.Pull()
+ for rb == nil {
+ rb = rx.Pull()
+ }
- if n != len(rb) {
- t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n)
- }
+ if n != len(rb) {
+ t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n)
+ }
- for j := range rb {
- if v := byte(rr.Intn(256)); v != rb[j] {
- t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v)
- }
+ for j := range rb {
+ if v := byte(rr.Intn(256)); v != rb[j] {
+ t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v)
}
-
- rx.Flush()
}
- }()
- wg.Wait()
+ rx.Flush()
+ }
}
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
index 0243424f6..86f14db76 100644
--- a/pkg/tcpip/link/tun/BUILD
+++ b/pkg/tcpip/link/tun/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "tun_endpoint_refs.go",
package = "tun",
prefix = "tunEndpoint",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "tunEndpoint",
},
@@ -28,6 +28,7 @@ go_library(
"//pkg/context",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/sync",
"//pkg/syserror",
"//pkg/tcpip",
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index f94491026..cda6328a2 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -150,7 +150,6 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
// 2. Creating a new NIC.
id := tcpip.NICID(s.UniqueID())
- // TODO(gvisor.dev/1486): enable leak check for tunEndpoint.
endpoint := &tunEndpoint{
Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""),
stack: s,
@@ -158,6 +157,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
name: name,
isTap: prefix == "tap",
}
+ endpoint.EnableLeakCheck()
endpoint.Endpoint.LinkEPCapabilities = linkCaps
if endpoint.name == "" {
endpoint.name = fmt.Sprintf("%s%d", prefix, id)
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index 59710352b..c118a2929 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -12,6 +12,7 @@ go_test(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index b40dde96b..8a6bcfc2c 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -30,5 +30,6 @@ go_test(
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
"@com_github_google_go_cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 7df77c66e..a79379abb 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -18,6 +18,7 @@
package arp
import (
+ "fmt"
"sync/atomic"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -150,28 +151,36 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
remoteAddr := tcpip.Address(h.ProtocolAddressSender())
remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- e.nud.HandleProbe(remoteAddr, localAddr, ProtocolNumber, remoteLinkAddr, e.protocol)
+ e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol)
}
- // As per RFC 826, under Packet Reception:
- // Swap hardware and protocol fields, putting the local hardware and
- // protocol addresses in the sender fields.
- //
- // Send the packet to the (new) target hardware address on the same
- // hardware on which the request was received.
- origSender := h.HardwareAddressSender()
- r.RemoteLinkAddress = tcpip.LinkAddress(origSender)
respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize,
})
packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize))
packet.SetIPv4OverEthernet()
packet.SetOp(header.ARPReply)
- copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:])
- copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget())
- copy(packet.HardwareAddressTarget(), origSender)
- copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender())
- _ = e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, respPkt)
+ // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a
+ // link address.
+ _ = copy(packet.HardwareAddressSender(), e.nic.LinkAddress())
+ if n := copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()); n != header.IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
+ }
+ origSender := h.HardwareAddressSender()
+ if n := copy(packet.HardwareAddressTarget(), origSender); n != header.EthernetAddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.EthernetAddressSize))
+ }
+ if n := copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()); n != header.IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
+ }
+
+ // As per RFC 826, under Packet Reception:
+ // Swap hardware and protocol fields, putting the local hardware and
+ // protocol addresses in the sender fields.
+ //
+ // Send the packet to the (new) target hardware address on the same
+ // hardware on which the request was received.
+ _ = e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt)
case header.ARPReply:
addr := tcpip.Address(h.ProtocolAddressSender())
@@ -199,6 +208,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// protocol implements stack.NetworkProtocol and stack.LinkAddressResolver.
type protocol struct {
+ stack *stack.Stack
}
func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
@@ -227,26 +237,44 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
-func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
- r := &stack.Route{
- NetProto: ProtocolNumber,
- RemoteLinkAddress: remoteLinkAddr,
+func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error {
+ if len(remoteLinkAddr) == 0 {
+ remoteLinkAddr = header.EthernetBroadcastAddress
}
- if len(r.RemoteLinkAddress) == 0 {
- r.RemoteLinkAddress = header.EthernetBroadcastAddress
+
+ nicID := nic.ID()
+ if len(localAddr) == 0 {
+ addr, err := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber)
+ if err != nil {
+ return err
+ }
+
+ if len(addr.Address) == 0 {
+ return tcpip.ErrNetworkUnreachable
+ }
+
+ localAddr = addr.Address
+ } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
+ return tcpip.ErrBadLocalAddress
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.ARPSize,
+ ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.ARPSize,
})
h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize))
+ pkt.NetworkProtocolNumber = ProtocolNumber
h.SetIPv4OverEthernet()
h.SetOp(header.ARPRequest)
- copy(h.HardwareAddressSender(), linkEP.LinkAddress())
- copy(h.ProtocolAddressSender(), localAddr)
- copy(h.ProtocolAddressTarget(), addr)
-
- return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
+ // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a
+ // link address.
+ _ = copy(h.HardwareAddressSender(), nic.LinkAddress())
+ if n := copy(h.ProtocolAddressSender(), localAddr); n != header.IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
+ }
+ if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize))
+ }
+ return nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt)
}
// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress.
@@ -286,6 +314,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
// Note, to make sure that the ARP endpoint receives ARP packets, the "arp"
// address must be added to every NIC that should respond to ARP requests. See
// ProtocolAddress for more details.
-func NewProtocol(*stack.Stack) stack.NetworkProtocol {
- return &protocol{}
+func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
+ return &protocol{stack: s}
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 626af975a..bf1292bb8 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -22,6 +22,7 @@ import (
"time"
"github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -78,13 +79,11 @@ func (t eventType) String() string {
type eventInfo struct {
eventType eventType
nicID tcpip.NICID
- addr tcpip.Address
- linkAddr tcpip.LinkAddress
- state stack.NeighborState
+ entry stack.NeighborEntry
}
func (e eventInfo) String() string {
- return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.eventType, e.nicID, e.addr, e.linkAddr, e.state)
+ return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry)
}
// arpDispatcher implements NUDDispatcher to validate the dispatching of
@@ -96,35 +95,29 @@ type arpDispatcher struct {
var _ stack.NUDDispatcher = (*arpDispatcher)(nil)
-func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) {
+func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) {
e := eventInfo{
eventType: entryAdded,
nicID: nicID,
- addr: addr,
- linkAddr: linkAddr,
- state: state,
+ entry: entry,
}
d.C <- e
}
-func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) {
+func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) {
e := eventInfo{
eventType: entryChanged,
nicID: nicID,
- addr: addr,
- linkAddr: linkAddr,
- state: state,
+ entry: entry,
}
d.C <- e
}
-func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) {
+func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) {
e := eventInfo{
eventType: entryRemoved,
nicID: nicID,
- addr: addr,
- linkAddr: linkAddr,
- state: state,
+ entry: entry,
}
d.C <- e
}
@@ -132,7 +125,7 @@ func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address,
func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error {
select {
case got := <-d.C:
- if diff := cmp.Diff(got, want, cmp.AllowUnexported(got)); diff != "" {
+ if diff := cmp.Diff(got, want, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" {
return fmt.Errorf("got invalid event (-got +want):\n%s", diff)
}
case <-ctx.Done():
@@ -373,9 +366,11 @@ func TestDirectRequestWithNeighborCache(t *testing.T) {
wantEvent := eventInfo{
eventType: entryAdded,
nicID: nicID,
- addr: test.senderAddr,
- linkAddr: tcpip.LinkAddress(test.senderLinkAddr),
- state: stack.Stale,
+ entry: stack.NeighborEntry{
+ Addr: test.senderAddr,
+ LinkAddr: tcpip.LinkAddress(test.senderLinkAddr),
+ State: stack.Stale,
+ },
}
if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil {
t.Fatal(err)
@@ -404,9 +399,6 @@ func TestDirectRequestWithNeighborCache(t *testing.T) {
if got, want := neigh.LinkAddr, test.senderLinkAddr; got != want {
t.Errorf("got neighbor LinkAddr = %s, want = %s", got, want)
}
- if got, want := neigh.LocalAddr, stackAddr; got != want {
- t.Errorf("got neighbor LocalAddr = %s, want = %s", got, want)
- }
if got, want := neigh.State, stack.Stale; got != want {
t.Errorf("got neighbor State = %s, want = %s", got, want)
}
@@ -423,43 +415,164 @@ func TestDirectRequestWithNeighborCache(t *testing.T) {
}
}
+var _ stack.NetworkInterface = (*testInterface)(nil)
+
+type testInterface struct {
+ stack.LinkEndpoint
+
+ nicID tcpip.NICID
+}
+
+func (t *testInterface) ID() tcpip.NICID {
+ return t.nicID
+}
+
+func (*testInterface) IsLoopback() bool {
+ return false
+}
+
+func (*testInterface) Name() string {
+ return ""
+}
+
+func (*testInterface) Enabled() bool {
+ return true
+}
+
+func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ r := stack.Route{
+ NetProto: protocol,
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+ return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt)
+}
+
func TestLinkAddressRequest(t *testing.T) {
+ const nicID = 1
+
+ testAddr := tcpip.Address([]byte{1, 2, 3, 4})
+
tests := []struct {
name string
+ nicAddr tcpip.Address
+ localAddr tcpip.Address
remoteLinkAddr tcpip.LinkAddress
- expectLinkAddr tcpip.LinkAddress
+
+ expectedErr *tcpip.Error
+ expectedLocalAddr tcpip.Address
+ expectedRemoteLinkAddr tcpip.LinkAddress
}{
{
- name: "Unicast",
+ name: "Unicast",
+ nicAddr: stackAddr,
+ localAddr: stackAddr,
+ remoteLinkAddr: remoteLinkAddr,
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: remoteLinkAddr,
+ },
+ {
+ name: "Multicast",
+ nicAddr: stackAddr,
+ localAddr: stackAddr,
+ remoteLinkAddr: "",
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
+ },
+ {
+ name: "Unicast with unspecified source",
+ nicAddr: stackAddr,
+ remoteLinkAddr: remoteLinkAddr,
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: remoteLinkAddr,
+ },
+ {
+ name: "Multicast with unspecified source",
+ nicAddr: stackAddr,
+ remoteLinkAddr: "",
+ expectedLocalAddr: stackAddr,
+ expectedRemoteLinkAddr: header.EthernetBroadcastAddress,
+ },
+ {
+ name: "Unicast with unassigned address",
+ localAddr: testAddr,
+ remoteLinkAddr: remoteLinkAddr,
+ expectedErr: tcpip.ErrBadLocalAddress,
+ },
+ {
+ name: "Multicast with unassigned address",
+ localAddr: testAddr,
+ remoteLinkAddr: "",
+ expectedErr: tcpip.ErrBadLocalAddress,
+ },
+ {
+ name: "Unicast with no local address available",
remoteLinkAddr: remoteLinkAddr,
- expectLinkAddr: remoteLinkAddr,
+ expectedErr: tcpip.ErrNetworkUnreachable,
},
{
- name: "Multicast",
+ name: "Multicast with no local address available",
remoteLinkAddr: "",
- expectLinkAddr: header.EthernetBroadcastAddress,
+ expectedErr: tcpip.ErrNetworkUnreachable,
},
}
for _, test := range tests {
- p := arp.NewProtocol(nil)
- linkRes, ok := p.(stack.LinkAddressResolver)
- if !ok {
- t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
- }
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
+ })
+ p := s.NetworkProtocolInstance(arp.ProtocolNumber)
+ linkRes, ok := p.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
+ }
- linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
- if err := linkRes.LinkAddressRequest(stackAddr, remoteAddr, test.remoteLinkAddr, linkEP); err != nil {
- t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr, remoteAddr, test.remoteLinkAddr, err)
- }
+ linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr)
+ if err := s.CreateNIC(nicID, linkEP); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
- pkt, ok := linkEP.Read()
- if !ok {
- t.Fatal("expected to send a link address request")
- }
+ if len(test.nicAddr) != 0 {
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err)
+ }
+ }
- if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want {
- t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want)
- }
+ // We pass a test network interface to LinkAddressRequest with the same
+ // NIC ID and link endpoint used by the NIC we created earlier so that we
+ // can mock a link address request and observe the packets sent to the
+ // link endpoint even though the stack uses the real NIC to validate the
+ // local address.
+ if err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr {
+ t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, test.localAddr, test.remoteLinkAddr, err, test.expectedErr)
+ }
+
+ if test.expectedErr != nil {
+ return
+ }
+
+ pkt, ok := linkEP.Read()
+ if !ok {
+ t.Fatal("expected to send a link address request")
+ }
+
+ if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr)
+ }
+
+ rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader()))
+ if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr {
+ t.Errorf("got HardwareAddressSender = %s, want = %s", got, stackLinkAddr)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressSender()); got != test.expectedLocalAddr {
+ t.Errorf("got ProtocolAddressSender = %s, want = %s", got, test.expectedLocalAddr)
+ }
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"); got != want {
+ t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want)
+ }
+ if got := tcpip.Address(rep.ProtocolAddressTarget()); got != remoteAddr {
+ t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, remoteAddr)
+ }
+ })
}
}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index ed502a473..936601287 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -136,8 +136,16 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea
// proto is the protocol number marked in the fragment being processed. It has
// to be given here outside of the FragmentID struct because IPv6 should not use
// the protocol to identify a fragment.
+//
+// releaseCB is a callback that will run when the fragment reassembly of a
+// packet is complete or cancelled. releaseCB take a a boolean argument which is
+// true iff the reassembly is cancelled due to timeout. releaseCB should be
+// passed only with the first fragment of a packet. If more than one releaseCB
+// are passed for the same packet, only the first releaseCB will be saved for
+// the packet and the succeeding ones will be dropped by running them
+// immediately with a false argument.
func (f *Fragmentation) Process(
- id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) (
+ id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView, releaseCB func(bool)) (
buffer.VectorisedView, uint8, bool, error) {
if first > last {
return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs)
@@ -171,6 +179,12 @@ func (f *Fragmentation) Process(
f.releaseReassemblersLocked()
}
}
+ if releaseCB != nil {
+ if !r.setCallback(releaseCB) {
+ // We got a duplicate callback. Release it immediately.
+ releaseCB(false /* timedOut */)
+ }
+ }
f.mu.Unlock()
res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, vv)
@@ -178,14 +192,14 @@ func (f *Fragmentation) Process(
// We probably got an invalid sequence of fragments. Just
// discard the reassembler and move on.
f.mu.Lock()
- f.release(r)
+ f.release(r, false /* timedOut */)
f.mu.Unlock()
return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragmentation processing error: %w", err)
}
f.mu.Lock()
f.size += consumed
if done {
- f.release(r)
+ f.release(r, false /* timedOut */)
}
// Evict reassemblers if we are consuming more memory than highLimit until
// we reach lowLimit.
@@ -195,14 +209,14 @@ func (f *Fragmentation) Process(
if tail == nil {
break
}
- f.release(tail)
+ f.release(tail, false /* timedOut */)
}
}
f.mu.Unlock()
return res, firstFragmentProto, done, nil
}
-func (f *Fragmentation) release(r *reassembler) {
+func (f *Fragmentation) release(r *reassembler, timedOut bool) {
// Before releasing a fragment we need to check if r is already marked as done.
// Otherwise, we would delete it twice.
if r.checkDoneOrMark() {
@@ -216,6 +230,8 @@ func (f *Fragmentation) release(r *reassembler) {
log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size)
f.size = 0
}
+
+ r.release(timedOut) // releaseCB may run.
}
// releaseReassemblersLocked releases already-expired reassemblers, then
@@ -238,31 +254,31 @@ func (f *Fragmentation) releaseReassemblersLocked() {
break
}
// If the oldest reassembler has already expired, release it.
- f.release(r)
+ f.release(r, true /* timedOut*/)
}
}
// PacketFragmenter is the book-keeping struct for packet fragmentation.
type PacketFragmenter struct {
- transportHeader buffer.View
- data buffer.VectorisedView
- reserve int
- innerMTU int
- fragmentCount int
- currentFragment int
- fragmentOffset int
+ transportHeader buffer.View
+ data buffer.VectorisedView
+ reserve int
+ fragmentPayloadLen int
+ fragmentCount int
+ currentFragment int
+ fragmentOffset int
}
// MakePacketFragmenter prepares the struct needed for packet fragmentation.
//
// pkt is the packet to be fragmented.
//
-// innerMTU is the maximum number of bytes of fragmentable data a fragment can
+// fragmentPayloadLen is the maximum number of bytes of fragmentable data a fragment can
// have.
//
// reserve is the number of bytes that should be reserved for the headers in
// each generated fragment.
-func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) PacketFragmenter {
+func MakePacketFragmenter(pkt *stack.PacketBuffer, fragmentPayloadLen uint32, reserve int) PacketFragmenter {
// As per RFC 8200 Section 4.5, some IPv6 extension headers should not be
// repeated in each fragment. However we do not currently support any header
// of that kind yet, so the following computation is valid for both IPv4 and
@@ -273,13 +289,13 @@ func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) Pa
var fragmentableData buffer.VectorisedView
fragmentableData.AppendView(pkt.TransportHeader().View())
fragmentableData.Append(pkt.Data)
- fragmentCount := (fragmentableData.Size() + innerMTU - 1) / innerMTU
+ fragmentCount := (uint32(fragmentableData.Size()) + fragmentPayloadLen - 1) / fragmentPayloadLen
return PacketFragmenter{
- data: fragmentableData,
- reserve: reserve,
- innerMTU: innerMTU,
- fragmentCount: fragmentCount,
+ data: fragmentableData,
+ reserve: reserve,
+ fragmentPayloadLen: int(fragmentPayloadLen),
+ fragmentCount: int(fragmentCount),
}
}
@@ -302,7 +318,7 @@ func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int,
})
// Copy data for the fragment.
- copied := pf.data.ReadToVV(&fragPkt.Data, pf.innerMTU)
+ copied := pf.data.ReadToVV(&fragPkt.Data, pf.fragmentPayloadLen)
offset := pf.fragmentOffset
pf.fragmentOffset += copied
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
index d3c7d7f92..5dcd10730 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -105,7 +105,7 @@ func TestFragmentationProcess(t *testing.T) {
f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{})
firstFragmentProto := c.in[0].proto
for i, in := range c.in {
- vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv)
+ vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv, nil)
if err != nil {
t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %X) failed: %s",
in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), err)
@@ -240,7 +240,7 @@ func TestReassemblingTimeout(t *testing.T) {
for _, event := range test.events {
clock.Advance(event.clockAdvance)
if frag := event.fragment; frag != nil {
- _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data))
+ _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data), nil)
if err != nil {
t.Fatalf("%s: f.Process failed: %s", event.name, err)
}
@@ -259,15 +259,15 @@ func TestReassemblingTimeout(t *testing.T) {
func TestMemoryLimits(t *testing.T) {
f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{})
// Send first fragment with id = 0.
- f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"))
+ f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"), nil)
// Send first fragment with id = 1.
- f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1"))
+ f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1"), nil)
// Send first fragment with id = 2.
- f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2"))
+ f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2"), nil)
// Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
// evicted.
- f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3"))
+ f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3"), nil)
if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok {
t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
@@ -283,9 +283,9 @@ func TestMemoryLimits(t *testing.T) {
func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{})
// Send first fragment with id = 0.
- f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"))
+ f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil)
// Send the same packet again.
- f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"))
+ f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil)
got := f.size
want := 1
@@ -377,7 +377,7 @@ func TestErrors(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{})
- _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data))
+ _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data), nil)
if !errors.Is(err, test.err) {
t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err)
}
@@ -403,14 +403,14 @@ func TestPacketFragmenter(t *testing.T) {
tests := []struct {
name string
- innerMTU int
+ fragmentPayloadLen uint32
transportHeaderLen int
payloadSize int
wantFragments []fragmentInfo
}{
{
name: "Packet exactly fits in MTU",
- innerMTU: 1280,
+ fragmentPayloadLen: 1280,
transportHeaderLen: 0,
payloadSize: 1280,
wantFragments: []fragmentInfo{
@@ -419,7 +419,7 @@ func TestPacketFragmenter(t *testing.T) {
},
{
name: "Packet exactly does not fit in MTU",
- innerMTU: 1000,
+ fragmentPayloadLen: 1000,
transportHeaderLen: 0,
payloadSize: 1001,
wantFragments: []fragmentInfo{
@@ -429,7 +429,7 @@ func TestPacketFragmenter(t *testing.T) {
},
{
name: "Packet has a transport header",
- innerMTU: 560,
+ fragmentPayloadLen: 560,
transportHeaderLen: 40,
payloadSize: 560,
wantFragments: []fragmentInfo{
@@ -439,7 +439,7 @@ func TestPacketFragmenter(t *testing.T) {
},
{
name: "Packet has a huge transport header",
- innerMTU: 500,
+ fragmentPayloadLen: 500,
transportHeaderLen: 1300,
payloadSize: 500,
wantFragments: []fragmentInfo{
@@ -458,7 +458,7 @@ func TestPacketFragmenter(t *testing.T) {
originalPayload.AppendView(pkt.TransportHeader().View())
originalPayload.Append(pkt.Data)
var reassembledPayload buffer.VectorisedView
- pf := MakePacketFragmenter(pkt, test.innerMTU, reserve)
+ pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve)
for i := 0; ; i++ {
fragPkt, offset, copied, more := pf.BuildNextFragment()
wantFragment := test.wantFragments[i]
@@ -474,8 +474,8 @@ func TestPacketFragmenter(t *testing.T) {
if more != wantFragment.more {
t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more)
}
- if got := fragPkt.Size(); got > test.innerMTU {
- t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.innerMTU)
+ if got := uint32(fragPkt.Size()); got > test.fragmentPayloadLen {
+ t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.fragmentPayloadLen)
}
if got := fragPkt.AvailableHeaderBytes(); got != reserve {
t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve)
@@ -497,3 +497,89 @@ func TestPacketFragmenter(t *testing.T) {
})
}
}
+
+func TestReleaseCallback(t *testing.T) {
+ const (
+ proto = 99
+ )
+
+ var result int
+ var callbackReasonIsTimeout bool
+ cb1 := func(timedOut bool) { result = 1; callbackReasonIsTimeout = timedOut }
+ cb2 := func(timedOut bool) { result = 2; callbackReasonIsTimeout = timedOut }
+
+ tests := []struct {
+ name string
+ callbacks []func(bool)
+ timeout bool
+ wantResult int
+ wantCallbackReasonIsTimeout bool
+ }{
+ {
+ name: "callback runs on release",
+ callbacks: []func(bool){cb1},
+ timeout: false,
+ wantResult: 1,
+ wantCallbackReasonIsTimeout: false,
+ },
+ {
+ name: "first callback is nil",
+ callbacks: []func(bool){nil, cb2},
+ timeout: false,
+ wantResult: 2,
+ wantCallbackReasonIsTimeout: false,
+ },
+ {
+ name: "two callbacks - first one is set",
+ callbacks: []func(bool){cb1, cb2},
+ timeout: false,
+ wantResult: 1,
+ wantCallbackReasonIsTimeout: false,
+ },
+ {
+ name: "callback runs on timeout",
+ callbacks: []func(bool){cb1},
+ timeout: true,
+ wantResult: 1,
+ wantCallbackReasonIsTimeout: true,
+ },
+ {
+ name: "no callbacks",
+ callbacks: []func(bool){nil},
+ timeout: false,
+ wantResult: 0,
+ wantCallbackReasonIsTimeout: false,
+ },
+ }
+
+ id := FragmentID{ID: 0}
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ result = 0
+ callbackReasonIsTimeout = false
+
+ f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{})
+
+ for i, cb := range test.callbacks {
+ _, _, _, err := f.Process(id, uint16(i), uint16(i), true, proto, vv(1, "0"), cb)
+ if err != nil {
+ t.Errorf("f.Process error = %s", err)
+ }
+ }
+
+ r, ok := f.reassemblers[id]
+ if !ok {
+ t.Fatalf("Reassemberr not found")
+ }
+ f.release(r, test.timeout)
+
+ if result != test.wantResult {
+ t.Errorf("got result = %d, want = %d", result, test.wantResult)
+ }
+ if callbackReasonIsTimeout != test.wantCallbackReasonIsTimeout {
+ t.Errorf("got callbackReasonIsTimeout = %t, want = %t", callbackReasonIsTimeout, test.wantCallbackReasonIsTimeout)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index 9bb051a30..c0cc0bde0 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -41,6 +41,7 @@ type reassembler struct {
heap fragHeap
done bool
creationTime int64
+ callback func(bool)
}
func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler {
@@ -123,3 +124,24 @@ func (r *reassembler) checkDoneOrMark() bool {
r.mu.Unlock()
return prev
}
+
+func (r *reassembler) setCallback(c func(bool)) bool {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.callback != nil {
+ return false
+ }
+ r.callback = c
+ return true
+}
+
+func (r *reassembler) release(timedOut bool) {
+ r.mu.Lock()
+ callback := r.callback
+ r.callback = nil
+ r.mu.Unlock()
+
+ if callback != nil {
+ callback(timedOut)
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
index a0a04a027..fa2a70dc8 100644
--- a/pkg/tcpip/network/fragmentation/reassembler_test.go
+++ b/pkg/tcpip/network/fragmentation/reassembler_test.go
@@ -105,3 +105,26 @@ func TestUpdateHoles(t *testing.T) {
}
}
}
+
+func TestSetCallback(t *testing.T) {
+ result := 0
+ reasonTimeout := false
+
+ cb1 := func(timedOut bool) { result = 1; reasonTimeout = timedOut }
+ cb2 := func(timedOut bool) { result = 2; reasonTimeout = timedOut }
+
+ r := newReassembler(FragmentID{}, &faketime.NullClock{})
+ if !r.setCallback(cb1) {
+ t.Errorf("setCallback failed")
+ }
+ if r.setCallback(cb2) {
+ t.Errorf("setCallback should fail if one is already set")
+ }
+ r.release(true)
+ if result != 1 {
+ t.Errorf("got result = %d, want = 1", result)
+ }
+ if !reasonTimeout {
+ t.Errorf("got reasonTimeout = %t, want = true", reasonTimeout)
+ }
+}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index d436873b6..969579601 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -15,11 +15,13 @@
package ip_test
import (
+ "strings"
"testing"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
@@ -302,6 +304,10 @@ func (t *testInterface) setEnabled(v bool) {
t.mu.disabled = !v
}
+func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
func TestSourceAddressValidation(t *testing.T) {
rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) {
totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
@@ -320,6 +326,7 @@ func TestSourceAddressValidation(t *testing.T) {
SrcAddr: src,
DstAddr: localIPv4Addr,
})
+ ip.SetChecksum(^ip.CalculateChecksum())
e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -342,7 +349,6 @@ func TestSourceAddressValidation(t *testing.T) {
SrcAddr: src,
DstAddr: localIPv6Addr,
})
-
e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
}))
@@ -579,6 +585,7 @@ func TestIPv4Receive(t *testing.T) {
SrcAddr: remoteIPv4Addr,
DstAddr: localIPv4Addr,
})
+ ip.SetChecksum(^ip.CalculateChecksum())
// Make payload be non-zero.
for i := header.IPv4MinimumSize; i < totalLen; i++ {
@@ -660,6 +667,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
SrcAddr: "\x0a\x00\x00\xbb",
DstAddr: localIPv4Addr,
})
+ ip.SetChecksum(^ip.CalculateChecksum())
// Create the ICMP header.
icmp := header.ICMPv4(view[header.IPv4MinimumSize:])
@@ -679,12 +687,17 @@ func TestIPv4ReceiveControl(t *testing.T) {
SrcAddr: localIPv4Addr,
DstAddr: remoteIPv4Addr,
})
+ ip.SetChecksum(^ip.CalculateChecksum())
// Make payload be non-zero.
for i := dataOffset; i < len(view); i++ {
view[i] = uint8(i)
}
+ icmp.SetChecksum(0)
+ checksum := ^header.Checksum(icmp, 0 /* initial */)
+ icmp.SetChecksum(checksum)
+
// Give packet to IPv4 endpoint, dispatcher will validate that
// it's ok.
nic.testObject.protocol = 10
@@ -732,6 +745,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
SrcAddr: remoteIPv4Addr,
DstAddr: localIPv4Addr,
})
+ ip1.SetChecksum(^ip1.CalculateChecksum())
+
// Make payload be non-zero.
for i := header.IPv4MinimumSize; i < totalLen; i++ {
frag1[i] = uint8(i)
@@ -748,6 +763,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
SrcAddr: remoteIPv4Addr,
DstAddr: localIPv4Addr,
})
+ ip2.SetChecksum(^ip2.CalculateChecksum())
+
// Make payload be non-zero.
for i := header.IPv4MinimumSize; i < totalLen; i++ {
frag2[i] = uint8(i)
@@ -1020,3 +1037,406 @@ func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer
_, _ = pkt.NetworkHeader().Consume(netHdrLen)
return pkt
}
+
+func TestWriteHeaderIncludedPacket(t *testing.T) {
+ const (
+ nicID = 1
+ transportProto = 5
+
+ dataLen = 4
+ optionsLen = 4
+ )
+
+ dataBuf := [dataLen]byte{1, 2, 3, 4}
+ data := dataBuf[:]
+
+ ipv4OptionsBuf := [optionsLen]byte{0, 1, 0, 1}
+ ipv4Options := ipv4OptionsBuf[:]
+
+ ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4}
+ ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:]
+
+ var ipv6PayloadWithExtHdrBuf [dataLen + header.IPv6FragmentExtHdrLength]byte
+ ipv6PayloadWithExtHdr := ipv6PayloadWithExtHdrBuf[:]
+ if n := copy(ipv6PayloadWithExtHdr, ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr))
+ }
+ if n := copy(ipv6PayloadWithExtHdr[header.IPv6FragmentExtHdrLength:], data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+
+ tests := []struct {
+ name string
+ protoFactory stack.NetworkProtocolFactory
+ protoNum tcpip.NetworkProtocolNumber
+ nicAddr tcpip.Address
+ remoteAddr tcpip.Address
+ pktGen func(*testing.T, tcpip.Address) buffer.View
+ checker func(*testing.T, *stack.PacketBuffer, tcpip.Address)
+ expectedErr *tcpip.Error
+ }{
+ {
+ name: "IPv4",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ totalLen := header.IPv4MinimumSize + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return hdr.View()
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv4Any {
+ src = localIPv4Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if len(netHdr.View()) != header.IPv4MinimumSize {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize)
+ }
+
+ checker.IPv4(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv4Addr),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.IPFullLength(uint16(header.IPv4MinimumSize+len(data))),
+ checker.IPPayload(data),
+ )
+ },
+ },
+ {
+ name: "IPv4 with IHL too small",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ totalLen := header.IPv4MinimumSize + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize - 1,
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return hdr.View()
+ },
+ expectedErr: tcpip.ErrMalformedHeader,
+ },
+ {
+ name: "IPv4 too small",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return buffer.View(ip[:len(ip)-1])
+ },
+ expectedErr: tcpip.ErrMalformedHeader,
+ },
+ {
+ name: "IPv4 minimum size",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return buffer.View(ip)
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv4Any {
+ src = localIPv4Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if len(netHdr.View()) != header.IPv4MinimumSize {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize)
+ }
+
+ checker.IPv4(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv4Addr),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.IPFullLength(header.IPv4MinimumSize),
+ checker.IPPayload(nil),
+ )
+ },
+ },
+ {
+ name: "IPv4 with options",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ipHdrLen := header.IPv4MinimumSize + len(ipv4Options)
+ totalLen := ipHdrLen + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ ip := header.IPv4(hdr.Prepend(ipHdrLen))
+ ip.Encode(&header.IPv4Fields{
+ IHL: uint8(ipHdrLen),
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ if n := copy(ip.Options(), ipv4Options); n != len(ipv4Options) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv4Options))
+ }
+ return hdr.View()
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv4Any {
+ src = localIPv4Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ hdrLen := header.IPv4MinimumSize + len(ipv4Options)
+ if len(netHdr.View()) != hdrLen {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
+ }
+
+ checker.IPv4(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv4Addr),
+ checker.IPv4HeaderLength(hdrLen),
+ checker.IPFullLength(uint16(hdrLen+len(data))),
+ checker.IPv4Options(ipv4Options),
+ checker.IPPayload(data),
+ )
+ },
+ },
+ {
+ name: "IPv6",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ nicAddr: localIPv6Addr,
+ remoteAddr: remoteIPv6Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ totalLen := header.IPv6MinimumSize + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ NextHeader: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return hdr.View()
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv6Any {
+ src = localIPv6Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if len(netHdr.View()) != header.IPv6MinimumSize {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv6Addr),
+ checker.IPFullLength(uint16(header.IPv6MinimumSize+len(data))),
+ checker.IPPayload(data),
+ )
+ },
+ },
+ {
+ name: "IPv6 with extension header",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ nicAddr: localIPv6Addr,
+ remoteAddr: remoteIPv6Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ if n := copy(hdr.Prepend(len(ipv6FragmentExtHdr)), ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr))
+ }
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ NextHeader: uint8(header.IPv6FragmentExtHdrIdentifier),
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return hdr.View()
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv6Any {
+ src = localIPv6Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if want := header.IPv6MinimumSize + len(ipv6FragmentExtHdr); len(netHdr.View()) != want {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), want)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv6Addr),
+ checker.IPFullLength(uint16(header.IPv6MinimumSize+len(ipv6PayloadWithExtHdr))),
+ checker.IPPayload(ipv6PayloadWithExtHdr),
+ )
+ },
+ },
+ {
+ name: "IPv6 minimum size",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ nicAddr: localIPv6Addr,
+ remoteAddr: remoteIPv6Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ NextHeader: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return buffer.View(ip)
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv6Any {
+ src = localIPv6Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if len(netHdr.View()) != header.IPv6MinimumSize {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv6Addr),
+ checker.IPFullLength(header.IPv6MinimumSize),
+ checker.IPPayload(nil),
+ )
+ },
+ },
+ {
+ name: "IPv6 too small",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ nicAddr: localIPv6Addr,
+ remoteAddr: remoteIPv6Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ NextHeader: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return buffer.View(ip[:len(ip)-1])
+ },
+ expectedErr: tcpip.ErrMalformedHeader,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ subTests := []struct {
+ name string
+ srcAddr tcpip.Address
+ }{
+ {
+ name: "unspecified source",
+ srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))),
+ },
+ {
+ name: "random source",
+ srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))),
+ },
+ }
+
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory},
+ })
+ e := channel.New(1, 1280, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}})
+
+ r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err)
+ }
+ defer r.Release()
+
+ if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: test.pktGen(t, subTest.srcAddr).ToVectorisedView(),
+ })); err != test.expectedErr {
+ t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr)
+ }
+
+ if test.expectedErr != nil {
+ return
+ }
+
+ pkt, ok := e.Read()
+ if !ok {
+ t.Fatal("expected a packet to be written")
+ }
+ test.checker(t, pkt.Pkt, subTest.srcAddr)
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 7fc12e229..6252614ec 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -29,6 +29,7 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 3407755ed..cf287446e 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -15,6 +15,7 @@
package ipv4
import (
+ "errors"
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -23,10 +24,10 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
-// handleControl handles the case when an ICMP packet contains the headers of
-// the original packet that caused the ICMP one to be sent. This information is
-// used to find out which transport endpoint must be notified about the ICMP
-// packet.
+// handleControl handles the case when an ICMP error packet contains the headers
+// of the original packet that caused the ICMP one to be sent. This information
+// is used to find out which transport endpoint must be notified about the ICMP
+// packet. We only expect the payload, not the enclosing ICMP packet.
func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
@@ -73,20 +74,65 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
}
h := header.ICMPv4(v)
+ // Only do in-stack processing if the checksum is correct.
+ if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xffff {
+ received.Invalid.Increment()
+ // It's possible that a raw socket expects to receive this regardless
+ // of checksum errors. If it's an echo request we know it's safe because
+ // we are the only handler, however other types do not cope well with
+ // packets with checksum errors.
+ switch h.Type() {
+ case header.ICMPv4Echo:
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
+ }
+ return
+ }
+
+ iph := header.IPv4(pkt.NetworkHeader().View())
+ var newOptions header.IPv4Options
+ if len(iph) > header.IPv4MinimumSize {
+ // RFC 1122 section 3.2.2.6 (page 43) (and similar for other round trip
+ // type ICMP packets):
+ // If a Record Route and/or Time Stamp option is received in an
+ // ICMP Echo Request, this option (these options) SHOULD be
+ // updated to include the current host and included in the IP
+ // header of the Echo Reply message, without "truncation".
+ // Thus, the recorded route will be for the entire round trip.
+ //
+ // So we need to let the option processor know how it should handle them.
+ var op optionsUsage
+ if h.Type() == header.ICMPv4Echo {
+ op = &optionUsageEcho{}
+ } else {
+ op = &optionUsageReceive{}
+ }
+ aux, tmp, err := processIPOptions(r, iph.Options(), op)
+ if err != nil {
+ switch {
+ case
+ errors.Is(err, header.ErrIPv4OptDuplicate),
+ errors.Is(err, errIPv4RecordRouteOptInvalidLength),
+ errors.Is(err, errIPv4RecordRouteOptInvalidPointer),
+ errors.Is(err, errIPv4TimestampOptInvalidLength),
+ errors.Is(err, errIPv4TimestampOptInvalidPointer),
+ errors.Is(err, errIPv4TimestampOptOverflow):
+ _ = e.protocol.returnError(r, &icmpReasonParamProblem{pointer: aux}, pkt)
+ e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ }
+ return
+ }
+ newOptions = tmp
+ }
+
// TODO(b/112892170): Meaningfully handle all ICMP types.
switch h.Type() {
case header.ICMPv4Echo:
received.Echo.Increment()
- // Only send a reply if the checksum is valid.
- headerChecksum := h.Checksum()
- h.SetChecksum(0)
- calculatedChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
- h.SetChecksum(headerChecksum)
- if calculatedChecksum != headerChecksum {
- // It's possible that a raw socket still expects to receive this.
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
- received.Invalid.Increment()
+ sent := stats.ICMP.V4PacketsSent
+ if !r.Stack().AllowICMPMessage() {
+ sent.RateLimited.Increment()
return
}
@@ -98,9 +144,14 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
// waiting endpoints. Consider moving responsibility for doing the copy to
// DeliverTransportPacket so that is is only done when needed.
replyData := pkt.Data.ToOwnedView()
- replyIPHdr := header.IPv4(append(buffer.View(nil), pkt.NetworkHeader().View()...))
+ // It's possible that a raw socket expects to receive this.
e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
+ pkt = nil
+ // Take the base of the incoming request IP header but replace the options.
+ replyHeaderLength := uint8(header.IPv4MinimumSize + len(newOptions))
+ replyIPHdr := header.IPv4(append(iph[:header.IPv4MinimumSize:header.IPv4MinimumSize], newOptions...))
+ replyIPHdr.SetHeaderLength(replyHeaderLength)
// As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP
// source address MUST be one of its own IP addresses (but not a broadcast
@@ -139,7 +190,8 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
// The fields we need to alter.
//
// We need to produce the entire packet in the data segment in order to
- // use WriteHeaderIncludedPacket().
+ // use WriteHeaderIncludedPacket(). WriteHeaderIncludedPacket sets the
+ // total length and the header checksum so we don't need to set those here.
replyIPHdr.SetSourceAddress(r.LocalAddress)
replyIPHdr.SetDestinationAddress(r.RemoteAddress)
replyIPHdr.SetTTL(r.DefaultTTL())
@@ -157,8 +209,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
})
replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
- // The checksum will be calculated so we don't need to do it here.
- sent := stats.ICMP.V4PacketsSent
if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil {
sent.Dropped.Increment()
return
@@ -182,8 +232,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
e.handleControl(stack.ControlPortUnreachable, 0, pkt)
case header.ICMPv4FragmentationNeeded:
- mtu := uint32(h.MTU())
- e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt)
+ networkMTU, err := calculateNetworkMTU(uint32(h.MTU()), header.IPv4MinimumSize)
+ if err != nil {
+ networkMTU = 0
+ }
+ e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt)
}
case header.ICMPv4SrcQuench:
@@ -234,6 +287,21 @@ type icmpReasonProtoUnreachable struct{}
func (*icmpReasonProtoUnreachable) isICMPReason() {}
+// icmpReasonReassemblyTimeout is an error where insufficient fragments are
+// received to complete reassembly of a packet within a configured time after
+// the reception of the first-arriving fragment of that packet.
+type icmpReasonReassemblyTimeout struct{}
+
+func (*icmpReasonReassemblyTimeout) isICMPReason() {}
+
+// icmpReasonParamProblem is an error to use to request a Parameter Problem
+// message to be sent.
+type icmpReasonParamProblem struct {
+ pointer byte
+}
+
+func (*icmpReasonParamProblem) isICMPReason() {}
+
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv4 and sends it back to the remote device that sent
// the problematic packet. It incorporates as much of that packet as
@@ -374,17 +442,29 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
- switch reason.(type) {
+ var counter *tcpip.StatCounter
+ switch reason := reason.(type) {
case *icmpReasonPortUnreachable:
+ icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4PortUnreachable)
+ counter = sent.DstUnreachable
case *icmpReasonProtoUnreachable:
+ icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(header.ICMPv4ProtoUnreachable)
+ counter = sent.DstUnreachable
+ case *icmpReasonReassemblyTimeout:
+ icmpHdr.SetType(header.ICMPv4TimeExceeded)
+ icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout)
+ counter = sent.TimeExceeded
+ case *icmpReasonParamProblem:
+ icmpHdr.SetType(header.ICMPv4ParamProblem)
+ icmpHdr.SetCode(header.ICMPv4UnusedCode)
+ icmpHdr.SetPointer(reason.pointer)
+ counter = sent.ParamProblem
default:
panic(fmt.Sprintf("unsupported ICMP type %T", reason))
}
- icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data))
- counter := sent.DstUnreachable
if err := route.WritePacket(
nil, /* gso */
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index c5ac7b8b5..4592984a5 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -16,7 +16,9 @@
package ipv4
import (
+ "errors"
"fmt"
+ "math"
"sync/atomic"
"time"
@@ -31,6 +33,8 @@ import (
)
const (
+ // ReassembleTimeout is the time a packet stays in the reassembly
+ // system before being evicted.
// As per RFC 791 section 3.2:
// The current recommendation for the initial timer setting is 15 seconds.
// This may be changed as experience with this protocol accumulates.
@@ -38,7 +42,7 @@ const (
// Considering that it is an old recommendation, we use the same reassembly
// timeout that linux defines, which is 30 seconds:
// https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ip.h#L138
- reassembleTimeout = 30 * time.Second
+ ReassembleTimeout = 30 * time.Second
// ProtocolNumber is the ipv4 protocol number.
ProtocolNumber = header.IPv4ProtocolNumber
@@ -176,7 +180,11 @@ func (e *endpoint) DefaultTTL() uint8 {
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
func (e *endpoint) MTU() uint32 {
- return calculateMTU(e.nic.MTU())
+ networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv4MinimumSize)
+ if err != nil {
+ return 0
+ }
+ return networkMTU
}
// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
@@ -190,29 +198,6 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return e.protocol.Number()
}
-// writePacketFragments fragments pkt and writes the results on the link
-// endpoint. The IP header must already present in the original packet. The mtu
-// is the maximum size of the packets.
-func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer) *tcpip.Error {
- networkHeader := header.IPv4(pkt.NetworkHeader().View())
- fragMTU := int(calculateFragmentInnerMTU(mtu, pkt))
- pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, pkt.AvailableHeaderBytes()+len(networkHeader))
-
- for {
- fragPkt, more := buildNextFragment(&pf, networkHeader)
- if err := e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pf.RemainingFragmentCount() + 1))
- return err
- }
- r.Stats().IP.PacketsSent.Increment()
- if !more {
- break
- }
- }
-
- return nil
-}
-
func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
ip := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
length := uint16(pkt.Size())
@@ -234,10 +219,36 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
pkt.NetworkProtocolNumber = ProtocolNumber
}
+// handleFragments fragments pkt and calls the handler function on each
+// fragment. It returns the number of fragments handled and the number of
+// fragments left to be processed. The IP header must already be present in the
+// original packet.
+func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
+ // Round the MTU down to align to 8 bytes.
+ fragmentPayloadSize := networkMTU &^ 7
+ networkHeader := header.IPv4(pkt.NetworkHeader().View())
+ pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadSize, pkt.AvailableHeaderBytes()+len(networkHeader))
+
+ var n int
+ for {
+ fragPkt, more := buildNextFragment(&pf, networkHeader)
+ if err := handler(fragPkt); err != nil {
+ return n, pf.RemainingFragmentCount() + 1, err
+ }
+ n++
+ if !more {
+ return n, pf.RemainingFragmentCount(), nil
+ }
+ }
+}
+
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
e.addIPHeader(r, pkt, params)
+ return e.writePacket(r, gso, pkt)
+}
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer) *tcpip.Error {
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
@@ -273,9 +284,26 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
if r.Loop&stack.PacketOut == 0 {
return nil
}
- if pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
- return e.writePacketFragments(r, gso, e.nic.MTU(), pkt)
+
+ networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return err
}
+
+ if packetMustBeFragmented(pkt, networkMTU, gso) {
+ sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
+ // fragment one by one using WritePacket() (current strategy) or if we
+ // want to create a PacketBufferList from the fragments and feed it to
+ // WritePackets(). It'll be faster but cost more memory.
+ return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt)
+ })
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(sent))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain))
+ return err
+ }
+
if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return err
@@ -293,9 +321,29 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return pkts.Len(), nil
}
- for pkt := pkts.Front(); pkt != nil; {
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
e.addIPHeader(r, pkt, params)
- pkt = pkt.Next()
+ networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
+ return 0, err
+ }
+
+ if packetMustBeFragmented(pkt, networkMTU, gso) {
+ // Keep track of the packet that is about to be fragmented so it can be
+ // removed once the fragmentation is done.
+ originalPkt := pkt
+ if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ // Modify the packet list in place with the new fragments.
+ pkts.InsertAfter(pkt, fragPkt)
+ pkt = fragPkt
+ return nil
+ }); err != nil {
+ panic(fmt.Sprintf("e.handleFragments(_, _, %d, _, _) = %s", networkMTU, err))
+ }
+ // Remove the packet that was just fragmented and process the rest.
+ pkts.Remove(originalPkt)
+ }
}
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
@@ -347,30 +395,27 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return n + len(dropped), nil
}
-// WriteHeaderIncludedPacket writes a packet already containing a network
-// header through the given route.
+// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
// The packet already has an IP header, but there are a few required
// checks.
h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
- return tcpip.ErrInvalidOptionValue
+ return tcpip.ErrMalformedHeader
}
ip := header.IPv4(h)
- if !ip.IsValid(pkt.Data.Size()) {
- return tcpip.ErrInvalidOptionValue
- }
// Always set the total length.
- ip.SetTotalLength(uint16(pkt.Data.Size()))
+ pktSize := pkt.Data.Size()
+ ip.SetTotalLength(uint16(pktSize))
// Set the source address when zero.
- if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) {
+ if ip.SourceAddress() == header.IPv4Any {
ip.SetSourceAddress(r.LocalAddress)
}
- // Set the destination. If the packet already included a destination,
- // it will be part of the route.
+ // Set the destination. If the packet already included a destination, it will
+ // be part of the route anyways.
ip.SetDestinationAddress(r.RemoteAddress)
// Set the packet ID when zero.
@@ -387,19 +432,17 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
ip.SetChecksum(0)
ip.SetChecksum(^ip.CalculateChecksum())
- if r.Loop&stack.PacketLoop != 0 {
- e.HandlePacket(r, pkt.Clone())
- }
- if r.Loop&stack.PacketOut == 0 {
- return nil
+ // Populate the packet buffer's network header and don't allow an invalid
+ // packet to be sent.
+ //
+ // Note that parsing only makes sure that the packet is well formed as per the
+ // wire format. We also want to check if the header's fields are valid before
+ // sending the packet.
+ if !parse.IPv4(pkt) || !header.IPv4(pkt.NetworkHeader().View()).IsValid(pktSize) {
+ return tcpip.ErrMalformedHeader
}
- if err := e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
- return err
- }
- r.Stats().IP.PacketsSent.Increment()
- return nil
+ return e.writePacket(r, nil /* gso */, pkt)
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
@@ -415,6 +458,32 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
return
}
+ // There has been some confusion regarding verifying checksums. We need
+ // just look for negative 0 (0xffff) as the checksum, as it's not possible to
+ // get positive 0 (0) for the checksum. Some bad implementations could get it
+ // when doing entry replacement in the early days of the Internet,
+ // however the lore that one needs to check for both persists.
+ //
+ // RFC 1624 section 1 describes the source of this confusion as:
+ // [the partial recalculation method described in RFC 1071] computes a
+ // result for certain cases that differs from the one obtained from
+ // scratch (one's complement of one's complement sum of the original
+ // fields).
+ //
+ // However RFC 1624 section 5 clarifies that if using the verification method
+ // "recommended by RFC 1071, it does not matter if an intermediate system
+ // generated a -0 instead of +0".
+ //
+ // RFC1071 page 1 specifies the verification method as:
+ // (3) To check a checksum, the 1's complement sum is computed over the
+ // same set of octets, including the checksum field. If the result
+ // is all 1 bits (-0 in 1's complement arithmetic), the check
+ // succeeds.
+ if h.CalculateChecksum() != 0xffff {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+
// As per RFC 1122 section 3.2.1.3:
// When a host sends any datagram, the IP source address MUST
// be one of its own IP addresses (but not a broadcast or
@@ -455,6 +524,28 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
r.Stats().IP.MalformedFragmentsReceived.Increment()
return
}
+
+ // Set up a callback in case we need to send a Time Exceeded Message, as per
+ // RFC 792:
+ //
+ // If a host reassembling a fragmented datagram cannot complete the
+ // reassembly due to missing fragments within its time limit it discards
+ // the datagram, and it may send a time exceeded message.
+ //
+ // If fragment zero is not available then no time exceeded need be sent at
+ // all.
+ var releaseCB func(bool)
+ if start == 0 {
+ pkt := pkt.Clone()
+ r := r.Clone()
+ releaseCB = func(timedOut bool) {
+ if timedOut {
+ _ = e.protocol.returnError(&r, &icmpReasonReassemblyTimeout{}, pkt)
+ }
+ r.Release()
+ }
+ }
+
var ready bool
var err error
proto := h.Protocol()
@@ -472,6 +563,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
h.More(),
proto,
pkt.Data,
+ releaseCB,
)
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
@@ -481,9 +573,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
if !ready {
return
}
- }
+ // The reassembler doesn't take care of fixing up the header, so we need
+ // to do it here.
+ h.SetTotalLength(uint16(pkt.Data.Size() + len((h))))
+ h.SetFlagsFragmentOffset(0, 0)
+ }
r.Stats().IP.PacketsDelivered.Increment()
+
p := h.TransportProtocol()
if p == header.ICMPv4ProtocolNumber {
// TODO(gvisor.dev/issues/3810): when we sort out ICMP and transport
@@ -493,6 +590,27 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
e.handleICMP(r, pkt)
return
}
+ if len(h.Options()) != 0 {
+ // TODO(gvisor.dev/issue/4586):
+ // When we add forwarding support we should use the verified options
+ // rather than just throwing them away.
+ aux, _, err := processIPOptions(r, h.Options(), &optionUsageReceive{})
+ if err != nil {
+ switch {
+ case
+ errors.Is(err, header.ErrIPv4OptDuplicate),
+ errors.Is(err, errIPv4RecordRouteOptInvalidPointer),
+ errors.Is(err, errIPv4RecordRouteOptInvalidLength),
+ errors.Is(err, errIPv4TimestampOptInvalidLength),
+ errors.Is(err, errIPv4TimestampOptInvalidPointer),
+ errors.Is(err, errIPv4TimestampOptOverflow):
+ _ = e.protocol.returnError(r, &icmpReasonParamProblem{pointer: aux}, pkt)
+ e.protocol.stack.Stats().MalformedRcvdPackets.Increment()
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ }
+ return
+ }
+ }
switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res {
case stack.TransportPacketHandled:
@@ -727,26 +845,32 @@ func (p *protocol) SetForwarding(v bool) {
}
}
-// calculateMTU calculates the network-layer payload MTU based on the link-layer
-// payload mtu.
-func calculateMTU(mtu uint32) uint32 {
- if mtu > MaxTotalSize {
- mtu = MaxTotalSize
+// calculateNetworkMTU calculates the network-layer payload MTU based on the
+// link-layer payload mtu.
+func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, *tcpip.Error) {
+ if linkMTU < header.IPv4MinimumMTU {
+ return 0, tcpip.ErrInvalidEndpointState
}
- return mtu - header.IPv4MinimumSize
-}
-// calculateFragmentInnerMTU calculates the maximum number of bytes of
-// fragmentable data a fragment can have, based on the link layer mtu and pkt's
-// network header size.
-func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 {
- if mtu > MaxTotalSize {
- mtu = MaxTotalSize
+ // As per RFC 791 section 3.1, an IPv4 header cannot exceed 60 bytes in
+ // length:
+ // The maximal internet header is 60 octets, and a typical internet header
+ // is 20 octets, allowing a margin for headers of higher level protocols.
+ if networkHeaderSize > header.IPv4MaximumHeaderSize {
+ return 0, tcpip.ErrMalformedHeader
}
- mtu -= uint32(pkt.NetworkHeader().View().Size())
- // Round the MTU down to align to 8 bytes.
- mtu &^= 7
- return mtu
+
+ networkMTU := linkMTU
+ if networkMTU > MaxTotalSize {
+ networkMTU = MaxTotalSize
+ }
+
+ return networkMTU - uint32(networkHeaderSize), nil
+}
+
+func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool {
+ payload := pkt.TransportHeader().View().Size() + pkt.Data.Size()
+ return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU
}
// addressToUint32 translates an IPv4 address into its little endian uint32
@@ -785,7 +909,7 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
ids: ids,
hashIV: hashIV,
defaultTTL: DefaultTTL,
- fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()),
+ fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()),
}
}
@@ -811,3 +935,322 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader head
return fragPkt, more
}
+
+// optionAction describes possible actions that may be taken on an option
+// while processing it.
+type optionAction uint8
+
+const (
+ // optionRemove says that the option should not be in the output option set.
+ optionRemove optionAction = iota
+
+ // optionProcess says that the option should be fully processed.
+ optionProcess
+
+ // optionVerify says the option should be checked and passed unchanged.
+ optionVerify
+
+ // optionPass says to pass the output set without checking.
+ optionPass
+)
+
+// optionActions list what to do for each option in a given scenario.
+type optionActions struct {
+ // timestamp controls what to do with a Timestamp option.
+ timestamp optionAction
+
+ // recordroute controls what to do with a Record Route option.
+ recordRoute optionAction
+
+ // unknown controls what to do with an unknown option.
+ unknown optionAction
+}
+
+// optionsUsage specifies the ways options may be operated upon for a given
+// scenario during packet processing.
+type optionsUsage interface {
+ actions() optionActions
+}
+
+// optionUsageReceive implements optionsUsage for received packets.
+type optionUsageReceive struct{}
+
+// actions implements optionsUsage.
+func (*optionUsageReceive) actions() optionActions {
+ return optionActions{
+ timestamp: optionVerify,
+ recordRoute: optionVerify,
+ unknown: optionPass,
+ }
+}
+
+// TODO(gvisor.dev/issue/4586): Add an entry here for forwarding when it
+// is enabled (Process, Process, Pass) and for fragmenting (Process, Process,
+// Pass for frag1, but Remove,Remove,Remove for all other frags).
+
+// optionUsageEcho implements optionsUsage for echo packet processing.
+type optionUsageEcho struct{}
+
+// actions implements optionsUsage.
+func (*optionUsageEcho) actions() optionActions {
+ return optionActions{
+ timestamp: optionProcess,
+ recordRoute: optionProcess,
+ unknown: optionRemove,
+ }
+}
+
+var (
+ errIPv4TimestampOptInvalidLength = errors.New("invalid Timestamp length")
+ errIPv4TimestampOptInvalidPointer = errors.New("invalid Timestamp pointer")
+ errIPv4TimestampOptOverflow = errors.New("overflow in Timestamp")
+ errIPv4TimestampOptInvalidFlags = errors.New("invalid Timestamp flags")
+)
+
+// handleTimestamp does any required processing on a Timestamp option
+// in place.
+func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Address, clock tcpip.Clock, usage optionsUsage) (uint8, error) {
+ flags := tsOpt.Flags()
+ var entrySize uint8
+ switch flags {
+ case header.IPv4OptionTimestampOnlyFlag:
+ entrySize = header.IPv4OptionTimestampSize
+ case
+ header.IPv4OptionTimestampWithIPFlag,
+ header.IPv4OptionTimestampWithPredefinedIPFlag:
+ entrySize = header.IPv4OptionTimestampWithAddrSize
+ default:
+ return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptInvalidFlags
+ }
+
+ pointer := tsOpt.Pointer()
+ // To simplify processing below, base further work on the array of timestamps
+ // beyond the header, rather than on the whole option. Also to aid
+ // calculations set 'nextSlot' to be 0 based as in the packet it is 1 based.
+ nextSlot := pointer - (header.IPv4OptionTimestampHdrLength + 1)
+ optLen := tsOpt.Size()
+ dataLength := optLen - header.IPv4OptionTimestampHdrLength
+
+ // In the section below, we verify the pointer, length and overflow counter
+ // fields of the option. The distinction is in which byte you return as being
+ // in error in the ICMP packet. Offsets 1 (length), 2 pointer)
+ // or 3 (overflowed counter).
+ //
+ // The following RFC sections cover this section:
+ //
+ // RFC 791 (page 22):
+ // If there is some room but not enough room for a full timestamp
+ // to be inserted, or the overflow count itself overflows, the
+ // original datagram is considered to be in error and is discarded.
+ // In either case an ICMP parameter problem message may be sent to
+ // the source host [3].
+ //
+ // You can get this situation in two ways. Firstly if the data area is not
+ // a multiple of the entry size or secondly, if the pointer is not at a
+ // multiple of the entry size. The wording of the RFC suggests that
+ // this is not an error until you actually run out of space.
+ if pointer > optLen {
+ // RFC 791 (page 22) says we should switch to using the overflow count.
+ // If the timestamp data area is already full (the pointer exceeds
+ // the length) the datagram is forwarded without inserting the
+ // timestamp, but the overflow count is incremented by one.
+ if flags == header.IPv4OptionTimestampWithPredefinedIPFlag {
+ // By definition we have nothing to do.
+ return 0, nil
+ }
+
+ if tsOpt.IncOverflow() != 0 {
+ return 0, nil
+ }
+ // The overflow count is also full.
+ return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptOverflow
+ }
+ if nextSlot+entrySize > dataLength {
+ // The data area isn't full but there isn't room for a new entry.
+ // Either Length or Pointer could be bad.
+ if false {
+ // We must select Pointer for Linux compatibility, even if
+ // only the length is bad.
+ // The Linux code is at (in October 2020)
+ // https://github.com/torvalds/linux/blob/bbf5c979011a099af5dc76498918ed7df445635b/net/ipv4/ip_options.c#L367-L370
+ // if (optptr[2]+3 > optlen) {
+ // pp_ptr = optptr + 2;
+ // goto error;
+ // }
+ // which doesn't distinguish between which of optptr[2] or optlen
+ // is wrong, but just arbitrarily decides on optptr+2.
+ if dataLength%entrySize != 0 {
+ // The Data section size should be a multiple of the expected
+ // timestamp entry size.
+ return header.IPv4OptionLengthOffset, errIPv4TimestampOptInvalidLength
+ }
+ // If the size is OK, the pointer must be corrupted.
+ }
+ return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer
+ }
+
+ if usage.actions().timestamp == optionProcess {
+ tsOpt.UpdateTimestamp(localAddress, clock)
+ }
+ return 0, nil
+}
+
+var (
+ errIPv4RecordRouteOptInvalidLength = errors.New("invalid length in Record Route")
+ errIPv4RecordRouteOptInvalidPointer = errors.New("invalid pointer in Record Route")
+)
+
+// handleRecordRoute checks and processes a Record route option. It is much
+// like the timestamp type 1 option, but without timestamps. The passed in
+// address is stored in the option in the correct spot if possible.
+func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Address, usage optionsUsage) (uint8, error) {
+ optlen := rrOpt.Size()
+
+ if optlen < header.IPv4AddressSize+header.IPv4OptionRecordRouteHdrLength {
+ return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength
+ }
+
+ nextSlot := rrOpt.Pointer() - 1 // Pointer is 1 based.
+
+ // RFC 791 page 21 says
+ // If the route data area is already full (the pointer exceeds the
+ // length) the datagram is forwarded without inserting the address
+ // into the recorded route. If there is some room but not enough
+ // room for a full address to be inserted, the original datagram is
+ // considered to be in error and is discarded. In either case an
+ // ICMP parameter problem message may be sent to the source
+ // host.
+ // The use of the words "In either case" suggests that a 'full' RR option
+ // could generate an ICMP at every hop after it fills up. We chose to not
+ // do this (as do most implementations). It is probable that the inclusion
+ // of these words is a copy/paste error from the timestamp option where
+ // there are two failure reasons given.
+ if nextSlot >= optlen {
+ return 0, nil
+ }
+
+ // The data area isn't full but there isn't room for a new entry.
+ // Either Length or Pointer could be bad. We must select Pointer for Linux
+ // compatibility, even if only the length is bad.
+ if nextSlot+header.IPv4AddressSize > optlen {
+ if false {
+ // This is what we would do if we were not being Linux compatible.
+ // Check for bad pointer or length value. Must be a multiple of 4 after
+ // accounting for the 3 byte header and not within that header.
+ // RFC 791, page 20 says:
+ // The pointer is relative to this option, and the
+ // smallest legal value for the pointer is 4.
+ //
+ // A recorded route is composed of a series of internet addresses.
+ // Each internet address is 32 bits or 4 octets.
+ // Linux skips this test so we must too. See Linux code at:
+ // https://github.com/torvalds/linux/blob/bbf5c979011a099af5dc76498918ed7df445635b/net/ipv4/ip_options.c#L338-L341
+ // if (optptr[2]+3 > optlen) {
+ // pp_ptr = optptr + 2;
+ // goto error;
+ // }
+ if (optlen-header.IPv4OptionRecordRouteHdrLength)%header.IPv4AddressSize != 0 {
+ // Length is bad, not on integral number of slots.
+ return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength
+ }
+ // If not length, the fault must be with the pointer.
+ }
+ return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer
+ }
+ if usage.actions().recordRoute == optionVerify {
+ return 0, nil
+ }
+ rrOpt.StoreAddress(localAddress)
+ return 0, nil
+}
+
+// processIPOptions parses the IPv4 options and produces a new set of options
+// suitable for use in the next step of packet processing as informed by usage.
+// The original will not be touched.
+//
+// Returns
+// - The location of an error if there was one (or 0 if no error)
+// - If there is an error, information as to what it was was.
+// - The replacement option set.
+func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) {
+
+ opts := header.IPv4Options(orig)
+ optIter := opts.MakeIterator()
+
+ // Each option other than NOP must only appear (RFC 791 section 3.1, at the
+ // definition of every type). Keep track of each of the possible types in
+ // the 8 bit 'type' field.
+ var seenOptions [math.MaxUint8 + 1]bool
+
+ // TODO(gvisor.dev/issue/4586):
+ // This will need tweaking when we start really forwarding packets
+ // as we may need to get two addresses, for rx and tx interfaces.
+ // We will also have to take usage into account.
+ prefixedAddress, err := r.Stack().GetMainNICAddress(r.NICID(), ProtocolNumber)
+ localAddress := prefixedAddress.Address
+ if err != nil {
+ if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) {
+ return 0 /* errCursor */, nil, header.ErrIPv4OptionAddress
+ }
+ localAddress = r.LocalAddress
+ }
+
+ for {
+ option, done, err := optIter.Next()
+ if done || err != nil {
+ return optIter.ErrCursor, optIter.Finalize(), err
+ }
+ optType := option.Type()
+ if optType == header.IPv4OptionNOPType {
+ optIter.PushNOPOrEnd(optType)
+ continue
+ }
+ if optType == header.IPv4OptionListEndType {
+ optIter.PushNOPOrEnd(optType)
+ return 0 /* errCursor */, optIter.Finalize(), nil /* err */
+ }
+
+ // check for repeating options (multiple NOPs are OK)
+ if seenOptions[optType] {
+ return optIter.ErrCursor, nil, header.ErrIPv4OptDuplicate
+ }
+ seenOptions[optType] = true
+
+ optLen := int(option.Size())
+ switch option := option.(type) {
+ case *header.IPv4OptionTimestamp:
+ r.Stats().IP.OptionTSReceived.Increment()
+ if usage.actions().timestamp != optionRemove {
+ clock := r.Stack().Clock()
+ newBuffer := optIter.RemainingBuffer()[:len(*option)]
+ _ = copy(newBuffer, option.Contents())
+ offset, err := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage)
+ if err != nil {
+ return optIter.ErrCursor + offset, nil, err
+ }
+ optIter.ConsumeBuffer(optLen)
+ }
+
+ case *header.IPv4OptionRecordRoute:
+ r.Stats().IP.OptionRRReceived.Increment()
+ if usage.actions().recordRoute != optionRemove {
+ newBuffer := optIter.RemainingBuffer()[:len(*option)]
+ _ = copy(newBuffer, option.Contents())
+ offset, err := handleRecordRoute(header.IPv4OptionRecordRoute(newBuffer), localAddress, usage)
+ if err != nil {
+ return optIter.ErrCursor + offset, nil, err
+ }
+ optIter.ConsumeBuffer(optLen)
+ }
+
+ default:
+ r.Stats().IP.OptionUnknownReceived.Increment()
+ if usage.actions().unknown == optionPass {
+ newBuffer := optIter.RemainingBuffer()[:optLen]
+ // Arguments already heavily checked.. ignore result.
+ _ = copy(newBuffer, option.Contents())
+ optIter.ConsumeBuffer(optLen)
+ }
+ }
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 819b5d71f..61672a5ff 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -21,11 +21,13 @@ import (
"math"
"net"
"testing"
+ "time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
@@ -39,13 +41,17 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+const (
+ extraHeaderReserve = 50
+ defaultMTU = 65536
+)
+
func TestExcludeBroadcast(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- const defaultMTU = 65536
ep := stack.LinkEndpoint(channel.New(256, defaultMTU, ""))
if testing.Verbose() {
ep = sniffer.New(ep)
@@ -101,7 +107,6 @@ func TestExcludeBroadcast(t *testing.T) {
// checks the response.
func TestIPv4Sanity(t *testing.T) {
const (
- defaultMTU = header.IPv6MinimumMTU
ttl = 255
nicID = 1
randomSequence = 123
@@ -116,23 +121,34 @@ func TestIPv4Sanity(t *testing.T) {
)
tests := []struct {
- name string
- headerLength uint8 // value of 0 means "use correct size"
- maxTotalLength uint16
- transportProtocol uint8
- TTL uint8
- shouldFail bool
- expectICMP bool
- ICMPType header.ICMPv4Type
- ICMPCode header.ICMPv4Code
- options []byte
+ name string
+ headerLength uint8 // value of 0 means "use correct size"
+ badHeaderChecksum bool
+ maxTotalLength uint16
+ transportProtocol uint8
+ TTL uint8
+ options []byte
+ replyOptions []byte // if succeeds, reply should look like this
+ shouldFail bool
+ expectErrorICMP bool
+ ICMPType header.ICMPv4Type
+ ICMPCode header.ICMPv4Code
+ paramProblemPointer uint8
}{
{
- name: "valid",
- maxTotalLength: defaultMTU,
+ name: "valid no options",
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
},
+ {
+ name: "bad header checksum",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ badHeaderChecksum: true,
+ shouldFail: true,
+ },
// The TTL tests check that we are not rejecting an incoming packet
// with a zero or one TTL, which has been a point of confusion in the
// past as RFC 791 says: "If this field contains the value zero, then the
@@ -146,47 +162,47 @@ func TestIPv4Sanity(t *testing.T) {
// received with TTL less than 2.
{
name: "zero TTL",
- maxTotalLength: defaultMTU,
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: 0,
- shouldFail: false,
},
{
name: "one TTL",
- maxTotalLength: defaultMTU,
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: 1,
- shouldFail: false,
},
{
name: "End options",
- maxTotalLength: defaultMTU,
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
options: []byte{0, 0, 0, 0},
+ replyOptions: []byte{0, 0, 0, 0},
},
{
name: "NOP options",
- maxTotalLength: defaultMTU,
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
options: []byte{1, 1, 1, 1},
+ replyOptions: []byte{1, 1, 1, 1},
},
{
name: "NOP and End options",
- maxTotalLength: defaultMTU,
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
options: []byte{1, 1, 0, 0},
+ replyOptions: []byte{1, 1, 0, 0},
},
{
name: "bad header length",
headerLength: header.IPv4MinimumSize - 1,
- maxTotalLength: defaultMTU,
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
shouldFail: true,
- expectICMP: false,
},
{
name: "bad total length (0)",
@@ -194,7 +210,6 @@ func TestIPv4Sanity(t *testing.T) {
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
shouldFail: true,
- expectICMP: false,
},
{
name: "bad total length (ip - 1)",
@@ -202,7 +217,6 @@ func TestIPv4Sanity(t *testing.T) {
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
shouldFail: true,
- expectICMP: false,
},
{
name: "bad total length (ip + icmp - 1)",
@@ -210,28 +224,361 @@ func TestIPv4Sanity(t *testing.T) {
transportProtocol: uint8(header.ICMPv4ProtocolNumber),
TTL: ttl,
shouldFail: true,
- expectICMP: false,
},
{
name: "bad protocol",
- maxTotalLength: defaultMTU,
+ maxTotalLength: ipv4.MaxTotalSize,
transportProtocol: 99,
TTL: ttl,
shouldFail: true,
- expectICMP: true,
+ expectErrorICMP: true,
ICMPType: header.ICMPv4DstUnreachable,
ICMPCode: header.ICMPv4ProtoUnreachable,
},
+ {
+ name: "timestamp option overflow",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 12, 13, 0x11,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ },
+ replyOptions: []byte{
+ 68, 12, 13, 0x21,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ },
+ },
+ {
+ name: "timestamp option overflow full",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 12, 13, 0xF1,
+ // ^ Counter full (15/0xF)
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 3,
+ replyOptions: []byte{},
+ },
+ {
+ name: "unknown option",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{10, 4, 9, 0},
+ // ^^
+ // The unknown option should be stripped out of the reply.
+ replyOptions: []byte{},
+ },
+ {
+ name: "bad option - length 0",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 0, 9, 0,
+ // ^
+ 1, 2, 3, 4,
+ },
+ shouldFail: true,
+ },
+ {
+ name: "bad option - length big",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 9, 9, 0,
+ // ^
+ // There are only 8 bytes allocated to options so 9 bytes of timestamp
+ // space is not possible. (Second byte)
+ 1, 2, 3, 4,
+ },
+ shouldFail: true,
+ },
+ {
+ // This tests for some linux compatible behaviour.
+ // The ICMP pointer returned is 22 for Linux but the
+ // error is actually in spot 21.
+ name: "bad option - length bad",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ // Timestamps are in multiples of 4 or 8 but never 7.
+ // The option space should be padded out.
+ options: []byte{
+ 68, 7, 5, 0,
+ // ^ ^ Linux points here which is wrong.
+ // | Not a multiple of 4
+ 1, 2, 3,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ {
+ name: "multiple type 0 with room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 24, 21, 0x00,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 0, 0, 0, 0,
+ },
+ replyOptions: []byte{
+ 68, 24, 25, 0x00,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
+ },
+ },
+ {
+ // The timestamp area is full so add to the overflow count.
+ name: "multiple type 1 timestamps",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 20, 21, 0x11,
+ // ^
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 192, 168, 1, 13,
+ 5, 6, 7, 8,
+ },
+ // Overflow count is the top nibble of the 4th byte.
+ replyOptions: []byte{
+ 68, 20, 21, 0x21,
+ // ^
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 192, 168, 1, 13,
+ 5, 6, 7, 8,
+ },
+ },
+ {
+ name: "multiple type 1 timestamps with room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 28, 21, 0x01,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 192, 168, 1, 13,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ replyOptions: []byte{
+ 68, 28, 29, 0x01,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 192, 168, 1, 13,
+ 5, 6, 7, 8,
+ 192, 168, 1, 58, // New IP Address.
+ 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock
+ },
+ },
+ {
+ // Needs 8 bytes for a type 1 timestamp but there are only 4 free.
+ name: "bad timer element alignment",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 20, 17, 0x01,
+ // ^^ ^^ 20 byte area, next free spot at 17.
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ },
+ // End of option list with illegal option after it, which should be ignored.
+ {
+ name: "end of options list",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 68, 12, 13, 0x11,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 0, 10, 3, 99,
+ },
+ replyOptions: []byte{
+ 68, 12, 13, 0x21,
+ 192, 168, 1, 12,
+ 1, 2, 3, 4,
+ 0, 0, 0, 0, // 3 bytes unknown option
+ }, // ^ End of options hides following bytes.
+ },
+ {
+ // Timestamp with a size too small.
+ name: "timestamp truncated",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{68, 1, 0, 0},
+ // ^ Smallest possible is 8.
+ shouldFail: true,
+ },
+ {
+ name: "single record route with room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 7, 4, // 3 byte header
+ 0, 0, 0, 0,
+ 0,
+ },
+ replyOptions: []byte{
+ 7, 7, 8, // 3 byte header
+ 192, 168, 1, 58, // New IP Address.
+ 0, // padding to multiple of 4 bytes.
+ },
+ },
+ {
+ name: "multiple record route with room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 23, 20, // 3 byte header
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 0, 0, 0, 0,
+ 0,
+ },
+ replyOptions: []byte{
+ 7, 23, 24,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 192, 168, 1, 58, // New IP Address.
+ 0, // padding to multiple of 4 bytes.
+ },
+ },
+ {
+ name: "single record route with no room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 7, 8, // 3 byte header
+ 1, 2, 3, 4,
+ 0,
+ },
+ replyOptions: []byte{
+ 7, 7, 8, // 3 byte header
+ 1, 2, 3, 4,
+ 0, // padding to multiple of 4 bytes.
+ },
+ },
+ {
+ // Unlike timestamp, this should just succeed.
+ name: "multiple record route with no room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 23, 24, // 3 byte header
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 0,
+ },
+ replyOptions: []byte{
+ 7, 23, 24,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 0, // padding to multiple of 4 bytes.
+ },
+ },
+ {
+ // Confirm linux bug for bug compatibility.
+ // Linux returns slot 22 but the error is in slot 21.
+ name: "multiple record route with not enough room",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 8, 8, // 3 byte header
+ // ^ ^ Linux points here. We must too.
+ // | Not enough room. 1 byte free, need 4.
+ 1, 2, 3, 4,
+ 0,
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 2,
+ replyOptions: []byte{},
+ },
+ {
+ name: "duplicate record route",
+ maxTotalLength: ipv4.MaxTotalSize,
+ transportProtocol: uint8(header.ICMPv4ProtocolNumber),
+ TTL: ttl,
+ options: []byte{
+ 7, 7, 8, // 3 byte header
+ 1, 2, 3, 4,
+ 7, 7, 8, // 3 byte header
+ 1, 2, 3, 4,
+ 0, 0, // pad
+ },
+ shouldFail: true,
+ expectErrorICMP: true,
+ ICMPType: header.ICMPv4ParamProblem,
+ ICMPCode: header.ICMPv4UnusedCode,
+ paramProblemPointer: header.IPv4MinimumSize + 7,
+ replyOptions: []byte{},
+ },
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
+ Clock: clock,
})
// We expect at most a single packet in response to our ICMP Echo Request.
- e := channel.New(1, defaultMTU, "")
+ e := channel.New(1, ipv4.MaxTotalSize, "")
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
@@ -239,6 +586,9 @@ func TestIPv4Sanity(t *testing.T) {
if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err)
}
+ // Advance the clock by some unimportant amount to make
+ // sure it's all set up.
+ clock.Advance(time.Millisecond * 0x10203040)
// Default routes for IPv4 so ICMP can find a route to the remote
// node when attempting to send the ICMP Echo Reply.
@@ -288,6 +638,12 @@ func TestIPv4Sanity(t *testing.T) {
if test.headerLength != 0 {
ip.SetHeaderLength(test.headerLength)
}
+ ip.SetChecksum(0)
+ ipHeaderChecksum := ip.CalculateChecksum()
+ if test.badHeaderChecksum {
+ ipHeaderChecksum += 42
+ }
+ ip.SetChecksum(^ipHeaderChecksum)
requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
})
@@ -295,14 +651,20 @@ func TestIPv4Sanity(t *testing.T) {
reply, ok := e.Read()
if !ok {
if test.shouldFail {
- if test.expectICMP {
- t.Fatal("expected ICMP error response missing")
+ if test.expectErrorICMP {
+ t.Fatalf("ICMP error response (type %d, code %d) missing", test.ICMPType, test.ICMPCode)
}
return // Expected silent failure.
}
t.Fatal("expected ICMP echo reply missing")
}
+ // We didn't expect a packet. Register our surprise but carry on to
+ // provide more information about what we got.
+ if test.shouldFail && !test.expectErrorICMP {
+ t.Error("unexpected packet response")
+ }
+
// Check the route that brought the packet to us.
if reply.Route.LocalAddress != ipv4Addr.Address {
t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address)
@@ -311,57 +673,90 @@ func TestIPv4Sanity(t *testing.T) {
t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr)
}
- // Make sure it's all in one buffer.
- vv := buffer.NewVectorisedView(reply.Pkt.Size(), reply.Pkt.Views())
- replyIPHeader := header.IPv4(vv.ToView())
+ // Make sure it's all in one buffer for checker.
+ replyIPHeader := header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader()))
- // At this stage we only know it's an IP header so verify that much.
+ // At this stage we only know it's probably an IP+ICMP header so verify
+ // that much.
checker.IPv4(t, replyIPHeader,
checker.SrcAddr(ipv4Addr.Address),
checker.DstAddr(remoteIPv4Addr),
+ checker.ICMPv4(
+ checker.ICMPv4Checksum(),
+ ),
)
- // All expected responses are ICMP packets.
- if got, want := replyIPHeader.Protocol(), uint8(header.ICMPv4ProtocolNumber); got != want {
- t.Fatalf("not ICMP response, got protocol %d, want = %d", got, want)
+ // Don't proceed any further if the checker found problems.
+ if t.Failed() {
+ t.FailNow()
}
- replyICMPHeader := header.ICMPv4(replyIPHeader.Payload())
- // Sanity check the response.
+ // OK it's ICMP. We can safely look at the type now.
+ replyICMPHeader := header.ICMPv4(replyIPHeader.Payload())
switch replyICMPHeader.Type() {
- case header.ICMPv4DstUnreachable:
+ case header.ICMPv4ParamProblem:
+ if !test.shouldFail {
+ t.Fatalf("got Parameter Problem with pointer %d, wanted Echo Reply", replyICMPHeader.Pointer())
+ }
+ if !test.expectErrorICMP {
+ t.Fatalf("got Parameter Problem with pointer %d, wanted no response", replyICMPHeader.Pointer())
+ }
checker.IPv4(t, replyIPHeader,
checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())),
checker.IPv4HeaderLength(header.IPv4MinimumSize),
checker.ICMPv4(
+ checker.ICMPv4Type(test.ICMPType),
checker.ICMPv4Code(test.ICMPCode),
- checker.ICMPv4Checksum(),
+ checker.ICMPv4Pointer(test.paramProblemPointer),
checker.ICMPv4Payload([]byte(hdr.View())),
),
)
- if !test.shouldFail || !test.expectICMP {
- t.Fatalf("unexpected packet rejection, got ICMP error packet type %d, code %d",
+ return
+ case header.ICMPv4DstUnreachable:
+ if !test.shouldFail {
+ t.Fatalf("got ICMP error packet type %d, code %d, wanted Echo Reply",
+ header.ICMPv4DstUnreachable, replyICMPHeader.Code())
+ }
+ if !test.expectErrorICMP {
+ t.Fatalf("got ICMP error packet type %d, code %d, wanted no response",
header.ICMPv4DstUnreachable, replyICMPHeader.Code())
}
+ checker.IPv4(t, replyIPHeader,
+ checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.ICMPv4(
+ checker.ICMPv4Type(test.ICMPType),
+ checker.ICMPv4Code(test.ICMPCode),
+ checker.ICMPv4Payload([]byte(hdr.View())),
+ ),
+ )
return
case header.ICMPv4EchoReply:
+ if test.shouldFail {
+ if !test.expectErrorICMP {
+ t.Error("got Echo Reply packet, want no response")
+ } else {
+ t.Errorf("got Echo Reply, want ICMP error type %d, code %d", test.ICMPType, test.ICMPCode)
+ }
+ }
+ // If the IP options change size then the packet will change size, so
+ // some IP header fields will need to be adjusted for the checks.
+ sizeChange := len(test.replyOptions) - len(test.options)
+
checker.IPv4(t, replyIPHeader,
- checker.IPv4HeaderLength(ipHeaderLength),
- checker.IPv4Options(test.options),
- checker.IPFullLength(uint16(requestPkt.Size())),
+ checker.IPv4HeaderLength(ipHeaderLength+sizeChange),
+ checker.IPv4Options(test.replyOptions),
+ checker.IPFullLength(uint16(requestPkt.Size()+sizeChange)),
checker.ICMPv4(
+ checker.ICMPv4Checksum(),
checker.ICMPv4Code(header.ICMPv4UnusedCode),
checker.ICMPv4Seq(randomSequence),
checker.ICMPv4Ident(randomIdent),
- checker.ICMPv4Checksum(),
),
)
- if test.shouldFail {
- t.Fatalf("unexpected Echo Reply packet\n")
- }
default:
- t.Fatalf("unexpected ICMP response, got type %d, want = %d or %d",
- replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable)
+ t.Fatalf("unexpected ICMP response, got type %d, want = %d, %d or %d",
+ replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable, header.ICMPv4ParamProblem)
}
})
}
@@ -369,7 +764,7 @@ func TestIPv4Sanity(t *testing.T) {
// comparePayloads compared the contents of all the packets against the contents
// of the source packet.
-func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32) error {
+func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error {
// Make a complete array of the sourcePacket packet.
source := header.IPv4(packets[0].NetworkHeader().View())
vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views())
@@ -381,7 +776,6 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB
sourceCopy.SetChecksum(0)
sourceCopy.SetFlagsFragmentOffset(0, 0)
sourceCopy.SetTotalLength(0)
- var offset uint16
// Build up an array of the bytes sent.
var reassembledPayload buffer.VectorisedView
for i, packet := range packets {
@@ -391,35 +785,38 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB
if !fragmentIPHeader.IsValid(len(fragmentIPHeader)) {
return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeader))
}
- if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want {
- return fmt.Errorf("fragment #%d: fragmentIPHeader.CalculateChecksum() got %#x, want %#x", i, got, want)
- }
if got := len(fragmentIPHeader); got > int(mtu) {
return fmt.Errorf("fragment #%d: got len(fragmentIPHeader) = %d, want <= %d", i, got, mtu)
}
- if got, want := packet.AvailableHeaderBytes(), sourcePacket.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want {
- return fmt.Errorf("fragment #%d: should have the same available space for prepending as source: got %d, want %d", i, got, want)
+ if got := fragmentIPHeader.TransportProtocol(); got != proto {
+ return fmt.Errorf("fragment #%d: got fragmentIPHeader.TransportProtocol() = %d, want = %d", i, got, uint8(proto))
+ }
+ if got := packet.AvailableHeaderBytes(); got != extraHeaderReserve {
+ return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve)
}
if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want {
- return fmt.Errorf("fragment #%d: has wrong network protocol number: got %d, want %d", i, got, want)
+ return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, got, want)
+ }
+ if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want {
+ return fmt.Errorf("fragment #%d: got ip.CalculateChecksum() = %#x, want = %#x", i, got, want)
}
- if i < len(packets)-1 {
- sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset)
+ if wantFragments[i].more {
+ sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, wantFragments[i].offset)
} else {
- sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset)
+ sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, wantFragments[i].offset)
}
reassembledPayload.AppendView(packet.TransportHeader().View())
reassembledPayload.Append(packet.Data)
- offset += fragmentIPHeader.TotalLength() - uint16(fragmentIPHeader.HeaderLength())
// Clear out the checksum and length from the ip because we can't compare
// it.
- sourceCopy.SetTotalLength(uint16(len(fragmentIPHeader)))
+ sourceCopy.SetTotalLength(wantFragments[i].payloadSize + header.IPv4MinimumSize)
sourceCopy.SetChecksum(0)
sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum())
if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" {
- return fmt.Errorf("fragment #%d: fragmentIPHeader[:fragmentIPHeader.HeaderLength()] mismatch (-want +got):\n%s", i, diff)
+ return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff)
}
}
+
expected := buffer.View(source[source.HeaderLength():])
if diff := cmp.Diff(expected, reassembledPayload.ToView()); diff != "" {
return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
@@ -428,38 +825,122 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB
return nil
}
-func TestFragmentation(t *testing.T) {
- const ttl = 42
+type fragmentInfo struct {
+ offset uint16
+ more bool
+ payloadSize uint16
+}
- var manyPayloadViewsSizes [1000]int
- for i := range manyPayloadViewsSizes {
- manyPayloadViewsSizes[i] = 7
- }
- fragTests := []struct {
- description string
- mtu uint32
- gso *stack.GSO
- transportHeaderLength int
- extraHeaderReserveLength int
- payloadViewsSizes []int
- expectedFrags int
- }{
- {"No fragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1},
- {"No fragmentation with big header", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1},
- {"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2},
- {"Fragmented with gso nil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2},
- {"Fragmented with many views", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25},
- {"Fragmented with many views and prependable bytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25},
- {"Fragmented with big header", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2},
- {"Fragmented with big header and prependable bytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2},
- {"Fragmented with MTU smaller than header and prependable bytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6},
- }
+var fragmentationTests = []struct {
+ description string
+ mtu uint32
+ gso *stack.GSO
+ transportHeaderLength int
+ payloadSize int
+ wantFragments []fragmentInfo
+}{
+ {
+ description: "No fragmentation",
+ mtu: 1280,
+ gso: nil,
+ transportHeaderLength: 0,
+ payloadSize: 1000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1000, more: false},
+ },
+ },
+ {
+ description: "Fragmented",
+ mtu: 1280,
+ gso: nil,
+ transportHeaderLength: 0,
+ payloadSize: 2000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1256, more: true},
+ {offset: 1256, payloadSize: 744, more: false},
+ },
+ },
+ {
+ description: "Fragmented with the minimum mtu",
+ mtu: header.IPv4MinimumMTU,
+ gso: nil,
+ transportHeaderLength: 0,
+ payloadSize: 100,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 48, more: true},
+ {offset: 48, payloadSize: 48, more: true},
+ {offset: 96, payloadSize: 4, more: false},
+ },
+ },
+ {
+ description: "Fragmented with mtu not a multiple of 8",
+ mtu: header.IPv4MinimumMTU + 1,
+ gso: nil,
+ transportHeaderLength: 0,
+ payloadSize: 100,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 48, more: true},
+ {offset: 48, payloadSize: 48, more: true},
+ {offset: 96, payloadSize: 4, more: false},
+ },
+ },
+ {
+ description: "No fragmentation with big header",
+ mtu: 2000,
+ gso: nil,
+ transportHeaderLength: 100,
+ payloadSize: 1000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1100, more: false},
+ },
+ },
+ {
+ description: "Fragmented with gso none",
+ mtu: 1280,
+ gso: &stack.GSO{Type: stack.GSONone},
+ transportHeaderLength: 0,
+ payloadSize: 1400,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1256, more: true},
+ {offset: 1256, payloadSize: 144, more: false},
+ },
+ },
+ {
+ description: "Fragmented with big header",
+ mtu: 1280,
+ gso: nil,
+ transportHeaderLength: 100,
+ payloadSize: 1200,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1256, more: true},
+ {offset: 1256, payloadSize: 44, more: false},
+ },
+ },
+ {
+ description: "Fragmented with MTU smaller than header",
+ mtu: 300,
+ gso: nil,
+ transportHeaderLength: 1000,
+ payloadSize: 500,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 280, more: true},
+ {offset: 280, payloadSize: 280, more: true},
+ {offset: 560, payloadSize: 280, more: true},
+ {offset: 840, payloadSize: 280, more: true},
+ {offset: 1120, payloadSize: 280, more: true},
+ {offset: 1400, payloadSize: 100, more: false},
+ },
+ },
+}
+
+func TestFragmentationWritePacket(t *testing.T) {
+ const ttl = 42
- for _, ft := range fragTests {
+ for _, ft := range fragmentationTests {
t.Run(ft.description, func(t *testing.T) {
ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
r := buildRoute(t, ep)
- pkt := testutil.MakeRandPkt(ft.transportHeaderLength, ft.extraHeaderReserveLength, ft.payloadViewsSizes, header.IPv4ProtocolNumber)
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
source := pkt.Clone()
err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
@@ -469,60 +950,149 @@ func TestFragmentation(t *testing.T) {
if err != nil {
t.Fatalf("r.WritePacket(_, _, _) = %s", err)
}
-
- if got := len(ep.WrittenPackets); got != ft.expectedFrags {
- t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, ft.expectedFrags)
+ if got := len(ep.WrittenPackets); got != len(ft.wantFragments) {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments))
}
- if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want {
- t.Errorf("no errors yet got len(ep.WrittenPackets) = %d, want = %d", got, want)
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) {
+ t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments))
}
if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
}
- if err := compareFragments(ep.WrittenPackets, source, ft.mtu); err != nil {
+ if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
t.Error(err)
}
})
}
}
-// TestFragmentationErrors checks that errors are returned from write packet
+func TestFragmentationWritePackets(t *testing.T) {
+ const ttl = 42
+ writePacketsTests := []struct {
+ description string
+ insertBefore int
+ insertAfter int
+ }{
+ {
+ description: "Single packet",
+ insertBefore: 0,
+ insertAfter: 0,
+ },
+ {
+ description: "With packet before",
+ insertBefore: 1,
+ insertAfter: 0,
+ },
+ {
+ description: "With packet after",
+ insertBefore: 0,
+ insertAfter: 1,
+ },
+ {
+ description: "With packet before and after",
+ insertBefore: 1,
+ insertAfter: 1,
+ },
+ }
+ tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber)
+
+ for _, test := range writePacketsTests {
+ t.Run(test.description, func(t *testing.T) {
+ for _, ft := range fragmentationTests {
+ t.Run(ft.description, func(t *testing.T) {
+ var pkts stack.PacketBufferList
+ for i := 0; i < test.insertBefore; i++ {
+ pkts.PushBack(tinyPacket.Clone())
+ }
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
+ pkts.PushBack(pkt.Clone())
+ for i := 0; i < test.insertAfter; i++ {
+ pkts.PushBack(tinyPacket.Clone())
+ }
+
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
+ r := buildRoute(t, ep)
+
+ wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter
+ n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: ttl,
+ TOS: stack.DefaultTOS,
+ })
+ if err != nil {
+ t.Errorf("got WritePackets(_, _, _) = (_, %s), want = (_, nil)", err)
+ }
+ if n != wantTotalPackets {
+ t.Errorf("got WritePackets(_, _, _) = (%d, _), want = (%d, _)", n, wantTotalPackets)
+ }
+ if got := len(ep.WrittenPackets); got != wantTotalPackets {
+ t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets)
+ }
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets {
+ t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets)
+ }
+ if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != 0 {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
+ }
+
+ if wantTotalPackets == 0 {
+ return
+ }
+
+ fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore]
+ if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ t.Error(err)
+ }
+ })
+ }
+ })
+ }
+}
+
+// TestFragmentationErrors checks that errors are returned from WritePacket
// correctly.
func TestFragmentationErrors(t *testing.T) {
const ttl = 42
- expectedError := tcpip.ErrAborted
- fragTests := []struct {
+ tests := []struct {
description string
mtu uint32
transportHeaderLength int
payloadSize int
allowPackets int
- fragmentCount int
+ outgoingErrors int
+ mockError *tcpip.Error
+ wantError *tcpip.Error
}{
{
description: "No frag",
mtu: 2000,
- transportHeaderLength: 0,
payloadSize: 1000,
+ transportHeaderLength: 0,
allowPackets: 0,
- fragmentCount: 1,
+ outgoingErrors: 1,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
},
{
description: "Error on first frag",
mtu: 500,
- transportHeaderLength: 0,
payloadSize: 1000,
+ transportHeaderLength: 0,
allowPackets: 0,
- fragmentCount: 3,
+ outgoingErrors: 3,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
},
{
description: "Error on second frag",
mtu: 500,
- transportHeaderLength: 0,
payloadSize: 1000,
+ transportHeaderLength: 0,
allowPackets: 1,
- fragmentCount: 3,
+ outgoingErrors: 2,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
},
{
description: "Error on first frag MTU smaller than header",
@@ -530,28 +1100,40 @@ func TestFragmentationErrors(t *testing.T) {
transportHeaderLength: 1000,
payloadSize: 500,
allowPackets: 0,
- fragmentCount: 4,
+ outgoingErrors: 4,
+ mockError: tcpip.ErrAborted,
+ wantError: tcpip.ErrAborted,
+ },
+ {
+ description: "Error when MTU is smaller than IPv4 minimum MTU",
+ mtu: header.IPv4MinimumMTU - 1,
+ transportHeaderLength: 0,
+ payloadSize: 500,
+ allowPackets: 0,
+ outgoingErrors: 1,
+ mockError: nil,
+ wantError: tcpip.ErrInvalidEndpointState,
},
}
- for _, ft := range fragTests {
+ for _, ft := range tests {
t.Run(ft.description, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(ft.mtu, expectedError, ft.allowPackets)
+ pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
+ ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
r := buildRoute(t, ep)
- pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber)
err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
Protocol: tcp.ProtocolNumber,
TTL: ttl,
TOS: stack.DefaultTOS,
}, pkt)
- if err != expectedError {
- t.Errorf("got WritePacket() = %s, want = %s", err, expectedError)
+ if err != ft.wantError {
+ t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError)
}
- if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want {
- t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want)
+ if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets {
+ t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets)
}
- if got, want := int(r.Stats().IP.OutgoingPacketErrors.Value()), ft.fragmentCount-ft.allowPackets; got != want {
- t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, want)
+ if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors {
+ t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors)
}
})
}
@@ -583,7 +1165,6 @@ func TestInvalidFragments(t *testing.T) {
autoChecksum bool // if true, the Checksum field will be overwritten.
}
- // These packets have both IHL and TotalLength set to 0.
tests := []struct {
name string
fragments []fragmentData
@@ -823,7 +1404,6 @@ func TestInvalidFragments(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
-
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
@@ -866,6 +1446,259 @@ func TestInvalidFragments(t *testing.T) {
}
}
+func TestFragmentReassemblyTimeout(t *testing.T) {
+ const (
+ nicID = 1
+ linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ addr1 = "\x0a\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x02"
+ tos = 0
+ ident = 1
+ ttl = 48
+ protocol = 99
+ data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT"
+ )
+
+ type fragmentData struct {
+ ipv4fields header.IPv4Fields
+ payload []byte
+ }
+
+ tests := []struct {
+ name string
+ fragments []fragmentData
+ expectICMP bool
+ }{
+ {
+ name: "first fragment only",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 16,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[:16],
+ },
+ },
+ expectICMP: true,
+ },
+ {
+ name: "two first fragments",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 16,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[:16],
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 16,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[:16],
+ },
+ },
+ expectICMP: true,
+ },
+ {
+ name: "second fragment only",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16),
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 8,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[16:],
+ },
+ },
+ expectICMP: false,
+ },
+ {
+ name: "two fragments with a gap",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 8,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[:8],
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16),
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 16,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[16:],
+ },
+ },
+ expectICMP: true,
+ },
+ {
+ name: "two fragments with a gap in reverse order",
+ fragments: []fragmentData{
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16),
+ ID: ident,
+ Flags: 0,
+ FragmentOffset: 16,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[16:],
+ },
+ {
+ ipv4fields: header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: tos,
+ TotalLength: header.IPv4MinimumSize + 8,
+ ID: ident,
+ Flags: header.IPv4FlagMoreFragments,
+ FragmentOffset: 0,
+ TTL: ttl,
+ Protocol: protocol,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ payload: []byte(data)[:8],
+ },
+ },
+ expectICMP: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ ipv4.NewProtocol,
+ },
+ Clock: clock,
+ })
+ e := channel.New(1, 1500, linkAddr)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ }})
+
+ var firstFragmentSent buffer.View
+ for _, f := range test.fragments {
+ pktSize := header.IPv4MinimumSize
+ hdr := buffer.NewPrependable(pktSize)
+
+ ip := header.IPv4(hdr.Prepend(pktSize))
+ ip.Encode(&f.ipv4fields)
+
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ vv := hdr.View().ToVectorisedView()
+ vv.AppendView(f.payload)
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ })
+
+ if firstFragmentSent == nil && ip.FragmentOffset() == 0 {
+ firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader())
+ }
+
+ e.InjectInbound(header.IPv4ProtocolNumber, pkt)
+ }
+
+ clock.Advance(ipv4.ReassembleTimeout)
+
+ reply, ok := e.Read()
+ if !test.expectICMP {
+ if ok {
+ t.Fatalf("unexpected ICMP error message received: %#v", reply)
+ }
+ return
+ }
+ if !ok {
+ t.Fatal("expected ICMP error message missing")
+ }
+ if firstFragmentSent == nil {
+ t.Fatalf("unexpected ICMP error message received: %#v", reply)
+ }
+
+ checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
+ checker.SrcAddr(addr2),
+ checker.DstAddr(addr1),
+ checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+firstFragmentSent.Size())),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4TimeExceeded),
+ checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout),
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Payload([]byte(firstFragmentSent)),
+ ),
+ )
+ })
+ }
+}
+
// TestReceiveFragments feeds fragments in through the incoming packet path to
// test reassembly
func TestReceiveFragments(t *testing.T) {
@@ -1281,6 +2114,7 @@ func TestReceiveFragments(t *testing.T) {
SrcAddr: frag.srcAddr,
DstAddr: frag.dstAddr,
})
+ ip.SetChecksum(^ip.CalculateChecksum())
vv := hdr.View().ToVectorisedView()
vv.AppendView(frag.payload)
@@ -1415,7 +2249,7 @@ func TestWriteStats(t *testing.T) {
t.Run(writer.name, func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumSize+header.UDPMinimumSize, tcpip.ErrInvalidEndpointState, test.allowPackets)
+ ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets)
rt := buildRoute(t, ep)
var pkts stack.PacketBufferList
@@ -1549,6 +2383,7 @@ func TestPacketQueing(t *testing.T) {
SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
})
+ ip.SetChecksum(^ip.CalculateChecksum())
e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
}))
@@ -1592,6 +2427,7 @@ func TestPacketQueing(t *testing.T) {
SrcAddr: host2IPv4Addr.AddressWithPrefix.Address,
DstAddr: host1IPv4Addr.AddressWithPrefix.Address,
})
+ ip.SetChecksum(^ip.CalculateChecksum())
e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
}))
@@ -1619,7 +2455,7 @@ func TestPacketQueing(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr)
+ e := channel.New(1, defaultMTU, host1NICLinkAddr)
e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol},
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index a30437f02..0ac24a6fb 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -36,6 +36,7 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index ead6bedcb..3c15e41a7 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -170,8 +170,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
return
}
pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
- mtu := header.ICMPv6(hdr).MTU()
- e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt)
+ networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize)
+ if err != nil {
+ networkMTU = 0
+ }
+ e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt)
case header.ICMPv6DstUnreachable:
received.DstUnreachable.Increment()
@@ -284,7 +287,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
received.Invalid.Increment()
return
} else if e.nud != nil {
- e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
+ e.nud.HandleProbe(r.RemoteAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
} else {
e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr)
}
@@ -555,7 +558,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
if e.nud != nil {
// A RS with a specified source IP address modifies the NUD state
// machine in the same way a reachability probe would.
- e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
+ e.nud.HandleProbe(r.RemoteAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
}
}
@@ -605,7 +608,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// If the RA has the source link layer option, update the link address
// cache with the link address for the advertised router.
if len(sourceLinkAddr) != 0 && e.nud != nil {
- e.nud.HandleProbe(routerAddr, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
+ e.nud.HandleProbe(routerAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
}
e.mu.Lock()
@@ -648,52 +651,46 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements stack.LinkAddressResolver.
-func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
- // TODO(b/148672031): Use stack.FindRoute instead of manually creating the
- // route here. Note, we would need the nicID to do this properly so the right
- // NIC (associated to linkEP) is used to send the NDP NS message.
- r := stack.Route{
- LocalAddress: localAddr,
- RemoteAddress: addr,
- LocalLinkAddress: linkEP.LinkAddress(),
- RemoteLinkAddress: remoteLinkAddr,
+func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error {
+ remoteAddr := targetAddr
+ if len(remoteLinkAddr) == 0 {
+ remoteAddr = header.SolicitedNodeAddr(targetAddr)
+ remoteLinkAddr = header.EthernetAddressFromMulticastIPv6Address(remoteAddr)
}
- // If a remote address is not already known, then send a multicast
- // solicitation since multicast addresses have a static mapping to link
- // addresses.
- if len(r.RemoteLinkAddress) == 0 {
- r.RemoteAddress = header.SolicitedNodeAddr(addr)
- r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(r.RemoteAddress)
+ r, err := p.stack.FindRoute(nic.ID(), localAddr, remoteAddr, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ return err
}
+ defer r.Release()
+ r.ResolveWith(remoteLinkAddr)
optsSerializer := header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(linkEP.LinkAddress()),
+ header.NDPSourceLinkLayerAddressOption(nic.LinkAddress()),
}
neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length()
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
- ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + neighborSolicitSize,
+ ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborSolicitSize,
})
pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize))
packet.SetType(header.ICMPv6NeighborSolicit)
ns := header.NDPNeighborSolicit(packet.NDPPayload())
- ns.SetTargetAddress(addr)
+ ns.SetTargetAddress(targetAddr)
ns.Options().Serialize(optsSerializer)
packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
- length := uint16(pkt.Size())
- ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: length,
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: header.NDPHopLimit,
- SrcAddr: r.LocalAddress,
- DstAddr: r.RemoteAddress,
- })
+ stat := p.stack.Stats().ICMP.V6PacketsSent
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ }, pkt); err != nil {
+ stat.Dropped.Increment()
+ return err
+ }
- // TODO(stijlist): count this in ICMP stats.
- return linkEP.WritePacket(&r, nil /* gso */, ProtocolNumber, pkt)
+ stat.NeighborSolicit.Increment()
+ return nil
}
// ResolveStaticAddress implements stack.LinkAddressResolver.
@@ -747,6 +744,13 @@ type icmpReasonPortUnreachable struct{}
func (*icmpReasonPortUnreachable) isICMPReason() {}
+// icmpReasonReassemblyTimeout is an error where insufficient fragments are
+// received to complete reassembly of a packet within a configured time after
+// the reception of the first-arriving fragment of that packet.
+type icmpReasonReassemblyTimeout struct{}
+
+func (*icmpReasonReassemblyTimeout) isICMPReason() {}
+
// returnError takes an error descriptor and generates the appropriate ICMP
// error packet for IPv6 and sends it.
func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error {
@@ -839,7 +843,9 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
if payloadLen > available {
payloadLen = available
}
- payload := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
+ payload := network.ToVectorisedView()
+ payload.AppendView(transport)
+ payload.Append(pkt.Data)
payload.CapLength(payloadLen)
newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -860,6 +866,10 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac
icmpHdr.SetType(header.ICMPv6DstUnreachable)
icmpHdr.SetCode(header.ICMPv6PortUnreachable)
counter = sent.DstUnreachable
+ case *icmpReasonReassemblyTimeout:
+ icmpHdr.SetType(header.ICMPv6TimeExceeded)
+ icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout)
+ counter = sent.TimeExceeded
default:
panic(fmt.Sprintf("unsupported ICMP type %T", reason))
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 8dc33c560..aa8b5f2e5 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -51,6 +51,7 @@ const (
var (
lladdr0 = header.LinkLocalAddr(linkAddr0)
lladdr1 = header.LinkLocalAddr(linkAddr1)
+ lladdr2 = header.LinkLocalAddr(linkAddr2)
)
type stubLinkEndpoint struct {
@@ -108,31 +109,27 @@ type stubNUDHandler struct {
var _ stack.NUDHandler = (*stubNUDHandler)(nil)
-func (s *stubNUDHandler) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes stack.LinkAddressResolver) {
+func (s *stubNUDHandler) HandleProbe(tcpip.Address, tcpip.NetworkProtocolNumber, tcpip.LinkAddress, stack.LinkAddressResolver) {
s.probeCount++
}
-func (s *stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags stack.ReachabilityConfirmationFlags) {
+func (s *stubNUDHandler) HandleConfirmation(tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) {
s.confirmationCount++
}
-func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) {
+func (*stubNUDHandler) HandleUpperLevelConfirmation(tcpip.Address) {
}
var _ stack.NetworkInterface = (*testInterface)(nil)
type testInterface struct {
- stack.NetworkLinkEndpoint
-
- linkAddr tcpip.LinkAddress
-}
+ stack.LinkEndpoint
-func (i *testInterface) LinkAddress() tcpip.LinkAddress {
- return i.linkAddr
+ nicID tcpip.NICID
}
func (*testInterface) ID() tcpip.NICID {
- return 0
+ return nicID
}
func (*testInterface) IsLoopback() bool {
@@ -147,6 +144,14 @@ func (*testInterface) Enabled() bool {
return true
}
+func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ r := stack.Route{
+ NetProto: protocol,
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+ return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt)
+}
+
func TestICMPCounts(t *testing.T) {
tests := []struct {
name string
@@ -1235,26 +1240,72 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
}
func TestLinkAddressRequest(t *testing.T) {
+ const nicID = 1
+
snaddr := header.SolicitedNodeAddr(lladdr0)
mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr)
tests := []struct {
- name string
- remoteLinkAddr tcpip.LinkAddress
- expectedLinkAddr tcpip.LinkAddress
- expectedAddr tcpip.Address
+ name string
+ nicAddr tcpip.Address
+ localAddr tcpip.Address
+ remoteLinkAddr tcpip.LinkAddress
+
+ expectedErr *tcpip.Error
+ expectedRemoteAddr tcpip.Address
+ expectedRemoteLinkAddr tcpip.LinkAddress
}{
{
- name: "Unicast",
- remoteLinkAddr: linkAddr1,
- expectedLinkAddr: linkAddr1,
- expectedAddr: lladdr0,
+ name: "Unicast",
+ nicAddr: lladdr1,
+ localAddr: lladdr1,
+ remoteLinkAddr: linkAddr1,
+ expectedRemoteAddr: lladdr0,
+ expectedRemoteLinkAddr: linkAddr1,
+ },
+ {
+ name: "Multicast",
+ nicAddr: lladdr1,
+ localAddr: lladdr1,
+ remoteLinkAddr: "",
+ expectedRemoteAddr: snaddr,
+ expectedRemoteLinkAddr: mcaddr,
+ },
+ {
+ name: "Unicast with unspecified source",
+ nicAddr: lladdr1,
+ remoteLinkAddr: linkAddr1,
+ expectedRemoteAddr: lladdr0,
+ expectedRemoteLinkAddr: linkAddr1,
},
{
- name: "Multicast",
- remoteLinkAddr: "",
- expectedLinkAddr: mcaddr,
- expectedAddr: snaddr,
+ name: "Multicast with unspecified source",
+ nicAddr: lladdr1,
+ remoteLinkAddr: "",
+ expectedRemoteAddr: snaddr,
+ expectedRemoteLinkAddr: mcaddr,
+ },
+ {
+ name: "Unicast with unassigned address",
+ localAddr: lladdr1,
+ remoteLinkAddr: linkAddr1,
+ expectedErr: tcpip.ErrNetworkUnreachable,
+ },
+ {
+ name: "Multicast with unassigned address",
+ localAddr: lladdr1,
+ remoteLinkAddr: "",
+ expectedErr: tcpip.ErrNetworkUnreachable,
+ },
+ {
+ name: "Unicast with no local address available",
+ remoteLinkAddr: linkAddr1,
+ expectedErr: tcpip.ErrNetworkUnreachable,
+ },
+ {
+ name: "Multicast with no local address available",
+ remoteLinkAddr: "",
+ expectedErr: tcpip.ErrNetworkUnreachable,
},
}
@@ -1269,26 +1320,43 @@ func TestLinkAddressRequest(t *testing.T) {
}
linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0)
- if err := linkRes.LinkAddressRequest(lladdr0, lladdr1, test.remoteLinkAddr, linkEP); err != nil {
- t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", lladdr0, lladdr1, test.remoteLinkAddr, err)
+ if err := s.CreateNIC(nicID, linkEP); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if len(test.nicAddr) != 0 {
+ if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err)
+ }
+ }
+
+ // We pass a test network interface to LinkAddressRequest with the same NIC
+ // ID and link endpoint used by the NIC we created earlier so that we can
+ // mock a link address request and observe the packets sent to the link
+ // endpoint even though the stack uses the real NIC.
+ if err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr {
+ t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", lladdr0, test.localAddr, test.remoteLinkAddr, err, test.expectedErr)
+ }
+
+ if test.expectedErr != nil {
+ return
}
pkt, ok := linkEP.Read()
if !ok {
t.Fatal("expected to send a link address request")
}
- if pkt.Route.RemoteLinkAddress != test.expectedLinkAddr {
- t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedLinkAddr)
+ if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr)
}
- if pkt.Route.RemoteAddress != test.expectedAddr {
- t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedAddr)
+ if pkt.Route.RemoteAddress != test.expectedRemoteAddr {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr)
}
if pkt.Route.LocalAddress != lladdr1 {
t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, lladdr1)
}
checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()),
checker.SrcAddr(lladdr1),
- checker.DstAddr(test.expectedAddr),
+ checker.DstAddr(test.expectedRemoteAddr),
checker.TTL(header.NDPHopLimit),
checker.NDPNS(
checker.NDPNSTargetAddress(lladdr0),
@@ -1698,7 +1766,7 @@ func TestCallsToNeighborCache(t *testing.T) {
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
nudHandler := &stubNUDHandler{}
- ep := netProto.NewEndpoint(&testInterface{linkAddr: linkAddr0}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{})
+ ep := netProto.NewEndpoint(&testInterface{LinkEndpoint: channel.New(0, header.IPv6MinimumMTU, linkAddr0)}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{})
defer ep.Close()
if err := ep.Enable(); err != nil {
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 2bd8f4ece..1e38f3a9d 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -41,12 +41,12 @@ const (
//
// Linux also uses 60 seconds for reassembly timeout:
// https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ipv6.h#L456
- reassembleTimeout = 60 * time.Second
+ ReassembleTimeout = 60 * time.Second
// ProtocolNumber is the ipv6 protocol number.
ProtocolNumber = header.IPv6ProtocolNumber
- // maxTotalSize is maximum size that can be encoded in the 16-bit
+ // maxPayloadSize is the maximum size that can be encoded in the 16-bit
// PayloadLength field of the ipv6 header.
maxPayloadSize = 0xffff
@@ -363,7 +363,11 @@ func (e *endpoint) DefaultTTL() uint8 {
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
func (e *endpoint) MTU() uint32 {
- return calculateMTU(e.nic.MTU())
+ networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv6MinimumSize)
+ if err != nil {
+ return 0
+ }
+ return networkMTU
}
// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
@@ -386,27 +390,40 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
pkt.NetworkProtocolNumber = ProtocolNumber
}
-func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool {
- return pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone)
+func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool {
+ payload := pkt.TransportHeader().View().Size() + pkt.Data.Size()
+ return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU
}
// handleFragments fragments pkt and calls the handler function on each
// fragment. It returns the number of fragments handled and the number of
// fragments left to be processed. The IP header must already be present in the
-// original packet. The mtu is the maximum size of the packets. The transport
-// header protocol number is required to avoid parsing the IPv6 extension
-// headers.
-func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
- fragMTU := int(calculateFragmentInnerMTU(mtu, pkt))
- if fragMTU < pkt.TransportHeader().View().Size() {
+// original packet. The transport header protocol number is required to avoid
+// parsing the IPv6 extension headers.
+func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) {
+ networkHeader := header.IPv6(pkt.NetworkHeader().View())
+
+ // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are
+ // supported for outbound packets, their length should not affect the fragment
+ // maximum payload length because they should only be transmitted once.
+ fragmentPayloadLen := (networkMTU - header.IPv6FragmentHeaderSize) &^ 7
+ if fragmentPayloadLen < header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit {
+ // We need at least 8 bytes of space left for the fragmentable part because
+ // the fragment payload must obviously be non-zero and must be a multiple
+ // of 8 as per RFC 8200 section 4.5:
+ // Each complete fragment, except possibly the last ("rightmost") one, is
+ // an integer multiple of 8 octets long.
+ return 0, 1, tcpip.ErrMessageTooLong
+ }
+
+ if fragmentPayloadLen < uint32(pkt.TransportHeader().View().Size()) {
// As per RFC 8200 Section 4.5, the Transport Header is expected to be small
// enough to fit in the first fragment.
return 0, 1, tcpip.ErrMessageTooLong
}
- pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, calculateFragmentReserve(pkt))
+ pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadLen, calculateFragmentReserve(pkt))
id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, e.protocol.hashIV)%buckets], 1)
- networkHeader := header.IPv6(pkt.NetworkHeader().View())
var n int
for {
@@ -416,17 +433,18 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, p
}
n++
if !more {
- break
+ return n, pf.RemainingFragmentCount(), nil
}
}
-
- return n, 0, nil
}
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
e.addIPHeader(r, pkt, params)
+ return e.writePacket(r, gso, pkt, params.Protocol)
+}
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
@@ -467,8 +485,14 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
return nil
}
- if e.packetMustBeFragmented(pkt, gso) {
- sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return err
+ }
+
+ if packetMustBeFragmented(pkt, networkMTU, gso) {
+ sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
// TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
// fragment one by one using WritePacket() (current strategy) or if we
// want to create a PacketBufferList from the fragments and feed it to
@@ -498,24 +522,30 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return pkts.Len(), nil
}
+ linkMTU := e.nic.MTU()
for pb := pkts.Front(); pb != nil; pb = pb.Next() {
e.addIPHeader(r, pb, params)
- if e.packetMustBeFragmented(pb, gso) {
- current := pb
- _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+
+ networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size()))
+ if err != nil {
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
+ return 0, err
+ }
+ if packetMustBeFragmented(pb, networkMTU, gso) {
+ // Keep track of the packet that is about to be fragmented so it can be
+ // removed once the fragmentation is done.
+ originalPkt := pb
+ if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
// Modify the packet list in place with the new fragments.
- pkts.InsertAfter(current, fragPkt)
- current = current.Next()
+ pkts.InsertAfter(pb, fragPkt)
+ pb = fragPkt
return nil
- })
- if err != nil {
+ }); err != nil {
r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len()))
return 0, err
}
- // The fragmented packet can be released. The rest of the packets can be
- // processed.
- pkts.Remove(pb)
- pb = current
+ // Remove the packet that was just fragmented and process the rest.
+ pkts.Remove(originalPkt)
}
}
@@ -569,11 +599,40 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return n + len(dropped), nil
}
-// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
-// supported by IPv6.
-func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
- // TODO(b/146666412): Support IPv6 header-included packets.
- return tcpip.ErrNotSupported
+// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
+ // The packet already has an IP header, but there are a few required checks.
+ h, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ if !ok {
+ return tcpip.ErrMalformedHeader
+ }
+ ip := header.IPv6(h)
+
+ // Always set the payload length.
+ pktSize := pkt.Data.Size()
+ ip.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize))
+
+ // Set the source address when zero.
+ if ip.SourceAddress() == header.IPv6Any {
+ ip.SetSourceAddress(r.LocalAddress)
+ }
+
+ // Set the destination. If the packet already included a destination, it will
+ // be part of the route anyways.
+ ip.SetDestinationAddress(r.RemoteAddress)
+
+ // Populate the packet buffer's network header and don't allow an invalid
+ // packet to be sent.
+ //
+ // Note that parsing only makes sure that the packet is well formed as per the
+ // wire format. We also want to check if the header's fields are valid before
+ // sending the packet.
+ proto, _, _, _, ok := parse.IPv6(pkt)
+ if !ok || !header.IPv6(pkt.NetworkHeader().View()).IsValid(pktSize) {
+ return tcpip.ErrMalformedHeader
+ }
+
+ return e.writePacket(r, nil /* gso */, pkt, proto)
}
// HandlePacket is called by the link layer when new ipv6 packets arrive for
@@ -718,6 +777,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
continue
}
+ fragmentFieldOffset := it.ParseOffset()
+
// Don't consume the iterator if we have the first fragment because we
// will use it to validate that the first fragment holds the upper layer
// header.
@@ -775,17 +836,59 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
return
}
+ // As per RFC 2460 Section 4.5:
+ //
+ // If the length of a fragment, as derived from the fragment packet's
+ // Payload Length field, is not a multiple of 8 octets and the M flag
+ // of that fragment is 1, then that fragment must be discarded and an
+ // ICMP Parameter Problem, Code 0, message should be sent to the source
+ // of the fragment, pointing to the Payload Length field of the
+ // fragment packet.
+ if extHdr.More() && fragmentPayloadLen%header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit != 0 {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6ErroneousHeader,
+ pointer: header.IPv6PayloadLenOffset,
+ }, pkt)
+ return
+ }
+
// The packet is a fragment, let's try to reassemble it.
start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit
- // Drop the fragment if the size of the reassembled payload would exceed
- // the maximum payload size.
+ // As per RFC 2460 Section 4.5:
+ //
+ // If the length and offset of a fragment are such that the Payload
+ // Length of the packet reassembled from that fragment would exceed
+ // 65,535 octets, then that fragment must be discarded and an ICMP
+ // Parameter Problem, Code 0, message should be sent to the source of
+ // the fragment, pointing to the Fragment Offset field of the fragment
+ // packet.
if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
+ _ = e.protocol.returnError(r, &icmpReasonParameterProblem{
+ code: header.ICMPv6ErroneousHeader,
+ pointer: fragmentFieldOffset,
+ }, pkt)
return
}
+ // Set up a callback in case we need to send a Time Exceeded Message as
+ // per RFC 2460 Section 4.5.
+ var releaseCB func(bool)
+ if start == 0 {
+ pkt := pkt.Clone()
+ r := r.Clone()
+ releaseCB = func(timedOut bool) {
+ if timedOut {
+ _ = e.protocol.returnError(&r, &icmpReasonReassemblyTimeout{}, pkt)
+ }
+ r.Release()
+ }
+ }
+
// Note that pkt doesn't have its transport header set after reassembly,
// and won't until DeliverNetworkPacket sets it.
data, proto, ready, err := e.protocol.fragmentation.Process(
@@ -801,6 +904,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
extHdr.More(),
uint8(rawPayload.Identifier),
rawPayload.Buf,
+ releaseCB,
)
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
@@ -1398,14 +1502,31 @@ func (p *protocol) SetForwarding(v bool) {
}
}
-// calculateMTU calculates the network-layer payload MTU based on the link-layer
-// payload mtu.
-func calculateMTU(mtu uint32) uint32 {
- mtu -= header.IPv6MinimumSize
- if mtu <= maxPayloadSize {
- return mtu
+// calculateNetworkMTU calculates the network-layer payload MTU based on the
+// link-layer payload MTU and the length of every IPv6 header.
+// Note that this is different than the Payload Length field of the IPv6 header,
+// which includes the length of the extension headers.
+func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, *tcpip.Error) {
+ if linkMTU < header.IPv6MinimumMTU {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ // As per RFC 7112 section 5, we should discard packets if their IPv6 header
+ // is bigger than 1280 bytes (ie, the minimum link MTU) since we do not
+ // support PMTU discovery:
+ // Hosts that do not discover the Path MTU MUST limit the IPv6 Header Chain
+ // length to 1280 bytes. Limiting the IPv6 Header Chain length to 1280
+ // bytes ensures that the header chain length does not exceed the IPv6
+ // minimum MTU.
+ if networkHeadersLen > header.IPv6MinimumMTU {
+ return 0, tcpip.ErrMalformedHeader
+ }
+
+ networkMTU := linkMTU - uint32(networkHeadersLen)
+ if networkMTU > maxPayloadSize {
+ networkMTU = maxPayloadSize
}
- return maxPayloadSize
+ return networkMTU, nil
}
// Options holds options to configure a new protocol.
@@ -1459,7 +1580,7 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
return func(s *stack.Stack) stack.NetworkProtocol {
p := &protocol{
stack: s,
- fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()),
+ fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()),
ids: ids,
hashIV: hashIV,
@@ -1480,23 +1601,6 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
return NewProtocolWithOptions(Options{})(s)
}
-// calculateFragmentInnerMTU calculates the maximum number of bytes of
-// fragmentable data a fragment can have, based on the link layer mtu and pkt's
-// network header size.
-func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 {
- // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are
- // supported for outbound packets, their length should not affect the fragment
- // MTU because they should only be transmitted once.
- mtu -= uint32(pkt.NetworkHeader().View().Size())
- mtu -= header.IPv6FragmentHeaderSize
- // Round the MTU down to align to 8 bytes.
- mtu &^= 7
- if mtu <= maxPayloadSize {
- return mtu
- }
- return maxPayloadSize
-}
-
func calculateFragmentReserve(pkt *stack.PacketBuffer) int {
return pkt.AvailableHeaderBytes() + pkt.NetworkHeader().View().Size() + header.IPv6FragmentHeaderSize
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index bee18d1a8..c593c0004 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/testutil"
@@ -49,6 +50,8 @@ const (
fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier)
destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier)
noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier)
+
+ extraHeaderReserve = 50
)
// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the
@@ -181,6 +184,9 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB
return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff)
}
+ if got := fragment.AvailableHeaderBytes(); got != extraHeaderReserve {
+ return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve)
+ }
if fragment.NetworkProtocolNumber != sourcePacket.NetworkProtocolNumber {
return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, fragment.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber)
}
@@ -208,8 +214,7 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB
reassembledPayload.Append(fragment.Data)
}
- result := reassembledPayload.ToView()
- if diff := cmp.Diff(result, buffer.View(source[sourceIPHeadersLen:])); diff != "" {
+ if diff := cmp.Diff(buffer.View(source[sourceIPHeadersLen:]), reassembledPayload.ToView()); diff != "" {
return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff)
}
@@ -234,7 +239,7 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory},
})
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(10, header.IPv6MinimumMTU, linkAddr1)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
}
@@ -267,7 +272,7 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory},
})
- e := channel.New(1, 1280, linkAddr1)
+ e := channel.New(1, header.IPv6MinimumMTU, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -821,7 +826,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- e := channel.New(1, 1280, linkAddr1)
+ e := channel.New(1, header.IPv6MinimumMTU, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -1840,7 +1845,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- e := channel.New(0, 1280, linkAddr1)
+ e := channel.New(0, header.IPv6MinimumMTU, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -1908,16 +1913,19 @@ func TestReceiveIPv6Fragments(t *testing.T) {
func TestInvalidIPv6Fragments(t *testing.T) {
const (
- nicID = 1
- fragmentExtHdrLen = 8
+ addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ nicID = 1
+ hoplimit = 255
+ ident = 1
+ data = "TEST_INVALID_IPV6_FRAGMENTS"
)
- payloadGen := func(payloadLen int) []byte {
- payload := make([]byte, payloadLen)
- for i := 0; i < len(payload); i++ {
- payload[i] = 0x30
- }
- return payload
+ type fragmentData struct {
+ ipv6Fields header.IPv6Fields
+ ipv6FragmentFields header.IPv6FragmentFields
+ payload []byte
}
tests := []struct {
@@ -1925,31 +1933,64 @@ func TestInvalidIPv6Fragments(t *testing.T) {
fragments []fragmentData
wantMalformedIPPackets uint64
wantMalformedFragments uint64
+ expectICMP bool
+ expectICMPType header.ICMPv6Type
+ expectICMPCode header.ICMPv6Code
+ expectICMPTypeSpecific uint32
}{
{
+ name: "fragment size is not a multiple of 8 and the M flag is true",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 9,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0 >> 3,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:9],
+ },
+ },
+ wantMalformedIPPackets: 1,
+ wantMalformedFragments: 1,
+ expectICMP: true,
+ expectICMPType: header.ICMPv6ParamProblem,
+ expectICMPCode: header.ICMPv6ErroneousHeader,
+ expectICMPTypeSpecific: header.IPv6PayloadLenOffset,
+ },
+ {
name: "fragments reassembled into a payload exceeding the max IPv6 payload size",
fragments: []fragmentData{
{
- srcAddr: addr1,
- dstAddr: addr2,
- nextHdr: fragmentExtHdrID,
- data: buffer.NewVectorisedView(
- fragmentExtHdrLen+(header.IPv6MaximumPayloadSize+1)-16,
- []buffer.View{
- // Fragment extension header.
- // Fragment offset = 8190, More = false, ID = 1
- buffer.View([]byte{uint8(header.UDPProtocolNumber), 0,
- ((header.IPv6MaximumPayloadSize + 1) - 16) >> 8,
- ((header.IPv6MaximumPayloadSize + 1) - 16) & math.MaxUint8,
- 0, 0, 0, 1}),
- // Payload length = 16
- payloadGen(16),
- },
- ),
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3,
+ M: false,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
},
},
wantMalformedIPPackets: 1,
wantMalformedFragments: 1,
+ expectICMP: true,
+ expectICMPType: header.ICMPv6ParamProblem,
+ expectICMPCode: header.ICMPv6ErroneousHeader,
+ expectICMPTypeSpecific: header.IPv6MinimumSize + 2, /* offset for 'Fragment Offset' in the fragment header */
},
}
@@ -1960,33 +2001,40 @@ func TestInvalidIPv6Fragments(t *testing.T) {
NewProtocol,
},
})
- e := channel.New(0, 1500, linkAddr1)
+ e := channel.New(1, 1500, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
}
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ }})
+ var expectICMPPayload buffer.View
for _, f := range test.fragments {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)
- // Serialize IPv6 fixed header.
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(f.data.Size()),
- NextHeader: f.nextHdr,
- HopLimit: 255,
- SrcAddr: f.srcAddr,
- DstAddr: f.dstAddr,
- })
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize))
+ ip.Encode(&f.ipv6Fields)
+
+ fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:])
+ fragHDR.Encode(&f.ipv6FragmentFields)
vv := hdr.View().ToVectorisedView()
- vv.Append(f.data)
+ vv.AppendView(f.payload)
- e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: vv,
- }))
+ })
+
+ if test.expectICMP {
+ expectICMPPayload = stack.PayloadSince(pkt.NetworkHeader())
+ }
+
+ e.InjectInbound(ProtocolNumber, pkt)
}
if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want {
@@ -1995,6 +2043,287 @@ func TestInvalidIPv6Fragments(t *testing.T) {
if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want {
t.Errorf("got Stats.IP.MalformedFragmentsReceived = %d, want = %d", got, want)
}
+
+ reply, ok := e.Read()
+ if !test.expectICMP {
+ if ok {
+ t.Fatalf("unexpected ICMP error message received: %#v", reply)
+ }
+ return
+ }
+ if !ok {
+ t.Fatal("expected ICMP error message missing")
+ }
+
+ checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
+ checker.SrcAddr(addr2),
+ checker.DstAddr(addr1),
+ checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectICMPPayload.Size())),
+ checker.ICMPv6(
+ checker.ICMPv6Type(test.expectICMPType),
+ checker.ICMPv6Code(test.expectICMPCode),
+ checker.ICMPv6TypeSpecific(test.expectICMPTypeSpecific),
+ checker.ICMPv6Payload([]byte(expectICMPPayload)),
+ ),
+ )
+ })
+ }
+}
+
+func TestFragmentReassemblyTimeout(t *testing.T) {
+ const (
+ addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ nicID = 1
+ hoplimit = 255
+ ident = 1
+ data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT"
+ )
+
+ type fragmentData struct {
+ ipv6Fields header.IPv6Fields
+ ipv6FragmentFields header.IPv6FragmentFields
+ payload []byte
+ }
+
+ tests := []struct {
+ name string
+ fragments []fragmentData
+ expectICMP bool
+ }{
+ {
+ name: "first fragment only",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
+ },
+ },
+ expectICMP: true,
+ },
+ {
+ name: "two first fragments",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
+ },
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
+ },
+ },
+ expectICMP: true,
+ },
+ {
+ name: "second fragment only",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 8,
+ M: false,
+ Identification: ident,
+ },
+ payload: []byte(data)[16:],
+ },
+ },
+ expectICMP: false,
+ },
+ {
+ name: "two fragments with a gap",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
+ },
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 8,
+ M: false,
+ Identification: ident,
+ },
+ payload: []byte(data)[16:],
+ },
+ },
+ expectICMP: true,
+ },
+ {
+ name: "two fragments with a gap in reverse order",
+ fragments: []fragmentData{
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16),
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 8,
+ M: false,
+ Identification: ident,
+ },
+ payload: []byte(data)[16:],
+ },
+ {
+ ipv6Fields: header.IPv6Fields{
+ PayloadLength: header.IPv6FragmentHeaderSize + 16,
+ NextHeader: header.IPv6FragmentHeader,
+ HopLimit: hoplimit,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ },
+ ipv6FragmentFields: header.IPv6FragmentFields{
+ NextHeader: uint8(header.UDPProtocolNumber),
+ FragmentOffset: 0,
+ M: true,
+ Identification: ident,
+ },
+ payload: []byte(data)[:16],
+ },
+ },
+ expectICMP: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ clock := faketime.NewManualClock()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{
+ NewProtocol,
+ },
+ Clock: clock,
+ })
+
+ e := channel.New(1, 1500, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ }})
+
+ var firstFragmentSent buffer.View
+ for _, f := range test.fragments {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)
+
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize))
+ ip.Encode(&f.ipv6Fields)
+
+ fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:])
+ fragHDR.Encode(&f.ipv6FragmentFields)
+
+ vv := hdr.View().ToVectorisedView()
+ vv.AppendView(f.payload)
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ })
+
+ if firstFragmentSent == nil && fragHDR.FragmentOffset() == 0 {
+ firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader())
+ }
+
+ e.InjectInbound(ProtocolNumber, pkt)
+ }
+
+ clock.Advance(ReassembleTimeout)
+
+ reply, ok := e.Read()
+ if !test.expectICMP {
+ if ok {
+ t.Fatalf("unexpected ICMP error message received: %#v", reply)
+ }
+ return
+ }
+ if !ok {
+ t.Fatal("expected ICMP error message missing")
+ }
+ if firstFragmentSent == nil {
+ t.Fatalf("unexpected ICMP error message received: %#v", reply)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()),
+ checker.SrcAddr(addr2),
+ checker.DstAddr(addr1),
+ checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+firstFragmentSent.Size())),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6TimeExceeded),
+ checker.ICMPv6Code(header.ICMPv6ReassemblyTimeout),
+ checker.ICMPv6Payload([]byte(firstFragmentSent)),
+ ),
+ )
})
}
}
@@ -2217,24 +2546,19 @@ type fragmentInfo struct {
payloadSize uint16
}
-type fragmentationTestCase struct {
+var fragmentationTests = []struct {
description string
mtu uint32
gso *stack.GSO
transHdrLen int
- extraHdrLen int
payloadSize int
wantFragments []fragmentInfo
- expectedFrags int
-}
-
-var fragmentationTests = []fragmentationTestCase{
+}{
{
- description: "No Fragmentation",
- mtu: 1280,
- gso: &stack.GSO{},
+ description: "No fragmentation",
+ mtu: header.IPv6MinimumMTU,
+ gso: nil,
transHdrLen: 0,
- extraHdrLen: header.IPv6MinimumSize,
payloadSize: 1000,
wantFragments: []fragmentInfo{
{offset: 0, payloadSize: 1000, more: false},
@@ -2242,10 +2566,20 @@ var fragmentationTests = []fragmentationTestCase{
},
{
description: "Fragmented",
- mtu: 1280,
- gso: &stack.GSO{},
+ mtu: header.IPv6MinimumMTU,
+ gso: nil,
+ transHdrLen: 0,
+ payloadSize: 2000,
+ wantFragments: []fragmentInfo{
+ {offset: 0, payloadSize: 1240, more: true},
+ {offset: 154, payloadSize: 776, more: false},
+ },
+ },
+ {
+ description: "Fragmented with mtu not a multiple of 8",
+ mtu: header.IPv6MinimumMTU + 1,
+ gso: nil,
transHdrLen: 0,
- extraHdrLen: header.IPv6MinimumSize,
payloadSize: 2000,
wantFragments: []fragmentInfo{
{offset: 0, payloadSize: 1240, more: true},
@@ -2255,20 +2589,18 @@ var fragmentationTests = []fragmentationTestCase{
{
description: "No fragmentation with big header",
mtu: 2000,
- gso: &stack.GSO{},
+ gso: nil,
transHdrLen: 100,
- extraHdrLen: header.IPv6MinimumSize,
payloadSize: 1000,
wantFragments: []fragmentInfo{
{offset: 0, payloadSize: 1100, more: false},
},
},
{
- description: "Fragmented with gso nil",
- mtu: 1280,
- gso: nil,
+ description: "Fragmented with gso none",
+ mtu: header.IPv6MinimumMTU,
+ gso: &stack.GSO{Type: stack.GSONone},
transHdrLen: 0,
- extraHdrLen: header.IPv6MinimumSize,
payloadSize: 1400,
wantFragments: []fragmentInfo{
{offset: 0, payloadSize: 1240, more: true},
@@ -2277,31 +2609,18 @@ var fragmentationTests = []fragmentationTestCase{
},
{
description: "Fragmented with big header",
- mtu: 1280,
- gso: &stack.GSO{},
+ mtu: header.IPv6MinimumMTU,
+ gso: nil,
transHdrLen: 100,
- extraHdrLen: header.IPv6MinimumSize,
payloadSize: 1200,
wantFragments: []fragmentInfo{
{offset: 0, payloadSize: 1240, more: true},
{offset: 154, payloadSize: 76, more: false},
},
},
- {
- description: "Fragmented with big header and prependable bytes",
- mtu: 1280,
- gso: &stack.GSO{},
- transHdrLen: 20,
- extraHdrLen: header.IPv6MinimumSize + 66,
- payloadSize: 1500,
- wantFragments: []fragmentInfo{
- {offset: 0, payloadSize: 1240, more: true},
- {offset: 154, payloadSize: 296, more: false},
- },
- },
}
-func TestFragmentation(t *testing.T) {
+func TestFragmentationWritePacket(t *testing.T) {
const (
ttl = 42
tos = stack.DefaultTOS
@@ -2310,7 +2629,7 @@ func TestFragmentation(t *testing.T) {
for _, ft := range fragmentationTests {
t.Run(ft.description, func(t *testing.T) {
- pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
source := pkt.Clone()
ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32)
r := buildRoute(t, ep)
@@ -2331,10 +2650,8 @@ func TestFragmentation(t *testing.T) {
if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 {
t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got)
}
- if len(ep.WrittenPackets) > 0 {
- if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
- t.Error(err)
- }
+ if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil {
+ t.Error(err)
}
})
}
@@ -2368,7 +2685,7 @@ func TestFragmentationWritePackets(t *testing.T) {
insertAfter: 1,
},
}
- tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber)
+ tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber)
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
@@ -2378,7 +2695,7 @@ func TestFragmentationWritePackets(t *testing.T) {
for i := 0; i < test.insertBefore; i++ {
pkts.PushBack(tinyPacket.Clone())
}
- pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
source := pkt
pkts.PushBack(pkt.Clone())
for i := 0; i < test.insertAfter; i++ {
@@ -2467,8 +2784,8 @@ func TestFragmentationErrors(t *testing.T) {
wantError: tcpip.ErrAborted,
},
{
- description: "Error on packet with MTU smaller than transport header",
- mtu: 1280,
+ description: "Error when MTU is smaller than transport header",
+ mtu: header.IPv6MinimumMTU,
transHdrLen: 1500,
payloadSize: 500,
allowPackets: 0,
@@ -2476,11 +2793,21 @@ func TestFragmentationErrors(t *testing.T) {
mockError: nil,
wantError: tcpip.ErrMessageTooLong,
},
+ {
+ description: "Error when MTU is smaller than IPv6 minimum MTU",
+ mtu: header.IPv6MinimumMTU - 1,
+ transHdrLen: 0,
+ payloadSize: 500,
+ allowPackets: 0,
+ outgoingErrors: 1,
+ mockError: nil,
+ wantError: tcpip.ErrInvalidEndpointState,
+ },
}
for _, ft := range tests {
t.Run(ft.description, func(t *testing.T) {
- pkt := testutil.MakeRandPkt(ft.transHdrLen, header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
+ pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber)
ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets)
r := buildRoute(t, ep)
err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index 4d3acab96..261705575 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -361,6 +361,8 @@ func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) *
return tcpip.ErrInvalidEndpointState
}
+ a.mu.Lock()
+ defer a.mu.Unlock()
return a.removePermanentEndpointLocked(addrState)
}
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go
index cf042309e..380688038 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -178,7 +178,7 @@ func (*fwdTestNetworkProtocol) Close() {}
func (*fwdTestNetworkProtocol) Wait() {}
-func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
+func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
if f.onLinkAddressResolved != nil {
time.AfterFunc(f.addrResolveDelay, func() {
f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr)
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index 6f73a0ce4..c9b13cd0e 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -180,7 +180,7 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt
}
// get reports any known link address for k.
-func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
if linkRes != nil {
if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok {
return addr, nil, nil
@@ -221,7 +221,7 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo
}
entry.done = make(chan struct{})
- go c.startAddressResolution(k, linkRes, localAddr, linkEP, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
+ go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
}
return entry.linkAddr, entry.done, tcpip.ErrWouldBlock
@@ -240,11 +240,11 @@ func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
}
}
-func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, done <-chan struct{}) {
+func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) {
for i := 0; ; i++ {
// Send link request, then wait for the timeout limit and check
// whether the request succeeded.
- linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, linkEP)
+ linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, nic)
select {
case now := <-time.After(c.resolutionTimeout):
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 33806340e..d2e37f38d 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -49,8 +49,8 @@ type testLinkAddressResolver struct {
onLinkAddressRequest func()
}
-func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error {
- time.AfterFunc(r.delay, func() { r.fakeRequest(addr) })
+func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
+ time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) })
if f := r.onLinkAddressRequest; f != nil {
f()
}
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index 4df288798..eebf43a1f 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -68,7 +68,7 @@ var _ NUDHandler = (*neighborCache)(nil)
// reset to state incomplete, and returned. If no matching entry exists and the
// cache is not full, a new entry with state incomplete is allocated and
// returned.
-func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry {
+func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry {
n.mu.Lock()
defer n.mu.Unlock()
@@ -84,7 +84,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li
// The entry that needs to be created must be dynamic since all static
// entries are directly added to the cache via addStaticEntry.
- entry := newNeighborEntry(n.nic, remoteAddr, localAddr, n.state, linkRes)
+ entry := newNeighborEntry(n.nic, remoteAddr, n.state, linkRes)
if n.dynamic.count == neighborCacheSize {
e := n.dynamic.lru.Back()
e.mu.Lock()
@@ -111,6 +111,10 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li
// provided, it will be notified when address resolution is complete (success
// or not).
//
+// If specified, the local address must be an address local to the interface the
+// neighbor cache belongs to. The local address is the source address of a
+// packet prompting NUD/link address resolution.
+//
// If address resolution is required, ErrNoLinkAddress and a notification
// channel is returned for the top level caller to block. Channel is closed
// once address resolution is complete (success or not).
@@ -118,7 +122,6 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok {
e := NeighborEntry{
Addr: remoteAddr,
- LocalAddr: localAddr,
LinkAddr: linkAddr,
State: Static,
UpdatedAt: time.Now(),
@@ -126,13 +129,13 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
return e, nil, nil
}
- entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes)
+ entry := n.getOrCreateEntry(remoteAddr, linkRes)
entry.mu.Lock()
defer entry.mu.Unlock()
switch s := entry.neigh.State; s {
case Stale:
- entry.handlePacketQueuedLocked()
+ entry.handlePacketQueuedLocked(localAddr)
fallthrough
case Reachable, Static, Delay, Probe:
// As per RFC 4861 section 7.3.3:
@@ -152,7 +155,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
entry.done = make(chan struct{})
}
- entry.handlePacketQueuedLocked()
+ entry.handlePacketQueuedLocked(localAddr)
return entry.neigh, entry.done, tcpip.ErrWouldBlock
case Failed:
return entry.neigh, nil, tcpip.ErrNoLinkAddress
@@ -207,7 +210,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd
} else {
// Static entry found with the same address but different link address.
entry.neigh.LinkAddr = linkAddr
- entry.dispatchChangeEventLocked(entry.neigh.State)
+ entry.dispatchChangeEventLocked()
entry.mu.Unlock()
return
}
@@ -220,8 +223,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd
entry.mu.Unlock()
}
- entry := newStaticNeighborEntry(n.nic, addr, linkAddr, n.state)
- n.cache[addr] = entry
+ n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state)
}
// removeEntryLocked removes the specified entry from the neighbor cache.
@@ -292,8 +294,8 @@ func (n *neighborCache) setConfig(config NUDConfigurations) {
// HandleProbe implements NUDHandler.HandleProbe by following the logic defined
// in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled
// by the caller.
-func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) {
- entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes)
+func (n *neighborCache) HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) {
+ entry := n.getOrCreateEntry(remoteAddr, linkRes)
entry.mu.Lock()
entry.handleProbeLocked(remoteLinkAddr)
entry.mu.Unlock()
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index fcd54ed83..d81f00848 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -128,9 +128,8 @@ func newTestEntryStore() *testEntryStore {
linkAddr := toLinkAddress(i)
store.entriesMap[addr] = NeighborEntry{
- Addr: addr,
- LocalAddr: testEntryLocalAddr,
- LinkAddr: linkAddr,
+ Addr: addr,
+ LinkAddr: linkAddr,
}
}
return store
@@ -195,10 +194,10 @@ type testNeighborResolver struct {
var _ LinkAddressResolver = (*testNeighborResolver)(nil)
-func (r *testNeighborResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
+func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
// Delay handling the request to emulate network latency.
r.clock.AfterFunc(r.delay, func() {
- r.fakeRequest(addr)
+ r.fakeRequest(targetAddr)
})
// Execute post address resolution action, if available.
@@ -294,9 +293,8 @@ func TestNeighborCacheEntry(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -305,15 +303,19 @@ func TestNeighborCacheEntry(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -324,8 +326,8 @@ func TestNeighborCacheEntry(t *testing.T) {
t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil {
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
}
// No more events should have been dispatched.
@@ -354,9 +356,9 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -365,15 +367,19 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -391,9 +397,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -404,8 +412,8 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
}
}
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
}
@@ -452,8 +460,8 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
if !ok {
return fmt.Errorf("c.store.entry(%d) not found", i)
}
- if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock {
- return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
+ return fmt.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(c.neigh.config().RetransmitTimer)
@@ -470,23 +478,29 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
wantEvents = append(wantEvents, testEntryEventInfo{
EventType: entryTestRemoved,
NICID: 1,
- Addr: removedEntry.Addr,
- LinkAddr: removedEntry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: removedEntry.Addr,
+ LinkAddr: removedEntry.LinkAddr,
+ State: Reachable,
+ },
})
}
wantEvents = append(wantEvents, testEntryEventInfo{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
}, testEntryEventInfo{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
})
c.nudDisp.mu.Lock()
@@ -508,10 +522,9 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
return fmt.Errorf("c.store.entry(%d) not found", i)
}
wantEntry := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
}
wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
}
@@ -564,24 +577,27 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
if !ok {
t.Fatalf("c.store.entry(0) not found")
}
- _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(c.neigh.config().RetransmitTimer)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -600,9 +616,11 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -640,9 +658,11 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -682,9 +702,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T)
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -703,9 +725,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T)
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -740,9 +764,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -760,9 +786,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -800,24 +828,27 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
if !ok {
t.Fatalf("c.store.entry(0) not found")
}
- _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -836,16 +867,20 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: staticLinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -861,10 +896,9 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
startAtEntryIndex: 1,
wantStaticEntries: []NeighborEntry{
{
- Addr: entry.Addr,
- LocalAddr: "", // static entries don't need a local address
- LinkAddr: staticLinkAddr,
- State: Static,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
},
},
}
@@ -896,12 +930,12 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w)
if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, _ = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ t.Fatalf("got neigh.entry(%s, '', _, _ = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
if doneCh == nil {
- t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr)
}
clock.Advance(typicalLatency)
@@ -913,7 +947,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) {
id, ok := s.Fetch(false /* block */)
if !ok {
- t.Errorf("expected waker to be notified after neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ t.Errorf("expected waker to be notified after neigh.entry(%s, '', _, _)", entry.Addr)
}
if id != wakerID {
t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID)
@@ -923,15 +957,19 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -964,12 +1002,12 @@ func TestNeighborCacheRemoveWaker(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w)
if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, _) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ t.Fatalf("got neigh.entry(%s, '', _, _) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
if doneCh == nil {
- t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr)
}
// Remove the waker before the neighbor cache has the opportunity to send a
@@ -991,15 +1029,19 @@ func TestNeighborCacheRemoveWaker(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1028,10 +1070,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: "", // static entries don't need a local address
- LinkAddr: entry.LinkAddr,
- State: Static,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
@@ -1041,9 +1082,11 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -1058,10 +1101,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
startAtEntryIndex: 1,
wantStaticEntries: []NeighborEntry{
{
- Addr: entry.Addr,
- LocalAddr: "", // static entries don't need a local address
- LinkAddr: entry.LinkAddr,
- State: Static,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
},
},
}
@@ -1089,9 +1131,8 @@ func TestNeighborCacheClear(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
@@ -1099,15 +1140,19 @@ func TestNeighborCacheClear(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1126,9 +1171,11 @@ func TestNeighborCacheClear(t *testing.T) {
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Static,
+ },
},
}
nudDisp.mu.Lock()
@@ -1149,16 +1196,20 @@ func TestNeighborCacheClear(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Static,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Static,
+ },
},
}
nudDisp.mu.Lock()
@@ -1185,24 +1236,27 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
if !ok {
t.Fatalf("c.store.entry(0) not found")
}
- _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -1220,9 +1274,11 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
c.nudDisp.mu.Lock()
@@ -1274,29 +1330,33 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
case <-doneCh:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
}
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1312,9 +1372,8 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
for i := neighborCacheSize; i < store.size(); i++ {
// Periodically refresh the frequently used entry
if i%(neighborCacheSize/2) == 0 {
- _, _, err := neigh.entry(frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, linkRes, nil)
- if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, err)
+ if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil {
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", frequentlyUsedEntry.Addr, err)
}
}
@@ -1322,15 +1381,15 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
case <-doneCh:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
}
// An entry should have been removed, as per the LRU eviction strategy
@@ -1342,22 +1401,28 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
{
EventType: entryTestRemoved,
NICID: 1,
- Addr: removedEntry.Addr,
- LinkAddr: removedEntry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: removedEntry.Addr,
+ LinkAddr: removedEntry.LinkAddr,
+ State: Reachable,
+ },
},
{
EventType: entryTestAdded,
NICID: 1,
- Addr: entry.Addr,
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: 1,
- Addr: entry.Addr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1374,10 +1439,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
// have to be sorted before comparison.
wantUnsortedEntries := []NeighborEntry{
{
- Addr: frequentlyUsedEntry.Addr,
- LocalAddr: frequentlyUsedEntry.LocalAddr,
- LinkAddr: frequentlyUsedEntry.LinkAddr,
- State: Reachable,
+ Addr: frequentlyUsedEntry.Addr,
+ LinkAddr: frequentlyUsedEntry.LinkAddr,
+ State: Reachable,
},
}
@@ -1387,10 +1451,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
t.Fatalf("store.entry(%d) not found", i)
}
wantEntry := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
}
wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
}
@@ -1430,9 +1493,8 @@ func TestNeighborCacheConcurrent(t *testing.T) {
wg.Add(1)
go func(entry NeighborEntry) {
defer wg.Done()
- e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != nil && err != tcpip.ErrWouldBlock {
- t.Errorf("got neigh.entry(%s, %s, _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, entry.LocalAddr, e, err, tcpip.ErrWouldBlock)
+ if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, '', _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock)
}
}(entry)
}
@@ -1456,10 +1518,9 @@ func TestNeighborCacheConcurrent(t *testing.T) {
t.Errorf("store.entry(%d) not found", i)
}
wantEntry := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
}
wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
}
@@ -1488,37 +1549,36 @@ func TestNeighborCacheReplace(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
select {
case <-doneCh:
default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr)
}
// Verify the entry exists
{
- e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ e, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
}
if doneCh != nil {
- t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh)
+ t.Errorf("unexpected done channel from neigh.entry(%s, '', _, nil): %v", entry.Addr, doneCh)
}
if t.Failed() {
t.FailNow()
}
want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
}
@@ -1542,37 +1602,35 @@ func TestNeighborCacheReplace(t *testing.T) {
//
// Verify the entry's new link address and the new state.
{
- e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ e, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Fatalf("neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ t.Fatalf("neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: updatedLinkAddr,
- State: Delay,
+ Addr: entry.Addr,
+ LinkAddr: updatedLinkAddr,
+ State: Delay,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
clock.Advance(config.DelayFirstProbeTime + typicalLatency)
}
// Verify that the neighbor is now reachable.
{
- e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ e, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
clock.Advance(typicalLatency)
if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: updatedLinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: updatedLinkAddr,
+ State: Reachable,
}
if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
}
}
@@ -1601,35 +1659,34 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
clock.Advance(typicalLatency)
- got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ got, _, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err)
}
want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
}
if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
}
// Verify that address resolution for an unknown address returns ErrNoLinkAddress
before := atomic.LoadUint32(&requestCount)
entry.Addr += "2"
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress)
}
maxAttempts := neigh.config().MaxUnicastProbes
@@ -1659,13 +1716,13 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
if !ok {
t.Fatalf("store.entry(0) not found")
}
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
clock.Advance(waitFor)
- if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
+ if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress)
}
}
@@ -1683,18 +1740,17 @@ func TestNeighborCacheStaticResolution(t *testing.T) {
delay: typicalLatency,
}
- got, _, err := neigh.entry(testEntryBroadcastAddr, testEntryLocalAddr, linkRes, nil)
+ got, _, err := neigh.entry(testEntryBroadcastAddr, "", linkRes, nil)
if err != nil {
- t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", testEntryBroadcastAddr, testEntryLocalAddr, err)
+ t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", testEntryBroadcastAddr, err)
}
want := NeighborEntry{
- Addr: testEntryBroadcastAddr,
- LocalAddr: testEntryLocalAddr,
- LinkAddr: testEntryBroadcastLinkAddr,
- State: Static,
+ Addr: testEntryBroadcastAddr,
+ LinkAddr: testEntryBroadcastLinkAddr,
+ State: Static,
}
if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, testEntryLocalAddr, diff)
+ t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff)
}
}
@@ -1719,9 +1775,9 @@ func BenchmarkCacheClear(b *testing.B) {
if !ok {
b.Fatalf("store.entry(%d) not found", i)
}
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil)
if err != tcpip.ErrWouldBlock {
- b.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ b.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock)
}
if doneCh != nil {
<-doneCh
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index be61a21af..bd80f95bd 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -27,7 +27,6 @@ import (
// NeighborEntry describes a neighboring device in the local network.
type NeighborEntry struct {
Addr tcpip.Address
- LocalAddr tcpip.Address
LinkAddr tcpip.LinkAddress
State NeighborState
UpdatedAt time.Time
@@ -106,35 +105,35 @@ type neighborEntry struct {
// state, Unknown. Transition out of Unknown by calling either
// `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created
// neighborEntry.
-func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, localAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry {
+func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry {
return &neighborEntry{
nic: nic,
linkRes: linkRes,
nudState: nudState,
neigh: NeighborEntry{
- Addr: remoteAddr,
- LocalAddr: localAddr,
- State: Unknown,
+ Addr: remoteAddr,
+ State: Unknown,
},
}
}
-// newStaticNeighborEntry creates a neighbor cache entry starting at the Static
-// state. The entry can only transition out of Static by directly calling
-// `setStateLocked`.
+// newStaticNeighborEntry creates a neighbor cache entry starting at the
+// Static state. The entry can only transition out of Static by directly
+// calling `setStateLocked`.
func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry {
+ entry := NeighborEntry{
+ Addr: addr,
+ LinkAddr: linkAddr,
+ State: Static,
+ UpdatedAt: time.Now(),
+ }
if nic.stack.nudDisp != nil {
- nic.stack.nudDisp.OnNeighborAdded(nic.id, addr, linkAddr, Static, time.Now())
+ nic.stack.nudDisp.OnNeighborAdded(nic.id, entry)
}
return &neighborEntry{
nic: nic,
nudState: state,
- neigh: NeighborEntry{
- Addr: addr,
- LinkAddr: linkAddr,
- State: Static,
- UpdatedAt: time.Now(),
- },
+ neigh: entry,
}
}
@@ -165,17 +164,17 @@ func (e *neighborEntry) notifyWakersLocked() {
// dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has
// been added.
-func (e *neighborEntry) dispatchAddEventLocked(nextState NeighborState) {
+func (e *neighborEntry) dispatchAddEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
- nudDisp.OnNeighborAdded(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now())
+ nudDisp.OnNeighborAdded(e.nic.id, e.neigh)
}
}
// dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry
// has changed state or link-layer address.
-func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) {
+func (e *neighborEntry) dispatchChangeEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
- nudDisp.OnNeighborChanged(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now())
+ nudDisp.OnNeighborChanged(e.nic.id, e.neigh)
}
}
@@ -183,7 +182,7 @@ func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) {
// has been removed.
func (e *neighborEntry) dispatchRemoveEventLocked() {
if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
- nudDisp.OnNeighborRemoved(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, e.neigh.State, time.Now())
+ nudDisp.OnNeighborRemoved(e.nic.id, e.neigh)
}
}
@@ -206,63 +205,19 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
switch next {
case Incomplete:
- var retryCounter uint32
- var sendMulticastProbe func()
-
- sendMulticastProbe = func() {
- if retryCounter == config.MaxMulticastProbes {
- // "If no Neighbor Advertisement is received after
- // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed.
- // The sender MUST return ICMP destination unreachable indications with
- // code 3 (Address Unreachable) for each packet queued awaiting address
- // resolution." - RFC 4861 section 7.2.2
- //
- // There is no need to send an ICMP destination unreachable indication
- // since the failure to resolve the address is expected to only occur
- // on this node. Thus, redirecting traffic is currently not supported.
- //
- // "If the error occurs on a node other than the node originating the
- // packet, an ICMP error message is generated. If the error occurs on
- // the originating node, an implementation is not required to actually
- // create and send an ICMP error packet to the source, as long as the
- // upper-layer sender is notified through an appropriate mechanism
- // (e.g. return value from a procedure call). Note, however, that an
- // implementation may find it convenient in some cases to return errors
- // to the sender by taking the offending packet, generating an ICMP
- // error message, and then delivering it (locally) through the generic
- // error-handling routines.' - RFC 4861 section 2.1
- e.dispatchRemoveEventLocked()
- e.setStateLocked(Failed)
- return
- }
-
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.LinkEndpoint); err != nil {
- // There is no need to log the error here; the NUD implementation may
- // assume a working link. A valid link should be the responsibility of
- // the NIC/stack.LinkEndpoint.
- e.dispatchRemoveEventLocked()
- e.setStateLocked(Failed)
- return
- }
-
- retryCounter++
- e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe)
- e.job.Schedule(config.RetransmitTimer)
- }
-
- sendMulticastProbe()
+ panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.neigh, prev))
case Reachable:
e.job = e.nic.stack.newJob(&e.mu, func() {
- e.dispatchChangeEventLocked(Stale)
e.setStateLocked(Stale)
+ e.dispatchChangeEventLocked()
})
e.job.Schedule(e.nudState.ReachableTime())
case Delay:
e.job = e.nic.stack.newJob(&e.mu, func() {
- e.dispatchChangeEventLocked(Probe)
e.setStateLocked(Probe)
+ e.dispatchChangeEventLocked()
})
e.job.Schedule(config.DelayFirstProbeTime)
@@ -277,19 +232,13 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
return
}
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.LinkEndpoint); err != nil {
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr, e.nic); err != nil {
e.dispatchRemoveEventLocked()
e.setStateLocked(Failed)
return
}
retryCounter++
- if retryCounter == config.MaxUnicastProbes {
- e.dispatchRemoveEventLocked()
- e.setStateLocked(Failed)
- return
- }
-
e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe)
e.job.Schedule(config.RetransmitTimer)
}
@@ -315,15 +264,72 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
// being queued for outgoing transmission.
//
// Follows the logic defined in RFC 4861 section 7.3.3.
-func (e *neighborEntry) handlePacketQueuedLocked() {
+func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) {
switch e.neigh.State {
case Unknown:
- e.dispatchAddEventLocked(Incomplete)
- e.setStateLocked(Incomplete)
+ e.neigh.State = Incomplete
+ e.neigh.UpdatedAt = time.Now()
+
+ e.dispatchAddEventLocked()
+
+ config := e.nudState.Config()
+
+ var retryCounter uint32
+ var sendMulticastProbe func()
+
+ sendMulticastProbe = func() {
+ if retryCounter == config.MaxMulticastProbes {
+ // "If no Neighbor Advertisement is received after
+ // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed.
+ // The sender MUST return ICMP destination unreachable indications with
+ // code 3 (Address Unreachable) for each packet queued awaiting address
+ // resolution." - RFC 4861 section 7.2.2
+ //
+ // There is no need to send an ICMP destination unreachable indication
+ // since the failure to resolve the address is expected to only occur
+ // on this node. Thus, redirecting traffic is currently not supported.
+ //
+ // "If the error occurs on a node other than the node originating the
+ // packet, an ICMP error message is generated. If the error occurs on
+ // the originating node, an implementation is not required to actually
+ // create and send an ICMP error packet to the source, as long as the
+ // upper-layer sender is notified through an appropriate mechanism
+ // (e.g. return value from a procedure call). Note, however, that an
+ // implementation may find it convenient in some cases to return errors
+ // to the sender by taking the offending packet, generating an ICMP
+ // error message, and then delivering it (locally) through the generic
+ // error-handling routines.' - RFC 4861 section 2.1
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ // As per RFC 4861 section 7.2.2:
+ //
+ // If the source address of the packet prompting the solicitation is the
+ // same as one of the addresses assigned to the outgoing interface, that
+ // address SHOULD be placed in the IP Source Address of the outgoing
+ // solicitation.
+ //
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, "", e.nic); err != nil {
+ // There is no need to log the error here; the NUD implementation may
+ // assume a working link. A valid link should be the responsibility of
+ // the NIC/stack.LinkEndpoint.
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ retryCounter++
+ e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe)
+ e.job.Schedule(config.RetransmitTimer)
+ }
+
+ sendMulticastProbe()
case Stale:
- e.dispatchChangeEventLocked(Delay)
e.setStateLocked(Delay)
+ e.dispatchChangeEventLocked()
case Incomplete, Reachable, Delay, Probe, Static, Failed:
// Do nothing
@@ -345,21 +351,21 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) {
switch e.neigh.State {
case Unknown, Incomplete, Failed:
e.neigh.LinkAddr = remoteLinkAddr
- e.dispatchAddEventLocked(Stale)
e.setStateLocked(Stale)
e.notifyWakersLocked()
+ e.dispatchAddEventLocked()
case Reachable, Delay, Probe:
if e.neigh.LinkAddr != remoteLinkAddr {
e.neigh.LinkAddr = remoteLinkAddr
- e.dispatchChangeEventLocked(Stale)
e.setStateLocked(Stale)
+ e.dispatchChangeEventLocked()
}
case Stale:
if e.neigh.LinkAddr != remoteLinkAddr {
e.neigh.LinkAddr = remoteLinkAddr
- e.dispatchChangeEventLocked(Stale)
+ e.dispatchChangeEventLocked()
}
case Static:
@@ -393,12 +399,11 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
e.neigh.LinkAddr = linkAddr
if flags.Solicited {
- e.dispatchChangeEventLocked(Reachable)
e.setStateLocked(Reachable)
} else {
- e.dispatchChangeEventLocked(Stale)
e.setStateLocked(Stale)
}
+ e.dispatchChangeEventLocked()
e.isRouter = flags.IsRouter
e.notifyWakersLocked()
@@ -411,8 +416,8 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
if isLinkAddrDifferent {
if !flags.Override {
if e.neigh.State == Reachable {
- e.dispatchChangeEventLocked(Stale)
e.setStateLocked(Stale)
+ e.dispatchChangeEventLocked()
}
break
}
@@ -421,23 +426,24 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
if !flags.Solicited {
if e.neigh.State != Stale {
- e.dispatchChangeEventLocked(Stale)
e.setStateLocked(Stale)
+ e.dispatchChangeEventLocked()
} else {
// Notify the LinkAddr change, even though NUD state hasn't changed.
- e.dispatchChangeEventLocked(e.neigh.State)
+ e.dispatchChangeEventLocked()
}
break
}
}
if flags.Solicited && (flags.Override || !isLinkAddrDifferent) {
- if e.neigh.State != Reachable {
- e.dispatchChangeEventLocked(Reachable)
- }
+ wasReachable := e.neigh.State == Reachable
// Set state to Reachable again to refresh timers.
e.setStateLocked(Reachable)
e.notifyWakersLocked()
+ if !wasReachable {
+ e.dispatchChangeEventLocked()
+ }
}
if e.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.neigh.Addr) {
@@ -475,11 +481,12 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla
func (e *neighborEntry) handleUpperLevelConfirmationLocked() {
switch e.neigh.State {
case Reachable, Stale, Delay, Probe:
- if e.neigh.State != Reachable {
- e.dispatchChangeEventLocked(Reachable)
- // Set state to Reachable again to refresh timers.
- }
+ wasReachable := e.neigh.State == Reachable
+ // Set state to Reachable again to refresh timers.
e.setStateLocked(Reachable)
+ if !wasReachable {
+ e.dispatchChangeEventLocked()
+ }
case Unknown, Incomplete, Failed, Static:
// Do nothing
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index 3ee2a3b31..e8e0e571b 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -52,19 +52,16 @@ const (
// predict the time that an event will be dispatched.
func eventDiffOpts() []cmp.Option {
return []cmp.Option{
- cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"),
+ cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"),
}
}
// eventDiffOptsWithSort is like eventDiffOpts but also includes an option to
// sort slices of events for cases where ordering must be ignored.
func eventDiffOptsWithSort() []cmp.Option {
- return []cmp.Option{
- cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"),
- cmpopts.SortSlices(func(a, b testEntryEventInfo) bool {
- return strings.Compare(string(a.Addr), string(b.Addr)) < 0
- }),
- }
+ return append(eventDiffOpts(), cmpopts.SortSlices(func(a, b testEntryEventInfo) bool {
+ return strings.Compare(string(a.Entry.Addr), string(b.Entry.Addr)) < 0
+ }))
}
// The following unit tests exercise every state transition and verify its
@@ -125,14 +122,11 @@ func (t testEntryEventType) String() string {
type testEntryEventInfo struct {
EventType testEntryEventType
NICID tcpip.NICID
- Addr tcpip.Address
- LinkAddr tcpip.LinkAddress
- State NeighborState
- UpdatedAt time.Time
+ Entry NeighborEntry
}
func (e testEntryEventInfo) String() string {
- return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.EventType, e.NICID, e.Addr, e.LinkAddr, e.State)
+ return fmt.Sprintf("%s event for NIC #%d, %#v", e.EventType, e.NICID, e.Entry)
}
// testNUDDispatcher implements NUDDispatcher to validate the dispatching of
@@ -150,36 +144,27 @@ func (d *testNUDDispatcher) queueEvent(e testEntryEventInfo) {
d.events = append(d.events, e)
}
-func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry NeighborEntry) {
d.queueEvent(testEntryEventInfo{
EventType: entryTestAdded,
NICID: nicID,
- Addr: addr,
- LinkAddr: linkAddr,
- State: state,
- UpdatedAt: updatedAt,
+ Entry: entry,
})
}
-func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry NeighborEntry) {
d.queueEvent(testEntryEventInfo{
EventType: entryTestChanged,
NICID: nicID,
- Addr: addr,
- LinkAddr: linkAddr,
- State: state,
- UpdatedAt: updatedAt,
+ Entry: entry,
})
}
-func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry NeighborEntry) {
d.queueEvent(testEntryEventInfo{
EventType: entryTestRemoved,
NICID: nicID,
- Addr: addr,
- LinkAddr: linkAddr,
- State: state,
- UpdatedAt: updatedAt,
+ Entry: entry,
})
}
@@ -202,9 +187,9 @@ func (p entryTestProbeInfo) String() string {
// LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts
// to the local network if linkAddr is the zero value.
-func (r *entryTestLinkResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
+func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
p := entryTestProbeInfo{
- RemoteAddress: addr,
+ RemoteAddress: targetAddr,
RemoteLinkAddress: linkAddr,
LocalAddress: localAddr,
}
@@ -245,7 +230,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
nudState := NewNUDState(c, rng)
linkRes := entryTestLinkResolver{}
- entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, entryTestAddr2 /* localAddr */, nudState, &linkRes)
+ entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, nudState, &linkRes)
// Stub out the neighbor cache to verify deletion from the cache.
nic.neigh = &neighborCache{
@@ -326,7 +311,7 @@ func TestEntryUnknownToIncomplete(t *testing.T) {
e, nudDisp, linkRes, _ := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Incomplete; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
@@ -350,9 +335,11 @@ func TestEntryUnknownToIncomplete(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
}
{
@@ -388,9 +375,11 @@ func TestEntryUnknownToStale(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -406,7 +395,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Incomplete; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
@@ -468,16 +457,20 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestRemoved,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
}
nudDisp.mu.Lock()
@@ -498,7 +491,7 @@ func TestEntryIncompleteToReachable(t *testing.T) {
e, nudDisp, linkRes, _ := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Incomplete; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
@@ -530,16 +523,20 @@ func TestEntryIncompleteToReachable(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -563,7 +560,7 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
defer s.Done()
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got := e.wakers; got != nil {
t.Errorf("got e.wakers = %v, want = nil", got)
}
@@ -605,16 +602,20 @@ func TestEntryAddsAndClearsWakers(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -629,7 +630,7 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
e, nudDisp, linkRes, _ := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Incomplete; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
@@ -663,16 +664,20 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -687,7 +692,7 @@ func TestEntryIncompleteToStale(t *testing.T) {
e, nudDisp, linkRes, _ := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Incomplete; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
@@ -719,16 +724,20 @@ func TestEntryIncompleteToStale(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -744,7 +753,7 @@ func TestEntryIncompleteToFailed(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Incomplete; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
@@ -783,16 +792,20 @@ func TestEntryIncompleteToFailed(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestRemoved,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
}
nudDisp.mu.Lock()
@@ -822,7 +835,7 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: true,
Override: false,
@@ -866,16 +879,20 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -896,7 +913,7 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) {
e, nudDisp, linkRes, _ := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: true,
Override: false,
@@ -932,16 +949,20 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -961,7 +982,7 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: true,
Override: false,
@@ -992,23 +1013,29 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1029,7 +1056,7 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
e, nudDisp, linkRes, _ := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: true,
Override: false,
@@ -1062,23 +1089,29 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1099,7 +1132,7 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T)
e, nudDisp, linkRes, _ := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: true,
Override: false,
@@ -1136,23 +1169,29 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T)
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1173,7 +1212,7 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t
e, nudDisp, linkRes, _ := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: true,
Override: false,
@@ -1210,23 +1249,29 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1247,7 +1292,7 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
e, nudDisp, linkRes, _ := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -1283,16 +1328,20 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1307,7 +1356,7 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
e, nudDisp, linkRes, _ := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -1347,23 +1396,29 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1378,7 +1433,7 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
e, nudDisp, linkRes, _ := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -1418,23 +1473,29 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
}
nudDisp.mu.Lock()
@@ -1449,7 +1510,7 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
e, nudDisp, linkRes, _ := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -1489,23 +1550,29 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1520,7 +1587,7 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
e, nudDisp, linkRes, _ := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -1556,23 +1623,29 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1587,7 +1660,7 @@ func TestEntryStaleToDelay(t *testing.T) {
e, nudDisp, linkRes, _ := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
@@ -1596,7 +1669,7 @@ func TestEntryStaleToDelay(t *testing.T) {
if got, want := e.neigh.State, Stale; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Delay; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
@@ -1620,23 +1693,29 @@ func TestEntryStaleToDelay(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
}
nudDisp.mu.Lock()
@@ -1656,13 +1735,13 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Delay; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
@@ -1692,37 +1771,47 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1743,13 +1832,13 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Delay; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
@@ -1786,37 +1875,47 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1837,13 +1936,13 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if e.neigh.State != Delay {
t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay)
}
@@ -1880,37 +1979,47 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -1925,13 +2034,13 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
e, nudDisp, linkRes, _ := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Delay; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
@@ -1966,23 +2075,29 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
}
nudDisp.mu.Lock()
@@ -1997,13 +2112,13 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
e, nudDisp, linkRes, _ := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Delay; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
@@ -2031,30 +2146,38 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2069,13 +2192,13 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
e, nudDisp, linkRes, _ := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Delay; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
@@ -2107,30 +2230,38 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2145,13 +2276,13 @@ func TestEntryDelayToProbe(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
if got, want := e.neigh.State, Delay; got != want {
t.Errorf("got e.neigh.State = %q, want = %q", got, want)
}
@@ -2170,7 +2301,6 @@ func TestEntryDelayToProbe(t *testing.T) {
{
RemoteAddress: entryTestAddr1,
RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
},
}
linkRes.mu.Lock()
@@ -2184,30 +2314,38 @@ func TestEntryDelayToProbe(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
}
nudDisp.mu.Lock()
@@ -2228,13 +2366,13 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
@@ -2250,7 +2388,6 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
{
RemoteAddress: entryTestAddr1,
RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
},
}
linkRes.mu.Lock()
@@ -2274,37 +2411,47 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2325,13 +2472,13 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
@@ -2347,7 +2494,6 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
{
RemoteAddress: entryTestAddr1,
RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
},
}
linkRes.mu.Lock()
@@ -2375,37 +2521,47 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2426,13 +2582,13 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
@@ -2448,7 +2604,6 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
{
RemoteAddress: entryTestAddr1,
RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
},
}
linkRes.mu.Lock()
@@ -2479,30 +2634,38 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
}
nudDisp.mu.Lock()
@@ -2529,7 +2692,7 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
e.mu.Lock()
e.handleProbeLocked(entryTestLinkAddr1)
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
@@ -2539,7 +2702,6 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
{
RemoteAddress: entryTestAddr1,
RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
},
}
linkRes.mu.Lock()
@@ -2572,37 +2734,47 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2622,13 +2794,13 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
@@ -2644,7 +2816,6 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
{
RemoteAddress: entryTestAddr1,
RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
},
}
linkRes.mu.Lock()
@@ -2677,44 +2848,56 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr2,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2734,13 +2917,13 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
@@ -2756,7 +2939,6 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
{
RemoteAddress: entryTestAddr1,
RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
},
}
linkRes.mu.Lock()
@@ -2786,44 +2968,56 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2843,13 +3037,13 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
clock.Advance(c.DelayFirstProbeTime)
@@ -2865,7 +3059,6 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
{
RemoteAddress: entryTestAddr1,
RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
},
}
linkRes.mu.Lock()
@@ -2895,44 +3088,56 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Reachable,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
}
nudDisp.mu.Lock()
@@ -2946,87 +3151,116 @@ func TestEntryProbeToFailed(t *testing.T) {
c := DefaultNUDConfigurations()
c.MaxMulticastProbes = 3
c.MaxUnicastProbes = 3
+ c.DelayFirstProbeTime = c.RetransmitTimer
e, nudDisp, linkRes, clock := entryTestSetup(c)
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
+ e.mu.Unlock()
+
+ {
+ wantProbes := []entryTestProbeInfo{
+ // Caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ e.mu.Lock()
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
- waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes)
- clock.Advance(waitFor)
+ // Observe each probe sent while in the Probe state.
+ for i := uint32(0); i < c.MaxUnicastProbes; i++ {
+ clock.Advance(c.RetransmitTimer)
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probe #%d mismatch (-got, +want):\n%s", i+1, diff)
+ }
- wantProbes := []entryTestProbeInfo{
- // The first probe is caused by the Unknown-to-Incomplete transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: tcpip.LinkAddress(""),
- LocalAddress: entryTestAddr2,
- },
- // The next three probe are caused by the Delay-to-Probe transition.
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
- {
- RemoteAddress: entryTestAddr1,
- RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
- },
+ e.mu.Lock()
+ if e.neigh.State != Probe {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe)
+ }
+ e.mu.Unlock()
}
- linkRes.mu.Lock()
- diff := cmp.Diff(linkRes.probes, wantProbes)
- linkRes.mu.Unlock()
- if diff != "" {
- t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+
+ // Wait for the last probe to expire, causing a transition to Failed.
+ clock.Advance(c.RetransmitTimer)
+ e.mu.Lock()
+ if e.neigh.State != Failed {
+ t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed)
}
+ e.mu.Unlock()
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestRemoved,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
}
nudDisp.mu.Lock()
@@ -3034,12 +3268,6 @@ func TestEntryProbeToFailed(t *testing.T) {
t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
}
nudDisp.mu.Unlock()
-
- e.mu.Lock()
- if got, want := e.neigh.State, Failed; got != want {
- t.Errorf("got e.neigh.State = %q, want = %q", got, want)
- }
- e.mu.Unlock()
}
func TestEntryFailedGetsDeleted(t *testing.T) {
@@ -3054,13 +3282,13 @@ func TestEntryFailedGetsDeleted(t *testing.T) {
}
e.mu.Lock()
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
Solicited: false,
Override: false,
IsRouter: false,
})
- e.handlePacketQueuedLocked()
+ e.handlePacketQueuedLocked(entryTestAddr2)
e.mu.Unlock()
waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime
@@ -3077,17 +3305,14 @@ func TestEntryFailedGetsDeleted(t *testing.T) {
{
RemoteAddress: entryTestAddr1,
RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
},
{
RemoteAddress: entryTestAddr1,
RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
},
{
RemoteAddress: entryTestAddr1,
RemoteLinkAddress: entryTestLinkAddr1,
- LocalAddress: entryTestAddr2,
},
}
linkRes.mu.Lock()
@@ -3101,37 +3326,47 @@ func TestEntryFailedGetsDeleted(t *testing.T) {
{
EventType: entryTestAdded,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: tcpip.LinkAddress(""),
- State: Incomplete,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Stale,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Delay,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
},
{
EventType: entryTestChanged,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
{
EventType: entryTestRemoved,
NICID: entryTestNICID,
- Addr: entryTestAddr1,
- LinkAddr: entryTestLinkAddr1,
- State: Probe,
+ Entry: NeighborEntry{
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
},
}
nudDisp.mu.Lock()
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 8828cc5fe..17f2e6b46 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -23,7 +23,6 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -274,6 +273,15 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
return n.writePacket(r, gso, protocol, pkt)
}
+// WritePacketToRemote implements NetworkInterface.
+func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+ r := Route{
+ NetProto: protocol,
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+ return n.writePacket(&r, gso, protocol, pkt)
+}
+
func (n *NIC) writePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
// WritePacket takes ownership of pkt, calculate numBytes first.
numBytes := pkt.Size()
@@ -679,14 +687,17 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// n doesn't have a destination endpoint.
// Send the packet out of n.
- // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6.
+ // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease
+ // the TTL field for ipv4/ipv6.
// pkt may have set its header and may not have enough headroom for
// link-layer header for the other link to prepend. Here we create a new
// packet to forward.
fwdPkt := NewPacketBuffer(PacketBufferOptions{
ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()),
- Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ // We need to do a deep copy of the IP packet because WritePacket (and
+ // friends) take ownership of the packet buffer, but we do not own it.
+ Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(),
})
// TODO(b/143425874) Decrease the TTL field in forwarded packets.
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 97a96af62..4af04846f 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -169,7 +169,7 @@ func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements LinkAddressResolver.
-func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error {
+func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error {
return nil
}
diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go
index e1ec15487..ab629b3a4 100644
--- a/pkg/tcpip/stack/nud.go
+++ b/pkg/tcpip/stack/nud.go
@@ -129,7 +129,7 @@ type NUDDispatcher interface {
// the stack's operation.
//
// May be called concurrently.
- OnNeighborAdded(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+ OnNeighborAdded(tcpip.NICID, NeighborEntry)
// OnNeighborChanged will be called when an entry in a NIC's (with ID nicID)
// neighbor table changes state and/or link address.
@@ -138,7 +138,7 @@ type NUDDispatcher interface {
// the stack's operation.
//
// May be called concurrently.
- OnNeighborChanged(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+ OnNeighborChanged(tcpip.NICID, NeighborEntry)
// OnNeighborRemoved will be called when an entry is removed from a NIC's
// (with ID nicID) neighbor table.
@@ -147,7 +147,7 @@ type NUDDispatcher interface {
// the stack's operation.
//
// May be called concurrently.
- OnNeighborRemoved(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+ OnNeighborRemoved(tcpip.NICID, NeighborEntry)
}
// ReachabilityConfirmationFlags describes the flags used within a reachability
@@ -177,7 +177,7 @@ type NUDHandler interface {
// Neighbor Solicitation for ARP or NDP, respectively). Validation of the
// probe needs to be performed before calling this function since the
// Neighbor Cache doesn't have access to view the NIC's assigned addresses.
- HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver)
+ HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver)
// HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP
// reply or Neighbor Advertisement for ARP or NDP, respectively).
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 105583c49..7f54a6de8 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -311,11 +311,25 @@ func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) {
}
// PayloadSince returns packet payload starting from and including a particular
-// header. This method isn't optimized and should be used in test only.
+// header.
+//
+// The returned View is owned by the caller - its backing buffer is separate
+// from the packet header's underlying packet buffer.
func PayloadSince(h PacketHeader) buffer.View {
- var v buffer.View
+ size := h.pk.Data.Size()
+ for _, hinfo := range h.pk.headers[h.typ:] {
+ size += len(hinfo.buf)
+ }
+
+ v := make(buffer.View, 0, size)
+
for _, hinfo := range h.pk.headers[h.typ:] {
v = append(v, hinfo.buf...)
}
- return append(v, h.pk.Data.ToView()...)
+
+ for _, view := range h.pk.Data.Views() {
+ v = append(v, view...)
+ }
+
+ return v
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index defb9129b..203f3b51f 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -490,6 +490,9 @@ type NetworkInterface interface {
// Enabled returns true if the interface is enabled.
Enabled() bool
+
+ // WritePacketToRemote writes the packet to the given remote link address.
+ WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error
}
// NetworkEndpoint is the interface that needs to be implemented by endpoints
@@ -764,13 +767,13 @@ type InjectableLinkEndpoint interface {
// A LinkAddressResolver is an extension to a NetworkProtocol that
// can resolve link addresses.
type LinkAddressResolver interface {
- // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts
- // the request on the local network if remoteLinkAddr is the zero value. The
- // request is sent on linkEP with localAddr as the source.
+ // LinkAddressRequest sends a request for the link address of the target
+ // address. The request is broadcasted on the local network if a remote link
+ // address is not provided.
//
- // A valid response will cause the discovery protocol's network
- // endpoint to call AddLinkAddress.
- LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error
+ // The request is sent from the passed network interface. If the interface
+ // local address is unspecified, any interface local address may be used.
+ LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) *tcpip.Error
// ResolveStaticAddress attempts to resolve address without sending
// requests. It either resolves the name immediately or returns the
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 3a07577c8..e8f1c110e 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -830,6 +830,20 @@ func (s *Stack) AddRoute(route tcpip.Route) {
s.routeTable = append(s.routeTable, route)
}
+// RemoveRoutes removes matching routes from the route table.
+func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ var filteredRoutes []tcpip.Route
+ for _, route := range s.routeTable {
+ if !match(route) {
+ filteredRoutes = append(filteredRoutes, route)
+ }
+ }
+ s.routeTable = filteredRoutes
+}
+
// NewEndpoint creates a new transport layer endpoint of the given protocol.
func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
t, ok := s.transportProtocols[transport]
@@ -1323,7 +1337,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address,
fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
linkRes := s.linkAddrResolvers[protocol]
- return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.LinkEndpoint, waker)
+ return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, waker)
}
// Neighbors returns all IP to MAC address associations.
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index e75f58c64..4eed4ced4 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -3672,3 +3672,87 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) {
t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix)
}
}
+
+// TestAddRoute tests Stack.AddRoute
+func TestAddRoute(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{})
+
+ subnet1, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ subnet2, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ expected := []tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: 1},
+ {Destination: subnet2, Gateway: "\x00", NIC: 1},
+ }
+
+ // Initialize the route table with one route.
+ s.SetRouteTable([]tcpip.Route{expected[0]})
+
+ // Add another route.
+ s.AddRoute(expected[1])
+
+ rt := s.GetRouteTable()
+ if got, want := len(rt), len(expected); got != want {
+ t.Fatalf("Unexpected route table length got = %d, want = %d", got, want)
+ }
+ for i, route := range rt {
+ if got, want := route, expected[i]; got != want {
+ t.Fatalf("Unexpected route got = %#v, want = %#v", got, want)
+ }
+ }
+}
+
+// TestRemoveRoutes tests Stack.RemoveRoutes
+func TestRemoveRoutes(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{})
+
+ addressToRemove := tcpip.Address("\x01")
+ subnet1, err := tcpip.NewSubnet(addressToRemove, "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ subnet2, err := tcpip.NewSubnet(addressToRemove, "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ subnet3, err := tcpip.NewSubnet("\x02", "\x02")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Initialize the route table with three routes.
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: 1},
+ {Destination: subnet2, Gateway: "\x00", NIC: 1},
+ {Destination: subnet3, Gateway: "\x00", NIC: 1},
+ })
+
+ // Remove routes with the specific address.
+ s.RemoveRoutes(func(r tcpip.Route) bool {
+ return r.Destination.ID() == addressToRemove
+ })
+
+ expected := []tcpip.Route{{Destination: subnet3, Gateway: "\x00", NIC: 1}}
+ rt := s.GetRouteTable()
+ if got, want := len(rt), len(expected); got != want {
+ t.Fatalf("Unexpected route table length got = %d, want = %d", got, want)
+ }
+ for i, route := range rt {
+ if got, want := route, expected[i]; got != want {
+ t.Fatalf("Unexpected route got = %#v, want = %#v", got, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 62ab6d92f..6b8071467 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -28,7 +28,7 @@ import (
const (
fakeTransNumber tcpip.TransportProtocolNumber = 1
- fakeTransHeaderLen = 3
+ fakeTransHeaderLen int = 3
)
// fakeTransportEndpoint is a transport-layer protocol endpoint. It counts
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index c42bb0991..3ab2b7654 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -111,6 +111,7 @@ var (
ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"}
ErrNotPermitted = &Error{msg: "operation not permitted"}
ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"}
+ ErrMalformedHeader = &Error{msg: "header is malformed"}
)
var messageToError map[string]*Error
@@ -159,6 +160,7 @@ func StringToError(s string) *Error {
ErrBroadcastDisabled,
ErrNotPermitted,
ErrAddressFamilyNotSupported,
+ ErrMalformedHeader,
}
messageToError = make(map[string]*Error)
@@ -354,10 +356,9 @@ func (s *Subnet) IsBroadcast(address Address) bool {
return s.Prefix() <= 30 && s.Broadcast() == address
}
-// Equal returns true if s equals o.
-//
-// Needed to use cmp.Equal on Subnet as its fields are unexported.
+// Equal returns true if this Subnet is equal to the given Subnet.
func (s Subnet) Equal(o Subnet) bool {
+ // If this changes, update Route.Equal accordingly.
return s == o
}
@@ -761,6 +762,10 @@ const (
// endpoint that all packets being written have an IP header and the
// endpoint should not attach an IP header.
IPHdrIncludedOption
+
+ // AcceptConnOption is used by GetSockOptBool to indicate if the
+ // socket is a listening socket.
+ AcceptConnOption
)
// SockOptInt represents socket options which values have the int type.
@@ -1254,6 +1259,12 @@ func (r Route) String() string {
return out.String()
}
+// Equal returns true if the given Route is equal to this Route.
+func (r Route) Equal(to Route) bool {
+ // NOTE: This relies on the fact that r.Destination == to.Destination
+ return r == to
+}
+
// TransportProtocolNumber is the number of a transport protocol.
type TransportProtocolNumber uint32
@@ -1494,6 +1505,15 @@ type IPStats struct {
// IPTablesOutputDropped is the total number of IP packets dropped in
// the Output chain.
IPTablesOutputDropped *StatCounter
+
+ // OptionTSReceived is the number of Timestamp options seen.
+ OptionTSReceived *StatCounter
+
+ // OptionRRReceived is the number of Record Route options seen.
+ OptionRRReceived *StatCounter
+
+ // OptionUnknownReceived is the number of unknown IP options seen.
+ OptionUnknownReceived *StatCounter
}
// TCPStats collects TCP-specific stats.
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index a4f141253..34aab32d0 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -16,6 +16,7 @@ go_test(
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/ethernet",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/link/pipe",
"//pkg/tcpip/network/arp",
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index ffd38ee1a..0dcef7b04 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -21,6 +21,7 @@ import (
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/link/ethernet"
"gvisor.dev/gvisor/pkg/tcpip/link/pipe"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -178,19 +179,19 @@ func TestForwarding(t *testing.T) {
routerStack := stack.New(stackOpts)
host2Stack := stack.New(stackOpts)
- host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr, stack.CapabilityResolutionRequired)
- routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr, stack.CapabilityResolutionRequired)
+ host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr)
+ routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr)
- if err := host1Stack.CreateNIC(host1NICID, host1NIC); err != nil {
+ if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil {
t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
}
- if err := routerStack.CreateNIC(routerNICID1, routerNIC1); err != nil {
+ if err := routerStack.CreateNIC(routerNICID1, ethernet.New(routerNIC1)); err != nil {
t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err)
}
- if err := routerStack.CreateNIC(routerNICID2, routerNIC2); err != nil {
+ if err := routerStack.CreateNIC(routerNICID2, ethernet.New(routerNIC2)); err != nil {
t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err)
}
- if err := host2Stack.CreateNIC(host2NICID, host2NIC); err != nil {
+ if err := host2Stack.CreateNIC(host2NICID, ethernet.New(host2NIC)); err != nil {
t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index bf3a6f6ee..6ddcda70c 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/ethernet"
"gvisor.dev/gvisor/pkg/tcpip/link/pipe"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -126,12 +127,12 @@ func TestPing(t *testing.T) {
host1Stack := stack.New(stackOpts)
host2Stack := stack.New(stackOpts)
- host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr, stack.CapabilityResolutionRequired)
+ host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr)
- if err := host1Stack.CreateNIC(host1NICID, host1NIC); err != nil {
+ if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil {
t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err)
}
- if err := host2Stack.CreateNIC(host2NICID, host2NIC); err != nil {
+ if err := host2Stack.CreateNIC(host2NICID, ethernet.New(host2NIC)); err != nil {
t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err)
}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index 4f2ca7f54..f1028823b 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -80,6 +80,7 @@ func TestPingMulticastBroadcast(t *testing.T) {
SrcAddr: remoteIPv4Addr,
DstAddr: dst,
})
+ ip.SetChecksum(^ip.CalculateChecksum())
e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
@@ -250,6 +251,7 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
SrcAddr: remoteIPv4Addr,
DstAddr: dst,
})
+ ip.SetChecksum(^ip.CalculateChecksum())
e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: hdr.View().ToVectorisedView(),
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 41eb0ca44..a17234946 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -378,7 +378,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
switch opt {
- case tcpip.KeepaliveEnabledOption:
+ case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption:
return false, nil
default:
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 072601d2d..31831a6d8 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -389,7 +389,12 @@ func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- return false, tcpip.ErrNotSupported
+ switch opt {
+ case tcpip.AcceptConnOption:
+ return false, nil
+ default:
+ return false, tcpip.ErrNotSupported
+ }
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index e37c00523..79f688129 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -601,7 +601,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
switch opt {
- case tcpip.KeepaliveEnabledOption:
+ case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption:
return false, nil
case tcpip.IPHdrIncludedOption:
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
index 33bfb56cd..7d97cbdc7 100644
--- a/pkg/tcpip/transport/raw/endpoint_state.go
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -37,57 +37,57 @@ func (p *rawPacket) loadData(data buffer.VectorisedView) {
}
// beforeSave is invoked by stateify.
-func (ep *endpoint) beforeSave() {
+func (e *endpoint) beforeSave() {
// Stop incoming packets from being handled (and mutate endpoint state).
// The lock will be released after saveRcvBufSizeMax(), which would have
- // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming
// packets.
- ep.rcvMu.Lock()
+ e.rcvMu.Lock()
}
// saveRcvBufSizeMax is invoked by stateify.
-func (ep *endpoint) saveRcvBufSizeMax() int {
- max := ep.rcvBufSizeMax
+func (e *endpoint) saveRcvBufSizeMax() int {
+ max := e.rcvBufSizeMax
// Make sure no new packets will be handled regardless of the lock.
- ep.rcvBufSizeMax = 0
+ e.rcvBufSizeMax = 0
// Release the lock acquired in beforeSave() so regular endpoint closing
// logic can proceed after save.
- ep.rcvMu.Unlock()
+ e.rcvMu.Unlock()
return max
}
// loadRcvBufSizeMax is invoked by stateify.
-func (ep *endpoint) loadRcvBufSizeMax(max int) {
- ep.rcvBufSizeMax = max
+func (e *endpoint) loadRcvBufSizeMax(max int) {
+ e.rcvBufSizeMax = max
}
// afterLoad is invoked by stateify.
-func (ep *endpoint) afterLoad() {
- stack.StackFromEnv.RegisterRestoredEndpoint(ep)
+func (e *endpoint) afterLoad() {
+ stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
// Resume implements tcpip.ResumableEndpoint.Resume.
-func (ep *endpoint) Resume(s *stack.Stack) {
- ep.stack = s
+func (e *endpoint) Resume(s *stack.Stack) {
+ e.stack = s
// If the endpoint is connected, re-connect.
- if ep.connected {
+ if e.connected {
var err *tcpip.Error
- ep.route, err = ep.stack.FindRoute(ep.RegisterNICID, ep.BindAddr, ep.route.RemoteAddress, ep.NetProto, false)
+ e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.route.RemoteAddress, e.NetProto, false)
if err != nil {
panic(err)
}
}
// If the endpoint is bound, re-bind.
- if ep.bound {
- if ep.stack.CheckLocalAddress(ep.RegisterNICID, ep.NetProto, ep.BindAddr) == 0 {
+ if e.bound {
+ if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.BindAddr) == 0 {
panic(tcpip.ErrBadLocalAddress)
}
}
- if ep.associated {
- if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil {
+ if e.associated {
+ if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil {
panic(err)
}
}
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index b706438bd..6b3238d6b 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -425,20 +425,17 @@ func (e *endpoint) notifyAborted() {
// cookies to accept connections.
func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
defer ctx.synRcvdCount.dec()
- defer func() {
- e.mu.Lock()
- e.decSynRcvdCount()
- e.mu.Unlock()
- }()
defer s.decRef()
n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
+ e.decSynRcvdCount()
return
}
ctx.removePendingEndpoint(n)
+ e.decSynRcvdCount()
n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
@@ -456,7 +453,9 @@ func (e *endpoint) incSynRcvdCount() bool {
}
func (e *endpoint) decSynRcvdCount() {
+ e.mu.Lock()
e.synRcvdCount--
+ e.mu.Unlock()
}
func (e *endpoint) acceptQueueIsFull() bool {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 3bcd3923a..c826942e9 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1999,6 +1999,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
case tcpip.MulticastLoopOption:
return true, nil
+ case tcpip.AcceptConnOption:
+ e.LockUser()
+ defer e.UnlockUser()
+
+ return e.EndpointState() == StateListen, nil
+
default:
return false, tcpip.ErrUnknownProtocolOption
}
diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard.go b/pkg/tcpip/transport/tcp/sack_scoreboard.go
index 7ef2df377..833a7b470 100644
--- a/pkg/tcpip/transport/tcp/sack_scoreboard.go
+++ b/pkg/tcpip/transport/tcp/sack_scoreboard.go
@@ -164,7 +164,7 @@ func (s *SACKScoreboard) IsSACKED(r header.SACKBlock) bool {
return found
}
-// Dump prints the state of the scoreboard structure.
+// String returns human-readable state of the scoreboard structure.
func (s *SACKScoreboard) String() string {
var str strings.Builder
str.WriteString("SACKScoreboard: {")
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index a7149efd0..5f05608e2 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -5131,6 +5131,7 @@ func TestKeepalive(t *testing.T) {
}
func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
+ t.Helper()
// Send a SYN request.
irs = seqnum.Value(789)
c.SendPacket(nil, &context.Headers{
@@ -5175,6 +5176,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki
}
func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
+ t.Helper()
// Send a SYN request.
irs = seqnum.Value(789)
c.SendV6Packet(nil, &context.Headers{
@@ -5238,13 +5240,14 @@ func TestListenBacklogFull(t *testing.T) {
// Test acceptance.
// Start listening.
- listenBacklog := 2
+ listenBacklog := 10
if err := c.EP.Listen(listenBacklog); err != nil {
t.Fatalf("Listen failed: %s", err)
}
- for i := 0; i < listenBacklog; i++ {
- executeHandshake(t, c, context.TestPort+uint16(i), false /*synCookieInUse */)
+ lastPortOffset := uint16(0)
+ for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ {
+ executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */)
}
time.Sleep(50 * time.Millisecond)
@@ -5252,7 +5255,7 @@ func TestListenBacklogFull(t *testing.T) {
// Now execute send one more SYN. The stack should not respond as the backlog
// is full at this point.
c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + 2,
+ SrcPort: context.TestPort + uint16(lastPortOffset),
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: seqnum.Value(789),
@@ -5293,7 +5296,7 @@ func TestListenBacklogFull(t *testing.T) {
}
// Now a new handshake must succeed.
- executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */)
+ executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */)
newEP, _, err := c.EP.Accept(nil)
if err == tcpip.ErrWouldBlock {
@@ -6722,6 +6725,13 @@ func TestTCPTimeWaitNewSyn(t *testing.T) {
c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second)
+ // drain any older notifications from the notification channel before attempting
+ // 2nd connection.
+ select {
+ case <-ch:
+ default:
+ }
+
// Send a SYN request w/ sequence number higher than
// the highest sequence number sent.
iss = seqnum.Value(792)
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 4d7847142..79646fefe 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -373,6 +373,9 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code header.ICMPv4Code,
const icmpv4VariableHeaderOffset = 4
copy(icmp[icmpv4VariableHeaderOffset:], p1)
copy(icmp[header.ICMPv4PayloadOffset:], p2)
+ icmp.SetChecksum(0)
+ checksum := ^header.Checksum(icmp, 0 /* initial */)
+ icmp.SetChecksum(checksum)
// Inject packet.
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index d57ed5d79..cdb5127ab 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -895,6 +895,9 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
return v, nil
+ case tcpip.AcceptConnOption:
+ return false, nil
+
default:
return false, tcpip.ErrUnknownProtocolOption
}
@@ -1366,6 +1369,12 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
e.rcvMu.Unlock()
}
+ e.lastErrorMu.Lock()
+ hasError := e.lastError != nil
+ e.lastErrorMu.Unlock()
+ if hasError {
+ result |= waiter.EventErr
+ }
return result
}
@@ -1465,14 +1474,16 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
if typ == stack.ControlPortUnreachable {
e.mu.RLock()
- defer e.mu.RUnlock()
-
if e.state == StateConnected {
e.lastErrorMu.Lock()
- defer e.lastErrorMu.Unlock()
-
e.lastError = tcpip.ErrConnectionRefused
+ e.lastErrorMu.Unlock()
+ e.mu.RUnlock()
+
+ e.waiterQueue.Notify(waiter.EventErr)
+ return
}
+ e.mu.RUnlock()
}
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index b4604ba35..fb7738dda 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -1452,6 +1452,10 @@ func (*testInterface) Enabled() bool {
return true
}
+func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
func TestTTL(t *testing.T) {
for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
@@ -1791,7 +1795,6 @@ func TestV4UnknownDestination(t *testing.T) {
// had only a minimal IP header but the ICMP sender will have allowed
// for a maximally sized packet header.
wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
-
}
// In the case of large payloads the IP packet may be truncated. Update
diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go
index 67a950444..08519d986 100644
--- a/pkg/waiter/waiter.go
+++ b/pkg/waiter/waiter.go
@@ -168,7 +168,7 @@ func NewChannelEntry(c chan struct{}) (Entry, chan struct{}) {
//
// +stateify savable
type Queue struct {
- list waiterList `state:"zerovalue"`
+ list waiterList
mu sync.RWMutex `state:"nosave"`
}
diff --git a/runsc/BUILD b/runsc/BUILD
index 33d8554af..3b91b984a 100644
--- a/runsc/BUILD
+++ b/runsc/BUILD
@@ -13,16 +13,7 @@ go_binary(
"//visibility:public",
],
x_defs = {"main.version": "{STABLE_VERSION}"},
- deps = [
- "//pkg/log",
- "//pkg/refs",
- "//pkg/sentry/platform",
- "//runsc/cmd",
- "//runsc/config",
- "//runsc/flag",
- "//runsc/specutils",
- "@com_github_google_subcommands//:go_default_library",
- ],
+ deps = ["//runsc/cli"],
)
# The runsc-race target is a race-compatible BUILD target. This must be built
@@ -49,16 +40,7 @@ go_binary(
"//visibility:public",
],
x_defs = {"main.version": "{STABLE_VERSION}"},
- deps = [
- "//pkg/log",
- "//pkg/refs",
- "//pkg/sentry/platform",
- "//runsc/cmd",
- "//runsc/config",
- "//runsc/flag",
- "//runsc/specutils",
- "@com_github_google_subcommands//:go_default_library",
- ],
+ deps = ["//runsc/cli"],
)
sh_test(
diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD
index 2d9517f4a..b97dc3c47 100644
--- a/runsc/boot/BUILD
+++ b/runsc/boot/BUILD
@@ -38,6 +38,7 @@ go_library(
"//pkg/memutil",
"//pkg/rand",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/sentry/arch",
"//pkg/sentry/arch:registers_go_proto",
"//pkg/sentry/control",
@@ -110,8 +111,8 @@ go_library(
"//runsc/config",
"//runsc/specutils",
"//runsc/specutils/seccomp",
- "@com_github_golang_protobuf//proto:go_default_library",
"@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
+ "@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/runsc/boot/compat.go b/runsc/boot/compat.go
index 84c67cbc2..7076ae2e2 100644
--- a/runsc/boot/compat.go
+++ b/runsc/boot/compat.go
@@ -19,7 +19,7 @@ import (
"os"
"syscall"
- "github.com/golang/protobuf/proto"
+ "google.golang.org/protobuf/proto"
"gvisor.dev/gvisor/pkg/eventchannel"
"gvisor.dev/gvisor/pkg/log"
rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go
index 894651519..4e0f0d57a 100644
--- a/runsc/boot/controller.go
+++ b/runsc/boot/controller.go
@@ -30,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/netstack"
"gvisor.dev/gvisor/pkg/sentry/state"
"gvisor.dev/gvisor/pkg/sentry/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sentry/watchdog"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/urpc"
@@ -367,12 +368,20 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error {
cm.l.k = k
// Set up the restore environment.
+ ctx := k.SupervisorContext()
mntr := newContainerMounter(cm.l.root.spec, cm.l.root.goferFDs, cm.l.k, cm.l.mountHints)
- renv, err := mntr.createRestoreEnvironment(cm.l.root.conf)
- if err != nil {
- return fmt.Errorf("creating RestoreEnvironment: %v", err)
+ if kernel.VFS2Enabled {
+ ctx, err = mntr.configureRestore(ctx, cm.l.root.conf)
+ if err != nil {
+ return fmt.Errorf("configuring filesystem restore: %v", err)
+ }
+ } else {
+ renv, err := mntr.createRestoreEnvironment(cm.l.root.conf)
+ if err != nil {
+ return fmt.Errorf("creating RestoreEnvironment: %v", err)
+ }
+ fs.SetRestoreEnvironment(*renv)
}
- fs.SetRestoreEnvironment(*renv)
// Prepare to load from the state file.
if eps, ok := networkStack.(*netstack.Stack); ok {
@@ -399,7 +408,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error {
// Load the state.
loadOpts := state.LoadOpts{Source: specFile}
- if err := loadOpts.Load(k, networkStack, time.NewCalibratedClocks()); err != nil {
+ if err := loadOpts.Load(ctx, k, networkStack, time.NewCalibratedClocks(), &vfs.CompleteRestoreOptions{}); err != nil {
return err
}
diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go
index ddf288456..6b6ae98d7 100644
--- a/runsc/boot/fs.go
+++ b/runsc/boot/fs.go
@@ -105,33 +105,28 @@ func addOverlay(ctx context.Context, conf *config.Config, lower *fs.Inode, name
// mandatory mounts that are required by the OCI specification.
func compileMounts(spec *specs.Spec) []specs.Mount {
// Keep track of whether proc and sys were mounted.
- var procMounted, sysMounted bool
+ var procMounted, sysMounted, devMounted, devptsMounted bool
var mounts []specs.Mount
- // Always mount /dev.
- mounts = append(mounts, specs.Mount{
- Type: devtmpfs.Name,
- Destination: "/dev",
- })
-
- mounts = append(mounts, specs.Mount{
- Type: devpts.Name,
- Destination: "/dev/pts",
- })
-
// Mount all submounts from the spec.
for _, m := range spec.Mounts {
if !specutils.IsSupportedDevMount(m) {
log.Warningf("ignoring dev mount at %q", m.Destination)
continue
}
- mounts = append(mounts, m)
switch filepath.Clean(m.Destination) {
case "/proc":
procMounted = true
case "/sys":
sysMounted = true
+ case "/dev":
+ m.Type = devtmpfs.Name
+ devMounted = true
+ case "/dev/pts":
+ m.Type = devpts.Name
+ devptsMounted = true
}
+ mounts = append(mounts, m)
}
// Mount proc and sys even if the user did not ask for it, as the spec
@@ -149,6 +144,18 @@ func compileMounts(spec *specs.Spec) []specs.Mount {
Destination: "/sys",
})
}
+ if !devMounted {
+ mandatoryMounts = append(mandatoryMounts, specs.Mount{
+ Type: devtmpfs.Name,
+ Destination: "/dev",
+ })
+ }
+ if !devptsMounted {
+ mandatoryMounts = append(mandatoryMounts, specs.Mount{
+ Type: devpts.Name,
+ Destination: "/dev/pts",
+ })
+ }
// The mandatory mounts should be ordered right after the root, in case
// there are submounts of these mandatory mounts already in the spec.
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index 8ad000497..8c6ab213d 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -35,6 +35,7 @@ import (
"gvisor.dev/gvisor/pkg/memutil"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/fdimport"
@@ -476,6 +477,12 @@ func (l *Loader) Destroy() {
// save/restore.
l.k.Release()
+ // All sentry-created resources should have been released at this point;
+ // check for reference leaks.
+ if refsvfs2.LeakCheckEnabled() {
+ refsvfs2.DoLeakCheck()
+ }
+
// In the success case, stdioFDs and goferFDs will only contain
// released/closed FDs that ownership has been passed over to host FDs and
// gofer sessions. Close them here in case of failure.
@@ -737,7 +744,7 @@ func (l *Loader) createContainerProcess(root bool, cid string, info *containerIn
return nil, err
}
- // Add the HOME enviroment variable if it is not already set.
+ // Add the HOME environment variable if it is not already set.
var envv []string
if kernel.VFS2Enabled {
envv, err = user.MaybeAddExecUserHomeVFS2(ctx, info.procArgs.MountNamespaceVFS2,
diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go
index e376f944b..b77b4762e 100644
--- a/runsc/boot/loader_test.go
+++ b/runsc/boot/loader_test.go
@@ -266,7 +266,7 @@ type CreateMountTestcase struct {
func createMountTestcases() []*CreateMountTestcase {
testCases := []*CreateMountTestcase{
- &CreateMountTestcase{
+ {
// Only proc.
name: "only proc mount",
spec: specs.Spec{
@@ -304,11 +304,10 @@ func createMountTestcases() []*CreateMountTestcase {
},
},
},
- // /some/deep/path should be mounted, along with /proc,
- // /dev, and /sys.
+ // /some/deep/path should be mounted, along with /proc, /dev, and /sys.
expectedPaths: []string{"/some/very/very/deep/path", "/proc", "/dev", "/sys"},
},
- &CreateMountTestcase{
+ {
// Mounts are nested inside each other.
name: "nested mounts",
spec: specs.Spec{
@@ -352,7 +351,7 @@ func createMountTestcases() []*CreateMountTestcase {
expectedPaths: []string{"/foo", "/foo/bar", "/foo/bar/baz", "/foo/qux",
"/foo/qux-quz", "/foo/some/very/very/deep/path", "/proc", "/dev", "/sys"},
},
- &CreateMountTestcase{
+ {
name: "mount inside /dev",
spec: specs.Spec{
Root: &specs.Root{
@@ -395,35 +394,37 @@ func createMountTestcases() []*CreateMountTestcase {
},
expectedPaths: []string{"/proc", "/dev", "/dev/fd-foo", "/dev/foo", "/dev/bar", "/sys"},
},
- }
-
- vfsCase := &CreateMountTestcase{
- name: "mounts inside mandatory mounts",
- spec: specs.Spec{
- Root: &specs.Root{
- Path: os.TempDir(),
- Readonly: true,
- },
- Mounts: []specs.Mount{
- {
- Destination: "/proc",
- Type: "tmpfs",
- },
- {
- Destination: "/sys/bar",
- Type: "tmpfs",
+ {
+ name: "mounts inside mandatory mounts",
+ spec: specs.Spec{
+ Root: &specs.Root{
+ Path: os.TempDir(),
+ Readonly: true,
},
-
- {
- Destination: "/tmp/baz",
- Type: "tmpfs",
+ Mounts: []specs.Mount{
+ {
+ Destination: "/proc",
+ Type: "tmpfs",
+ },
+ {
+ Destination: "/sys/bar",
+ Type: "tmpfs",
+ },
+ {
+ Destination: "/tmp/baz",
+ Type: "tmpfs",
+ },
+ {
+ Destination: "/dev/goo",
+ Type: "tmpfs",
+ },
},
},
+ expectedPaths: []string{"/proc", "/sys", "/sys/bar", "/tmp", "/tmp/baz", "/dev/goo"},
},
- expectedPaths: []string{"/proc", "/sys", "/sys/bar", "/tmp", "/tmp/baz"},
}
- return append(testCases, vfsCase)
+ return testCases
}
// Test that MountNamespace can be created with various specs.
diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go
index 82e459f46..b157387ef 100644
--- a/runsc/boot/vfs.go
+++ b/runsc/boot/vfs.go
@@ -210,6 +210,9 @@ func (c *containerMounter) createMountNamespaceVFS2(ctx context.Context, conf *c
ReadOnly: c.root.Readonly,
GetFilesystemOptions: vfs.GetFilesystemOptions{
Data: strings.Join(data, ","),
+ InternalData: gofer.InternalFilesystemOptions{
+ UniqueID: "/",
+ },
},
InternalMount: true,
}
@@ -264,10 +267,38 @@ func (c *containerMounter) configureOverlay(ctx context.Context, creds *auth.Cre
}
cu.Add(func() { lower.DecRef(ctx) })
+ // Propagate the lower layer's root's owner, group, and mode to the upper
+ // layer's root for consistency with VFS1.
+ upperRootVD := vfs.MakeVirtualDentry(upper, upper.Root())
+ lowerRootVD := vfs.MakeVirtualDentry(lower, lower.Root())
+ stat, err := c.k.VFS().StatAt(ctx, creds, &vfs.PathOperation{
+ Root: lowerRootVD,
+ Start: lowerRootVD,
+ }, &vfs.StatOptions{
+ Mask: linux.STATX_UID | linux.STATX_GID | linux.STATX_MODE,
+ })
+ if err != nil {
+ return nil, nil, err
+ }
+ err = c.k.VFS().SetStatAt(ctx, creds, &vfs.PathOperation{
+ Root: upperRootVD,
+ Start: upperRootVD,
+ }, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: (linux.STATX_UID | linux.STATX_GID | linux.STATX_MODE) & stat.Mask,
+ UID: stat.UID,
+ GID: stat.GID,
+ Mode: stat.Mode,
+ },
+ })
+ if err != nil {
+ return nil, nil, err
+ }
+
// Configure overlay with both layers.
overlayOpts.GetFilesystemOptions.InternalData = overlay.FilesystemOptions{
- UpperRoot: vfs.MakeVirtualDentry(upper, upper.Root()),
- LowerRoots: []vfs.VirtualDentry{vfs.MakeVirtualDentry(lower, lower.Root())},
+ UpperRoot: upperRootVD,
+ LowerRoots: []vfs.VirtualDentry{lowerRootVD},
}
return &overlayOpts, cu.Release(), nil
}
@@ -399,6 +430,7 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo
fsName := m.Type
useOverlay := false
var data []string
+ var iopts interface{}
// Find filesystem name and FS specific data field.
switch m.Type {
@@ -423,6 +455,9 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo
return "", nil, false, fmt.Errorf("9P mount requires a connection FD")
}
data = p9MountData(m.fd, c.getMountAccessType(m.Mount), true /* vfs2 */)
+ iopts = gofer.InternalFilesystemOptions{
+ UniqueID: m.Destination,
+ }
// If configured, add overlay to all writable mounts.
useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly
@@ -434,7 +469,8 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo
opts := &vfs.MountOptions{
GetFilesystemOptions: vfs.GetFilesystemOptions{
- Data: strings.Join(data, ","),
+ Data: strings.Join(data, ","),
+ InternalData: iopts,
},
InternalMount: true,
}
@@ -639,3 +675,21 @@ func (c *containerMounter) makeMountPoint(ctx context.Context, creds *auth.Crede
}
return c.k.VFS().MakeSyntheticMountpoint(ctx, dest, root, creds)
}
+
+// configureRestore returns an updated context.Context including filesystem
+// state used by restore defined by conf.
+func (c *containerMounter) configureRestore(ctx context.Context, conf *config.Config) (context.Context, error) {
+ fdmap := make(map[string]int)
+ fdmap["/"] = c.fds.remove()
+ mounts, err := c.prepareMountsVFS2()
+ if err != nil {
+ return ctx, err
+ }
+ for i := range c.mounts {
+ submount := &mounts[i]
+ if submount.fd >= 0 {
+ fdmap[submount.Destination] = submount.fd
+ }
+ }
+ return context.WithValue(ctx, gofer.CtxRestoreServerFDMap, fdmap), nil
+}
diff --git a/runsc/cgroup/cgroup.go b/runsc/cgroup/cgroup.go
index 8fbc3887a..5bd0afc52 100644
--- a/runsc/cgroup/cgroup.go
+++ b/runsc/cgroup/cgroup.go
@@ -21,6 +21,7 @@ import (
"context"
"errors"
"fmt"
+ "io"
"io/ioutil"
"os"
"path/filepath"
@@ -198,16 +199,26 @@ func LoadPaths(pid string) (map[string]string, error) {
}
defer f.Close()
+ return loadPathsHelper(f)
+}
+
+func loadPathsHelper(cgroup io.Reader) (map[string]string, error) {
paths := make(map[string]string)
- scanner := bufio.NewScanner(f)
+
+ scanner := bufio.NewScanner(cgroup)
for scanner.Scan() {
- // Format: ID:controller1,controller2:path
+ // Format: ID:[name=]controller1,controller2:path
// Example: 2:cpu,cpuacct:/user.slice
tokens := strings.Split(scanner.Text(), ":")
if len(tokens) != 3 {
return nil, fmt.Errorf("invalid cgroups file, line: %q", scanner.Text())
}
+ if len(tokens[1]) == 0 {
+ continue
+ }
for _, ctrlr := range strings.Split(tokens[1], ",") {
+ // Remove prefix for cgroups with no controller, eg. systemd.
+ ctrlr = strings.TrimPrefix(ctrlr, "name=")
paths[ctrlr] = tokens[2]
}
}
@@ -237,7 +248,7 @@ func New(spec *specs.Spec) (*Cgroup, error) {
var err error
parents, err = LoadPaths("self")
if err != nil {
- return nil, fmt.Errorf("finding current cgroups: %v", err)
+ return nil, fmt.Errorf("finding current cgroups: %w", err)
}
}
return &Cgroup{
@@ -276,10 +287,8 @@ func (c *Cgroup) Install(res *specs.LinuxResources) error {
}
return err
}
- if res != nil {
- if err := cfg.ctrlr.set(res, path); err != nil {
- return err
- }
+ if err := cfg.ctrlr.set(res, path); err != nil {
+ return err
}
}
clean.Release()
@@ -304,14 +313,15 @@ func (c *Cgroup) Uninstall() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
b := backoff.WithContext(backoff.NewConstantBackOff(100*time.Millisecond), ctx)
- if err := backoff.Retry(func() error {
+ fn := func() error {
err := syscall.Rmdir(path)
if os.IsNotExist(err) {
return nil
}
return err
- }, b); err != nil {
- return fmt.Errorf("removing cgroup path %q: %v", path, err)
+ }
+ if err := backoff.Retry(fn, b); err != nil {
+ return fmt.Errorf("removing cgroup path %q: %w", path, err)
}
}
return nil
@@ -332,7 +342,6 @@ func (c *Cgroup) Join() (func(), error) {
if _, ok := controllers[ctrlr]; ok {
fullPath := filepath.Join(cgroupRoot, ctrlr, path)
undoPaths = append(undoPaths, fullPath)
- break
}
}
@@ -422,7 +431,7 @@ func (*noop) set(*specs.LinuxResources, string) error {
type memory struct{}
func (*memory) set(spec *specs.LinuxResources, path string) error {
- if spec.Memory == nil {
+ if spec == nil || spec.Memory == nil {
return nil
}
if err := setOptionalValueInt(path, "memory.limit_in_bytes", spec.Memory.Limit); err != nil {
@@ -455,7 +464,7 @@ func (*memory) set(spec *specs.LinuxResources, path string) error {
type cpu struct{}
func (*cpu) set(spec *specs.LinuxResources, path string) error {
- if spec.CPU == nil {
+ if spec == nil || spec.CPU == nil {
return nil
}
if err := setOptionalValueUint(path, "cpu.shares", spec.CPU.Shares); err != nil {
@@ -478,7 +487,7 @@ type cpuSet struct{}
func (*cpuSet) set(spec *specs.LinuxResources, path string) error {
// cpuset.cpus and mems are required fields, but are not set on a new cgroup.
// If not set in the spec, get it from one of the ancestors cgroup.
- if spec.CPU == nil || spec.CPU.Cpus == "" {
+ if spec == nil || spec.CPU == nil || spec.CPU.Cpus == "" {
if _, err := fillFromAncestor(filepath.Join(path, "cpuset.cpus")); err != nil {
return err
}
@@ -488,18 +497,17 @@ func (*cpuSet) set(spec *specs.LinuxResources, path string) error {
}
}
- if spec.CPU == nil || spec.CPU.Mems == "" {
+ if spec == nil || spec.CPU == nil || spec.CPU.Mems == "" {
_, err := fillFromAncestor(filepath.Join(path, "cpuset.mems"))
return err
}
- mems := spec.CPU.Mems
- return setValue(path, "cpuset.mems", mems)
+ return setValue(path, "cpuset.mems", spec.CPU.Mems)
}
type blockIO struct{}
func (*blockIO) set(spec *specs.LinuxResources, path string) error {
- if spec.BlockIO == nil {
+ if spec == nil || spec.BlockIO == nil {
return nil
}
@@ -549,7 +557,7 @@ func setThrottle(path, name string, devs []specs.LinuxThrottleDevice) error {
type networkClass struct{}
func (*networkClass) set(spec *specs.LinuxResources, path string) error {
- if spec.Network == nil {
+ if spec == nil || spec.Network == nil {
return nil
}
return setOptionalValueUint32(path, "net_cls.classid", spec.Network.ClassID)
@@ -558,7 +566,7 @@ func (*networkClass) set(spec *specs.LinuxResources, path string) error {
type networkPrio struct{}
func (*networkPrio) set(spec *specs.LinuxResources, path string) error {
- if spec.Network == nil {
+ if spec == nil || spec.Network == nil {
return nil
}
for _, prio := range spec.Network.Priorities {
@@ -573,7 +581,7 @@ func (*networkPrio) set(spec *specs.LinuxResources, path string) error {
type pids struct{}
func (*pids) set(spec *specs.LinuxResources, path string) error {
- if spec.Pids == nil || spec.Pids.Limit <= 0 {
+ if spec == nil || spec.Pids == nil || spec.Pids.Limit <= 0 {
return nil
}
val := strconv.FormatInt(spec.Pids.Limit, 10)
@@ -583,6 +591,9 @@ func (*pids) set(spec *specs.LinuxResources, path string) error {
type hugeTLB struct{}
func (*hugeTLB) set(spec *specs.LinuxResources, path string) error {
+ if spec == nil {
+ return nil
+ }
for _, limit := range spec.HugepageLimits {
name := fmt.Sprintf("hugetlb.%s.limit_in_bytes", limit.Pagesize)
val := strconv.FormatUint(limit.Limit, 10)
diff --git a/runsc/cgroup/cgroup_test.go b/runsc/cgroup/cgroup_test.go
index 4db5ee5c3..9794517a7 100644
--- a/runsc/cgroup/cgroup_test.go
+++ b/runsc/cgroup/cgroup_test.go
@@ -647,3 +647,83 @@ func TestPids(t *testing.T) {
})
}
}
+
+func TestLoadPaths(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ cgroups string
+ want map[string]string
+ err string
+ }{
+ {
+ name: "abs-path",
+ cgroups: "0:ctr:/path",
+ want: map[string]string{"ctr": "/path"},
+ },
+ {
+ name: "rel-path",
+ cgroups: "0:ctr:rel-path",
+ want: map[string]string{"ctr": "rel-path"},
+ },
+ {
+ name: "non-controller",
+ cgroups: "0:name=systemd:/path",
+ want: map[string]string{"systemd": "/path"},
+ },
+ {
+ name: "empty",
+ },
+ {
+ name: "multiple",
+ cgroups: "0:ctr0:/path0\n" +
+ "1:ctr1:/path1\n" +
+ "2::/empty\n",
+ want: map[string]string{
+ "ctr0": "/path0",
+ "ctr1": "/path1",
+ },
+ },
+ {
+ name: "missing-field",
+ cgroups: "0:nopath\n",
+ err: "invalid cgroups file",
+ },
+ {
+ name: "too-many-fields",
+ cgroups: "0:ctr:/path:extra\n",
+ err: "invalid cgroups file",
+ },
+ {
+ name: "multiple-malformed",
+ cgroups: "0:ctr0:/path0\n" +
+ "1:ctr1:/path1\n" +
+ "2:\n",
+ err: "invalid cgroups file",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ r := strings.NewReader(tc.cgroups)
+ got, err := loadPathsHelper(r)
+ if len(tc.err) == 0 {
+ if err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+ } else if !strings.Contains(err.Error(), tc.err) {
+ t.Fatalf("Wrong error message, want: *%s*, got: %v", tc.err, err)
+ }
+ for key, vWant := range tc.want {
+ vGot, ok := got[key]
+ if !ok {
+ t.Errorf("Missing controller %q", key)
+ }
+ if vWant != vGot {
+ t.Errorf("Wrong controller %q value, want: %q, got: %q", key, vWant, vGot)
+ }
+ delete(got, key)
+ }
+ for k, v := range got {
+ t.Errorf("Unexpected controller %q: %q", k, v)
+ }
+ })
+ }
+}
diff --git a/runsc/cli/BUILD b/runsc/cli/BUILD
new file mode 100644
index 000000000..32cce2a18
--- /dev/null
+++ b/runsc/cli/BUILD
@@ -0,0 +1,22 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "cli",
+ srcs = ["main.go"],
+ visibility = [
+ "//:__pkg__",
+ "//runsc:__pkg__",
+ ],
+ deps = [
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/sentry/platform",
+ "//runsc/cmd",
+ "//runsc/config",
+ "//runsc/flag",
+ "//runsc/specutils",
+ "@com_github_google_subcommands//:go_default_library",
+ ],
+)
diff --git a/runsc/cli/main.go b/runsc/cli/main.go
new file mode 100644
index 000000000..bca015db5
--- /dev/null
+++ b/runsc/cli/main.go
@@ -0,0 +1,256 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package cli is the main entrypoint for runsc.
+package cli
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "os/signal"
+ "syscall"
+ "time"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/runsc/cmd"
+ "gvisor.dev/gvisor/runsc/config"
+ "gvisor.dev/gvisor/runsc/flag"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+var (
+ // Although these flags are not part of the OCI spec, they are used by
+ // Docker, and thus should not be changed.
+ // TODO(gvisor.dev/issue/193): support systemd cgroups
+ systemdCgroup = flag.Bool("systemd-cgroup", false, "Use systemd for cgroups. NOT SUPPORTED.")
+ showVersion = flag.Bool("version", false, "show version and exit.")
+
+ // These flags are unique to runsc, and are used to configure parts of the
+ // system that are not covered by the runtime spec.
+
+ // Debugging flags.
+ logFD = flag.Int("log-fd", -1, "file descriptor to log to. If set, the 'log' flag is ignored.")
+ debugLogFD = flag.Int("debug-log-fd", -1, "file descriptor to write debug logs to. If set, the 'debug-log-dir' flag is ignored.")
+ panicLogFD = flag.Int("panic-log-fd", -1, "file descriptor to write Go's runtime messages.")
+)
+
+// Main is the main entrypoint.
+func Main(version string) {
+ // Help and flags commands are generated automatically.
+ help := cmd.NewHelp(subcommands.DefaultCommander)
+ help.Register(new(cmd.Syscalls))
+ subcommands.Register(help, "")
+ subcommands.Register(subcommands.FlagsCommand(), "")
+
+ // Installation helpers.
+ const helperGroup = "helpers"
+ subcommands.Register(new(cmd.Install), helperGroup)
+ subcommands.Register(new(cmd.Uninstall), helperGroup)
+
+ // Register user-facing runsc commands.
+ subcommands.Register(new(cmd.Checkpoint), "")
+ subcommands.Register(new(cmd.Create), "")
+ subcommands.Register(new(cmd.Delete), "")
+ subcommands.Register(new(cmd.Do), "")
+ subcommands.Register(new(cmd.Events), "")
+ subcommands.Register(new(cmd.Exec), "")
+ subcommands.Register(new(cmd.Gofer), "")
+ subcommands.Register(new(cmd.Kill), "")
+ subcommands.Register(new(cmd.List), "")
+ subcommands.Register(new(cmd.Pause), "")
+ subcommands.Register(new(cmd.PS), "")
+ subcommands.Register(new(cmd.Restore), "")
+ subcommands.Register(new(cmd.Resume), "")
+ subcommands.Register(new(cmd.Run), "")
+ subcommands.Register(new(cmd.Spec), "")
+ subcommands.Register(new(cmd.State), "")
+ subcommands.Register(new(cmd.Start), "")
+ subcommands.Register(new(cmd.Wait), "")
+
+ // Register internal commands with the internal group name. This causes
+ // them to be sorted below the user-facing commands with empty group.
+ // The string below will be printed above the commands.
+ const internalGroup = "internal use only"
+ subcommands.Register(new(cmd.Boot), internalGroup)
+ subcommands.Register(new(cmd.Debug), internalGroup)
+ subcommands.Register(new(cmd.Gofer), internalGroup)
+ subcommands.Register(new(cmd.Statefile), internalGroup)
+
+ config.RegisterFlags()
+
+ // All subcommands must be registered before flag parsing.
+ flag.Parse()
+
+ // Are we showing the version?
+ if *showVersion {
+ // The format here is the same as runc.
+ fmt.Fprintf(os.Stdout, "runsc version %s\n", version)
+ fmt.Fprintf(os.Stdout, "spec: %s\n", specutils.Version)
+ os.Exit(0)
+ }
+
+ // Create a new Config from the flags.
+ conf, err := config.NewFromFlags()
+ if err != nil {
+ cmd.Fatalf(err.Error())
+ }
+
+ // TODO(gvisor.dev/issue/193): support systemd cgroups
+ if *systemdCgroup {
+ fmt.Fprintln(os.Stderr, "systemd cgroup flag passed, but systemd cgroups not supported. See gvisor.dev/issue/193")
+ os.Exit(1)
+ }
+
+ var errorLogger io.Writer
+ if *logFD > -1 {
+ errorLogger = os.NewFile(uintptr(*logFD), "error log file")
+
+ } else if conf.LogFilename != "" {
+ // We must set O_APPEND and not O_TRUNC because Docker passes
+ // the same log file for all commands (and also parses these
+ // log files), so we can't destroy them on each command.
+ var err error
+ errorLogger, err = os.OpenFile(conf.LogFilename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
+ if err != nil {
+ cmd.Fatalf("error opening log file %q: %v", conf.LogFilename, err)
+ }
+ }
+ cmd.ErrorLogger = errorLogger
+
+ if _, err := platform.Lookup(conf.Platform); err != nil {
+ cmd.Fatalf("%v", err)
+ }
+
+ // Sets the reference leak check mode. Also set it in config below to
+ // propagate it to child processes.
+ refs.SetLeakMode(conf.ReferenceLeak)
+
+ // Set up logging.
+ if conf.Debug {
+ log.SetLevel(log.Debug)
+ }
+
+ // Logging will include the local date and time via the time package.
+ //
+ // On first use, time.Local initializes the local time zone, which
+ // involves opening tzdata files on the host. Since this requires
+ // opening host files, it must be done before syscall filter
+ // installation.
+ //
+ // Generally there will be a log message before filter installation
+ // that will force initialization, but force initialization here in
+ // case that does not occur.
+ _ = time.Local.String()
+
+ subcommand := flag.CommandLine.Arg(0)
+
+ var e log.Emitter
+ if *debugLogFD > -1 {
+ f := os.NewFile(uintptr(*debugLogFD), "debug log file")
+
+ e = newEmitter(conf.DebugLogFormat, f)
+
+ } else if conf.DebugLog != "" {
+ f, err := specutils.DebugLogFile(conf.DebugLog, subcommand, "" /* name */)
+ if err != nil {
+ cmd.Fatalf("error opening debug log file in %q: %v", conf.DebugLog, err)
+ }
+ e = newEmitter(conf.DebugLogFormat, f)
+
+ } else {
+ // Stderr is reserved for the application, just discard the logs if no debug
+ // log is specified.
+ e = newEmitter("text", ioutil.Discard)
+ }
+
+ if *panicLogFD > -1 || *debugLogFD > -1 {
+ fd := *panicLogFD
+ if fd < 0 {
+ fd = *debugLogFD
+ }
+ // Quick sanity check to make sure no other commands get passed
+ // a log fd (they should use log dir instead).
+ if subcommand != "boot" && subcommand != "gofer" {
+ cmd.Fatalf("flags --debug-log-fd and --panic-log-fd should only be passed to 'boot' and 'gofer' command, but was passed to %q", subcommand)
+ }
+
+ // If we are the boot process, then we own our stdio FDs and can do what we
+ // want with them. Since Docker and Containerd both eat boot's stderr, we
+ // dup our stderr to the provided log FD so that panics will appear in the
+ // logs, rather than just disappear.
+ if err := syscall.Dup3(fd, int(os.Stderr.Fd()), 0); err != nil {
+ cmd.Fatalf("error dup'ing fd %d to stderr: %v", fd, err)
+ }
+ } else if conf.AlsoLogToStderr {
+ e = &log.MultiEmitter{e, newEmitter(conf.DebugLogFormat, os.Stderr)}
+ }
+
+ log.SetTarget(e)
+
+ log.Infof("***************************")
+ log.Infof("Args: %s", os.Args)
+ log.Infof("Version %s", version)
+ log.Infof("PID: %d", os.Getpid())
+ log.Infof("UID: %d, GID: %d", os.Getuid(), os.Getgid())
+ log.Infof("Configuration:")
+ log.Infof("\t\tRootDir: %s", conf.RootDir)
+ log.Infof("\t\tPlatform: %v", conf.Platform)
+ log.Infof("\t\tFileAccess: %v, overlay: %t", conf.FileAccess, conf.Overlay)
+ log.Infof("\t\tNetwork: %v, logging: %t", conf.Network, conf.LogPackets)
+ log.Infof("\t\tStrace: %t, max size: %d, syscalls: %s", conf.Strace, conf.StraceLogSize, conf.StraceSyscalls)
+ log.Infof("\t\tVFS2 enabled: %v", conf.VFS2)
+ log.Infof("***************************")
+
+ if conf.TestOnlyAllowRunAsCurrentUserWithoutChroot {
+ // SIGTERM is sent to all processes if a test exceeds its
+ // timeout and this case is handled by syscall_test_runner.
+ log.Warningf("Block the TERM signal. This is only safe in tests!")
+ signal.Ignore(syscall.SIGTERM)
+ }
+
+ // Call the subcommand and pass in the configuration.
+ var ws syscall.WaitStatus
+ subcmdCode := subcommands.Execute(context.Background(), conf, &ws)
+ if subcmdCode == subcommands.ExitSuccess {
+ log.Infof("Exiting with status: %v", ws)
+ if ws.Signaled() {
+ // No good way to return it, emulate what the shell does. Maybe raise
+ // signal to self?
+ os.Exit(128 + int(ws.Signal()))
+ }
+ os.Exit(ws.ExitStatus())
+ }
+ // Return an error that is unlikely to be used by the application.
+ log.Warningf("Failure to execute command, err: %v", subcmdCode)
+ os.Exit(128)
+}
+
+func newEmitter(format string, logFile io.Writer) log.Emitter {
+ switch format {
+ case "text":
+ return log.GoogleEmitter{&log.Writer{Next: logFile}}
+ case "json":
+ return log.JSONEmitter{&log.Writer{Next: logFile}}
+ case "json-k8s":
+ return log.K8sJSONEmitter{&log.Writer{Next: logFile}}
+ }
+ cmd.Fatalf("invalid log format %q, must be 'text', 'json', or 'json-k8s'", format)
+ panic("unreachable")
+}
diff --git a/runsc/cmd/boot.go b/runsc/cmd/boot.go
index cd419e1aa..2c92e3067 100644
--- a/runsc/cmd/boot.go
+++ b/runsc/cmd/boot.go
@@ -131,11 +131,11 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
return subcommands.ExitUsageError
}
- // Ensure that if there is a panic, all goroutine stacks are printed.
- debug.SetTraceback("system")
-
conf := args[0].(*config.Config)
+ // Set traceback level
+ debug.SetTraceback(conf.Traceback)
+
if b.attached {
// Ensure this process is killed after parent process terminates when
// attached mode is enabled. In the unfortunate event that the parent
diff --git a/runsc/cmd/do.go b/runsc/cmd/do.go
index d1f2e9e6d..640de4c47 100644
--- a/runsc/cmd/do.go
+++ b/runsc/cmd/do.go
@@ -17,6 +17,7 @@ package cmd
import (
"context"
"encoding/json"
+ "errors"
"fmt"
"io/ioutil"
"math/rand"
@@ -36,6 +37,8 @@ import (
"gvisor.dev/gvisor/runsc/specutils"
)
+var errNoDefaultInterface = errors.New("no default interface found")
+
// Do implements subcommands.Command for the "do" command. It sets up a simple
// sandbox and executes the command inside it. See Usage() for more details.
type Do struct {
@@ -126,26 +129,28 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su
cid := fmt.Sprintf("runsc-%06d", rand.Int31n(1000000))
if conf.Network == config.NetworkNone {
- netns := specs.LinuxNamespace{
- Type: specs.NetworkNamespace,
- }
- if spec.Linux != nil {
- panic("spec.Linux is not nil")
- }
- spec.Linux = &specs.Linux{Namespaces: []specs.LinuxNamespace{netns}}
+ addNamespace(spec, specs.LinuxNamespace{Type: specs.NetworkNamespace})
} else if conf.Rootless {
if conf.Network == config.NetworkSandbox {
- c.notifyUser("*** Warning: using host network due to --rootless ***")
+ c.notifyUser("*** Warning: sandbox network isn't supported with --rootless, switching to host ***")
conf.Network = config.NetworkHost
}
} else {
- clean, err := c.setupNet(cid, spec)
- if err != nil {
+ switch clean, err := c.setupNet(cid, spec); err {
+ case errNoDefaultInterface:
+ log.Warningf("Network interface not found, using internal network")
+ addNamespace(spec, specs.LinuxNamespace{Type: specs.NetworkNamespace})
+ conf.Network = config.NetworkHost
+
+ case nil:
+ // Setup successfull.
+ defer clean()
+
+ default:
return Errorf("Error setting up network: %v", err)
}
- defer clean()
}
out, err := json.Marshal(spec)
@@ -199,6 +204,13 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su
return subcommands.ExitSuccess
}
+func addNamespace(spec *specs.Spec, ns specs.LinuxNamespace) {
+ if spec.Linux == nil {
+ spec.Linux = &specs.Linux{}
+ }
+ spec.Linux.Namespaces = append(spec.Linux.Namespaces, ns)
+}
+
func (c *Do) notifyUser(format string, v ...interface{}) {
if !c.quiet {
fmt.Printf(format+"\n", v...)
@@ -219,10 +231,14 @@ func resolvePath(path string) (string, error) {
return path, nil
}
+// setupNet setups up the sandbox network, including the creation of a network
+// namespace, and iptable rules to redirect the traffic. Returns a cleanup
+// function to tear down the network. Returns errNoDefaultInterface when there
+// is no network interface available to setup the network.
func (c *Do) setupNet(cid string, spec *specs.Spec) (func(), error) {
dev, err := defaultDevice()
if err != nil {
- return nil, err
+ return nil, errNoDefaultInterface
}
peerIP, err := calculatePeerIP(c.ip)
if err != nil {
@@ -279,14 +295,11 @@ func (c *Do) setupNet(cid string, spec *specs.Spec) (func(), error) {
return nil, err
}
- if spec.Linux == nil {
- spec.Linux = &specs.Linux{}
- }
netns := specs.LinuxNamespace{
Type: specs.NetworkNamespace,
Path: filepath.Join("/var/run/netns", cid),
}
- spec.Linux.Namespaces = append(spec.Linux.Namespaces, netns)
+ addNamespace(spec, netns)
return func() { c.cleanupNet(cid, dev, resolvPath, hostnamePath, hostsPath) }, nil
}
diff --git a/runsc/cmd/start.go b/runsc/cmd/start.go
index 88991b521..139edbd49 100644
--- a/runsc/cmd/start.go
+++ b/runsc/cmd/start.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/runsc/config"
"gvisor.dev/gvisor/runsc/container"
"gvisor.dev/gvisor/runsc/flag"
+ "gvisor.dev/gvisor/runsc/specutils"
)
// Start implements subcommands.Command for the "start" command.
@@ -58,6 +59,12 @@ func (*Start) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s
if err != nil {
Fatalf("loading container: %v", err)
}
+ // Read the spec again here to ensure flag annotations from the spec are
+ // applied to "conf".
+ if _, err := specutils.ReadSpec(c.BundleDir, conf); err != nil {
+ Fatalf("reading spec: %v", err)
+ }
+
if err := c.Start(conf); err != nil {
Fatalf("starting container: %v", err)
}
diff --git a/runsc/config/config.go b/runsc/config/config.go
index f30f79f68..b02d8e2e1 100644
--- a/runsc/config/config.go
+++ b/runsc/config/config.go
@@ -37,6 +37,9 @@ type Config struct {
// RootDir is the runtime root directory.
RootDir string `flag:"root"`
+ // Traceback changes the Go runtime's traceback level.
+ Traceback string `flag:"traceback"`
+
// Debug indicates that debug logging should be enabled.
Debug bool `flag:"debug"`
diff --git a/runsc/config/flags.go b/runsc/config/flags.go
index a5f25cfa2..d3203b565 100644
--- a/runsc/config/flags.go
+++ b/runsc/config/flags.go
@@ -49,6 +49,7 @@ func RegisterFlags() {
flag.String("debug-log-format", "text", "log format: text (default), json, or json-k8s.")
flag.Bool("alsologtostderr", false, "send log messages to stderr.")
flag.Bool("allow-flag-override", false, "allow OCI annotations (dev.gvisor.flag.<name>) to override flags for debugging.")
+ flag.String("traceback", "system", "golang runtime's traceback level")
// Debugging flags: strace related
flag.Bool("strace", false, "enable strace.")
diff --git a/runsc/container/container.go b/runsc/container/container.go
index 63478ba8c..52e1755ce 100644
--- a/runsc/container/container.go
+++ b/runsc/container/container.go
@@ -312,6 +312,14 @@ func New(conf *config.Config, args Args) (*Container, error) {
if isRoot(args.Spec) {
log.Debugf("Creating new sandbox for container %q", args.ID)
+ if args.Spec.Linux == nil {
+ args.Spec.Linux = &specs.Linux{}
+ }
+ // Don't force the use of cgroups in tests because they lack permission to do so.
+ if args.Spec.Linux.CgroupsPath == "" && !conf.TestOnlyAllowRunAsCurrentUserWithoutChroot {
+ args.Spec.Linux.CgroupsPath = "/" + args.ID
+ }
+
// Create and join cgroup before processes are created to ensure they are
// part of the cgroup from the start (and all their children processes).
cg, err := cgroup.New(args.Spec)
@@ -321,7 +329,13 @@ func New(conf *config.Config, args Args) (*Container, error) {
if cg != nil {
// If there is cgroup config, install it before creating sandbox process.
if err := cg.Install(args.Spec.Linux.Resources); err != nil {
- return nil, fmt.Errorf("configuring cgroup: %v", err)
+ switch {
+ case errors.Is(err, syscall.EACCES) && conf.Rootless:
+ log.Warningf("Skipping cgroup configuration in rootless mode: %v", err)
+ cg = nil
+ default:
+ return nil, fmt.Errorf("configuring cgroup: %v", err)
+ }
}
}
if err := runInCgroup(cg, func() error {
@@ -985,7 +999,7 @@ func (c *Container) createGoferProcess(spec *specs.Spec, conf *config.Config, bu
// Start the gofer in the given namespace.
log.Debugf("Starting gofer: %s %v", binPath, args)
if err := specutils.StartInNS(cmd, nss); err != nil {
- return nil, nil, fmt.Errorf("Gofer: %v", err)
+ return nil, nil, fmt.Errorf("gofer: %v", err)
}
log.Infof("Gofer started, PID: %d", cmd.Process.Pid)
c.GoferPid = cmd.Process.Pid
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index 1f8e277cc..cc188f45b 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -2362,12 +2362,12 @@ func executeCombinedOutput(cont *Container, name string, arg ...string) ([]byte,
}
// executeSync synchronously executes a new process.
-func (cont *Container) executeSync(args *control.ExecArgs) (syscall.WaitStatus, error) {
- pid, err := cont.Execute(args)
+func (c *Container) executeSync(args *control.ExecArgs) (syscall.WaitStatus, error) {
+ pid, err := c.Execute(args)
if err != nil {
return 0, fmt.Errorf("error executing: %v", err)
}
- ws, err := cont.WaitPID(pid)
+ ws, err := c.WaitPID(pid)
if err != nil {
return 0, fmt.Errorf("error waiting: %v", err)
}
diff --git a/runsc/main.go b/runsc/main.go
index ed244c4ba..4ce5ebee9 100644
--- a/runsc/main.go
+++ b/runsc/main.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,245 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Binary runsc is an implementation of the Open Container Initiative Runtime
-// that runs applications inside a sandbox.
+// Binary runsc implements the OCI runtime interface.
package main
import (
- "context"
- "fmt"
- "io"
- "io/ioutil"
- "os"
- "os/signal"
- "syscall"
- "time"
-
- "github.com/google/subcommands"
- "gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/refs"
- "gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/runsc/cmd"
- "gvisor.dev/gvisor/runsc/config"
- "gvisor.dev/gvisor/runsc/flag"
- "gvisor.dev/gvisor/runsc/specutils"
-)
-
-var (
- // Although these flags are not part of the OCI spec, they are used by
- // Docker, and thus should not be changed.
- // TODO(gvisor.dev/issue/193): support systemd cgroups
- systemdCgroup = flag.Bool("systemd-cgroup", false, "Use systemd for cgroups. NOT SUPPORTED.")
- showVersion = flag.Bool("version", false, "show version and exit.")
-
- // These flags are unique to runsc, and are used to configure parts of the
- // system that are not covered by the runtime spec.
-
- // Debugging flags.
- logFD = flag.Int("log-fd", -1, "file descriptor to log to. If set, the 'log' flag is ignored.")
- debugLogFD = flag.Int("debug-log-fd", -1, "file descriptor to write debug logs to. If set, the 'debug-log-dir' flag is ignored.")
- panicLogFD = flag.Int("panic-log-fd", -1, "file descriptor to write Go's runtime messages.")
+ "gvisor.dev/gvisor/runsc/cli"
)
func main() {
- // Help and flags commands are generated automatically.
- help := cmd.NewHelp(subcommands.DefaultCommander)
- help.Register(new(cmd.Syscalls))
- subcommands.Register(help, "")
- subcommands.Register(subcommands.FlagsCommand(), "")
-
- // Installation helpers.
- const helperGroup = "helpers"
- subcommands.Register(new(cmd.Install), helperGroup)
- subcommands.Register(new(cmd.Uninstall), helperGroup)
-
- // Register user-facing runsc commands.
- subcommands.Register(new(cmd.Checkpoint), "")
- subcommands.Register(new(cmd.Create), "")
- subcommands.Register(new(cmd.Delete), "")
- subcommands.Register(new(cmd.Do), "")
- subcommands.Register(new(cmd.Events), "")
- subcommands.Register(new(cmd.Exec), "")
- subcommands.Register(new(cmd.Gofer), "")
- subcommands.Register(new(cmd.Kill), "")
- subcommands.Register(new(cmd.List), "")
- subcommands.Register(new(cmd.Pause), "")
- subcommands.Register(new(cmd.PS), "")
- subcommands.Register(new(cmd.Restore), "")
- subcommands.Register(new(cmd.Resume), "")
- subcommands.Register(new(cmd.Run), "")
- subcommands.Register(new(cmd.Spec), "")
- subcommands.Register(new(cmd.State), "")
- subcommands.Register(new(cmd.Start), "")
- subcommands.Register(new(cmd.Wait), "")
-
- // Register internal commands with the internal group name. This causes
- // them to be sorted below the user-facing commands with empty group.
- // The string below will be printed above the commands.
- const internalGroup = "internal use only"
- subcommands.Register(new(cmd.Boot), internalGroup)
- subcommands.Register(new(cmd.Debug), internalGroup)
- subcommands.Register(new(cmd.Gofer), internalGroup)
- subcommands.Register(new(cmd.Statefile), internalGroup)
-
- config.RegisterFlags()
-
- // All subcommands must be registered before flag parsing.
- flag.Parse()
-
- // Are we showing the version?
- if *showVersion {
- // The format here is the same as runc.
- fmt.Fprintf(os.Stdout, "runsc version %s\n", version)
- fmt.Fprintf(os.Stdout, "spec: %s\n", specutils.Version)
- os.Exit(0)
- }
-
- // Create a new Config from the flags.
- conf, err := config.NewFromFlags()
- if err != nil {
- cmd.Fatalf(err.Error())
- }
-
- // TODO(gvisor.dev/issue/193): support systemd cgroups
- if *systemdCgroup {
- fmt.Fprintln(os.Stderr, "systemd cgroup flag passed, but systemd cgroups not supported. See gvisor.dev/issue/193")
- os.Exit(1)
- }
-
- var errorLogger io.Writer
- if *logFD > -1 {
- errorLogger = os.NewFile(uintptr(*logFD), "error log file")
-
- } else if conf.LogFilename != "" {
- // We must set O_APPEND and not O_TRUNC because Docker passes
- // the same log file for all commands (and also parses these
- // log files), so we can't destroy them on each command.
- var err error
- errorLogger, err = os.OpenFile(conf.LogFilename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
- if err != nil {
- cmd.Fatalf("error opening log file %q: %v", conf.LogFilename, err)
- }
- }
- cmd.ErrorLogger = errorLogger
-
- if _, err := platform.Lookup(conf.Platform); err != nil {
- cmd.Fatalf("%v", err)
- }
-
- // Sets the reference leak check mode. Also set it in config below to
- // propagate it to child processes.
- refs.SetLeakMode(conf.ReferenceLeak)
-
- // Set up logging.
- if conf.Debug {
- log.SetLevel(log.Debug)
- }
-
- // Logging will include the local date and time via the time package.
- //
- // On first use, time.Local initializes the local time zone, which
- // involves opening tzdata files on the host. Since this requires
- // opening host files, it must be done before syscall filter
- // installation.
- //
- // Generally there will be a log message before filter installation
- // that will force initialization, but force initialization here in
- // case that does not occur.
- _ = time.Local.String()
-
- subcommand := flag.CommandLine.Arg(0)
-
- var e log.Emitter
- if *debugLogFD > -1 {
- f := os.NewFile(uintptr(*debugLogFD), "debug log file")
-
- e = newEmitter(conf.DebugLogFormat, f)
-
- } else if conf.DebugLog != "" {
- f, err := specutils.DebugLogFile(conf.DebugLog, subcommand, "" /* name */)
- if err != nil {
- cmd.Fatalf("error opening debug log file in %q: %v", conf.DebugLog, err)
- }
- e = newEmitter(conf.DebugLogFormat, f)
-
- } else {
- // Stderr is reserved for the application, just discard the logs if no debug
- // log is specified.
- e = newEmitter("text", ioutil.Discard)
- }
-
- if *panicLogFD > -1 || *debugLogFD > -1 {
- fd := *panicLogFD
- if fd < 0 {
- fd = *debugLogFD
- }
- // Quick sanity check to make sure no other commands get passed
- // a log fd (they should use log dir instead).
- if subcommand != "boot" && subcommand != "gofer" {
- cmd.Fatalf("flags --debug-log-fd and --panic-log-fd should only be passed to 'boot' and 'gofer' command, but was passed to %q", subcommand)
- }
-
- // If we are the boot process, then we own our stdio FDs and can do what we
- // want with them. Since Docker and Containerd both eat boot's stderr, we
- // dup our stderr to the provided log FD so that panics will appear in the
- // logs, rather than just disappear.
- if err := syscall.Dup3(fd, int(os.Stderr.Fd()), 0); err != nil {
- cmd.Fatalf("error dup'ing fd %d to stderr: %v", fd, err)
- }
- } else if conf.AlsoLogToStderr {
- e = &log.MultiEmitter{e, newEmitter(conf.DebugLogFormat, os.Stderr)}
- }
-
- log.SetTarget(e)
-
- log.Infof("***************************")
- log.Infof("Args: %s", os.Args)
- log.Infof("Version %s", version)
- log.Infof("PID: %d", os.Getpid())
- log.Infof("UID: %d, GID: %d", os.Getuid(), os.Getgid())
- log.Infof("Configuration:")
- log.Infof("\t\tRootDir: %s", conf.RootDir)
- log.Infof("\t\tPlatform: %v", conf.Platform)
- log.Infof("\t\tFileAccess: %v, overlay: %t", conf.FileAccess, conf.Overlay)
- log.Infof("\t\tNetwork: %v, logging: %t", conf.Network, conf.LogPackets)
- log.Infof("\t\tStrace: %t, max size: %d, syscalls: %s", conf.Strace, conf.StraceLogSize, conf.StraceSyscalls)
- log.Infof("\t\tVFS2 enabled: %v", conf.VFS2)
- log.Infof("***************************")
-
- if conf.TestOnlyAllowRunAsCurrentUserWithoutChroot {
- // SIGTERM is sent to all processes if a test exceeds its
- // timeout and this case is handled by syscall_test_runner.
- log.Warningf("Block the TERM signal. This is only safe in tests!")
- signal.Ignore(syscall.SIGTERM)
- }
-
- // Call the subcommand and pass in the configuration.
- var ws syscall.WaitStatus
- subcmdCode := subcommands.Execute(context.Background(), conf, &ws)
- if subcmdCode == subcommands.ExitSuccess {
- log.Infof("Exiting with status: %v", ws)
- if ws.Signaled() {
- // No good way to return it, emulate what the shell does. Maybe raise
- // signal to self?
- os.Exit(128 + int(ws.Signal()))
- }
- os.Exit(ws.ExitStatus())
- }
- // Return an error that is unlikely to be used by the application.
- log.Warningf("Failure to execute command, err: %v", subcmdCode)
- os.Exit(128)
-}
-
-func newEmitter(format string, logFile io.Writer) log.Emitter {
- switch format {
- case "text":
- return log.GoogleEmitter{&log.Writer{Next: logFile}}
- case "json":
- return log.JSONEmitter{&log.Writer{Next: logFile}}
- case "json-k8s":
- return log.K8sJSONEmitter{&log.Writer{Next: logFile}}
- }
- cmd.Fatalf("invalid log format %q, must be 'text', 'json', or 'json-k8s'", format)
- panic("unreachable")
+ cli.Main(version)
}
diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go
index 0392e3e83..45abc1425 100644
--- a/runsc/specutils/specutils.go
+++ b/runsc/specutils/specutils.go
@@ -344,15 +344,9 @@ func IsSupportedDevMount(m specs.Mount) bool {
var existingDevices = []string{
"/dev/fd", "/dev/stdin", "/dev/stdout", "/dev/stderr",
"/dev/null", "/dev/zero", "/dev/full", "/dev/random",
- "/dev/urandom", "/dev/shm", "/dev/pts", "/dev/ptmx",
+ "/dev/urandom", "/dev/shm", "/dev/ptmx",
}
dst := filepath.Clean(m.Destination)
- if dst == "/dev" {
- // OCI spec uses many different mounts for the things inside of '/dev'. We
- // have a single mount at '/dev' that is always mounted, regardless of
- // whether it was asked for, as the spec says we SHOULD.
- return false
- }
for _, dev := range existingDevices {
if dst == dev || strings.HasPrefix(dst, dev+"/") {
return false
@@ -425,7 +419,7 @@ func Mount(src, dst, typ string, flags uint32) error {
// Special case, as there is no source directory for proc mounts.
isDir = true
} else if fi, err := os.Stat(src); err != nil {
- return fmt.Errorf("Stat(%q) failed: %v", src, err)
+ return fmt.Errorf("stat(%q) failed: %v", src, err)
} else {
isDir = fi.IsDir()
}
@@ -433,25 +427,25 @@ func Mount(src, dst, typ string, flags uint32) error {
if isDir {
// Create the destination directory.
if err := os.MkdirAll(dst, 0777); err != nil {
- return fmt.Errorf("Mkdir(%q) failed: %v", dst, err)
+ return fmt.Errorf("mkdir(%q) failed: %v", dst, err)
}
} else {
// Create the parent destination directory.
parent := path.Dir(dst)
if err := os.MkdirAll(parent, 0777); err != nil {
- return fmt.Errorf("Mkdir(%q) failed: %v", parent, err)
+ return fmt.Errorf("mkdir(%q) failed: %v", parent, err)
}
// Create the destination file if it does not exist.
f, err := os.OpenFile(dst, syscall.O_CREAT, 0777)
if err != nil {
- return fmt.Errorf("Open(%q) failed: %v", dst, err)
+ return fmt.Errorf("open(%q) failed: %v", dst, err)
}
f.Close()
}
// Do the mount.
if err := syscall.Mount(src, dst, typ, uintptr(flags), ""); err != nil {
- return fmt.Errorf("Mount(%q, %q, %d) failed: %v", src, dst, flags, err)
+ return fmt.Errorf("mount(%q, %q, %d) failed: %v", src, dst, flags, err)
}
return nil
}
diff --git a/shim/v1/BUILD b/shim/v1/BUILD
index 4c9e2c2c6..3614a67d1 100644
--- a/shim/v1/BUILD
+++ b/shim/v1/BUILD
@@ -4,27 +4,10 @@ package(licenses = ["notice"])
go_binary(
name = "gvisor-containerd-shim",
- srcs = [
- "api.go",
- "config.go",
- "main.go",
- ],
+ srcs = ["main.go"],
static = True,
visibility = [
"//visibility:public",
],
- deps = [
- "//pkg/shim/runsc",
- "//pkg/shim/v1/shim",
- "@com_github_burntsushi_toml//:go_default_library",
- "@com_github_containerd_containerd//events:go_default_library",
- "@com_github_containerd_containerd//namespaces:go_default_library",
- "@com_github_containerd_containerd//runtime/v1/shim/v1:go_default_library",
- "@com_github_containerd_containerd//sys:go_default_library",
- "@com_github_containerd_containerd//sys/reaper:go_default_library",
- "@com_github_containerd_ttrpc//:go_default_library",
- "@com_github_containerd_typeurl//:go_default_library",
- "@com_github_gogo_protobuf//types:go_default_library",
- "@org_golang_x_sys//unix:go_default_library",
- ],
+ deps = ["//shim/v1/cli"],
)
diff --git a/shim/v1/cli/BUILD b/shim/v1/cli/BUILD
new file mode 100644
index 000000000..0bbdc4add
--- /dev/null
+++ b/shim/v1/cli/BUILD
@@ -0,0 +1,30 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "cli",
+ srcs = [
+ "api.go",
+ "cli.go",
+ "config.go",
+ ],
+ visibility = [
+ "//:__pkg__",
+ "//shim/v1:__pkg__",
+ ],
+ deps = [
+ "//pkg/shim/runsc",
+ "//pkg/shim/v1/shim",
+ "@com_github_burntsushi_toml//:go_default_library",
+ "@com_github_containerd_containerd//events:go_default_library",
+ "@com_github_containerd_containerd//namespaces:go_default_library",
+ "@com_github_containerd_containerd//runtime/v1/shim/v1:go_default_library",
+ "@com_github_containerd_containerd//sys:go_default_library",
+ "@com_github_containerd_containerd//sys/reaper:go_default_library",
+ "@com_github_containerd_ttrpc//:go_default_library",
+ "@com_github_containerd_typeurl//:go_default_library",
+ "@com_github_gogo_protobuf//types:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/shim/v1/api.go b/shim/v1/cli/api.go
index 2444d23f1..050793094 100644
--- a/shim/v1/api.go
+++ b/shim/v1/cli/api.go
@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package main
+package cli
import (
shim "github.com/containerd/containerd/runtime/v1/shim/v1"
diff --git a/shim/v1/cli/cli.go b/shim/v1/cli/cli.go
new file mode 100644
index 000000000..1a502eabd
--- /dev/null
+++ b/shim/v1/cli/cli.go
@@ -0,0 +1,267 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package cli defines the command line interface for the V1 shim.
+package cli
+
+import (
+ "bytes"
+ "context"
+ "flag"
+ "fmt"
+ "log"
+ "net"
+ "os"
+ "os/exec"
+ "os/signal"
+ "path/filepath"
+ "strings"
+ "sync"
+ "syscall"
+
+ "github.com/containerd/containerd/events"
+ "github.com/containerd/containerd/namespaces"
+ "github.com/containerd/containerd/sys"
+ "github.com/containerd/containerd/sys/reaper"
+ "github.com/containerd/ttrpc"
+ "github.com/containerd/typeurl"
+ "github.com/gogo/protobuf/types"
+ "golang.org/x/sys/unix"
+
+ "gvisor.dev/gvisor/pkg/shim/runsc"
+ "gvisor.dev/gvisor/pkg/shim/v1/shim"
+)
+
+var (
+ debugFlag bool
+ namespaceFlag string
+ socketFlag string
+ addressFlag string
+ workdirFlag string
+ runtimeRootFlag string
+ containerdBinaryFlag string
+ shimConfigFlag string
+)
+
+// Containerd defaults to runc, unless another runtime is explicitly specified.
+// We keep the same default to make the default behavior consistent.
+const defaultRoot = "/run/containerd/runc"
+
+func init() {
+ flag.BoolVar(&debugFlag, "debug", false, "enable debug output in logs")
+ flag.StringVar(&namespaceFlag, "namespace", "", "namespace that owns the shim")
+ flag.StringVar(&socketFlag, "socket", "", "abstract socket path to serve")
+ flag.StringVar(&addressFlag, "address", "", "grpc address back to main containerd")
+ flag.StringVar(&workdirFlag, "workdir", "", "path used to storge large temporary data")
+ flag.StringVar(&runtimeRootFlag, "runtime-root", defaultRoot, "root directory for the runtime")
+
+ // Currently, the `containerd publish` utility is embedded in the
+ // daemon binary. The daemon invokes `containerd-shim
+ // -containerd-binary ...` with its own os.Executable() path.
+ flag.StringVar(&containerdBinaryFlag, "containerd-binary", "containerd", "path to containerd binary (used for `containerd publish`)")
+ flag.StringVar(&shimConfigFlag, "config", "/etc/containerd/runsc.toml", "path to the shim configuration file")
+}
+
+// Main is the main entrypoint.
+func Main() {
+ flag.Parse()
+
+ // This is a hack. Exec current process to run standard containerd-shim
+ // if runtime root is not `runsc`. We don't need this for shim v2 api.
+ if filepath.Base(runtimeRootFlag) != "runsc" {
+ if err := executeRuncShim(); err != nil {
+ fmt.Fprintf(os.Stderr, "gvisor-containerd-shim: %s\n", err)
+ os.Exit(1)
+ }
+ }
+
+ // Run regular shim if needed.
+ if err := executeShim(); err != nil {
+ fmt.Fprintf(os.Stderr, "gvisor-containerd-shim: %s\n", err)
+ os.Exit(1)
+ }
+}
+
+// executeRuncShim execs current process to a containerd-shim process and
+// retains all flags and envs.
+func executeRuncShim() error {
+ c, err := loadConfig(shimConfigFlag)
+ if err != nil && !os.IsNotExist(err) {
+ return fmt.Errorf("failed to load shim config: %w", err)
+ }
+ shimPath := c.RuncShim
+ if shimPath == "" {
+ shimPath, err = exec.LookPath("containerd-shim")
+ if err != nil {
+ return fmt.Errorf("lookup containerd-shim failed: %w", err)
+ }
+ }
+
+ args := append([]string{shimPath}, os.Args[1:]...)
+ if err := syscall.Exec(shimPath, args, os.Environ()); err != nil {
+ return fmt.Errorf("exec containerd-shim @ %q failed: %w", shimPath, err)
+ }
+ return nil
+}
+
+func executeShim() error {
+ // start handling signals as soon as possible so that things are
+ // properly reaped or if runtime exits before we hit the handler.
+ signals, err := setupSignals()
+ if err != nil {
+ return err
+ }
+ path, err := os.Getwd()
+ if err != nil {
+ return err
+ }
+ server, err := ttrpc.NewServer(ttrpc.WithServerHandshaker(ttrpc.UnixSocketRequireSameUser()))
+ if err != nil {
+ return fmt.Errorf("failed creating server: %w", err)
+ }
+ c, err := loadConfig(shimConfigFlag)
+ if err != nil && !os.IsNotExist(err) {
+ return fmt.Errorf("failed to load shim config: %w", err)
+ }
+ sv, err := shim.NewService(
+ shim.Config{
+ Path: path,
+ Namespace: namespaceFlag,
+ WorkDir: workdirFlag,
+ RuntimeRoot: runtimeRootFlag,
+ RunscConfig: c.RunscConfig,
+ },
+ &remoteEventsPublisher{address: addressFlag},
+ )
+ if err != nil {
+ return err
+ }
+ registerShimService(server, sv)
+ if err := serve(server, socketFlag); err != nil {
+ return err
+ }
+ return handleSignals(signals, server, sv)
+}
+
+// serve serves the ttrpc API over a unix socket at the provided path this
+// function does not block.
+func serve(server *ttrpc.Server, path string) error {
+ var (
+ l net.Listener
+ err error
+ )
+ if path == "" {
+ l, err = net.FileListener(os.NewFile(3, "socket"))
+ path = "[inherited from parent]"
+ } else {
+ if len(path) > 106 {
+ return fmt.Errorf("%q: unix socket path too long (> 106)", path)
+ }
+ l, err = net.Listen("unix", "\x00"+path)
+ }
+ if err != nil {
+ return err
+ }
+ go func() {
+ defer l.Close()
+ err := server.Serve(context.Background(), l)
+ if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
+ log.Fatalf("ttrpc server failure: %v", err)
+ }
+ }()
+ return nil
+}
+
+// setupSignals creates a new signal handler for all signals and sets the shim
+// as a sub-reaper so that the container processes are reparented.
+func setupSignals() (chan os.Signal, error) {
+ signals := make(chan os.Signal, 32)
+ signal.Notify(signals, unix.SIGTERM, unix.SIGINT, unix.SIGCHLD, unix.SIGPIPE)
+ // make sure runc is setup to use the monitor for waiting on processes.
+ // TODO(random-liu): Move shim/reaper.go to a separate package.
+ runsc.Monitor = reaper.Default
+ // Set the shim as the subreaper for all orphaned processes created by
+ // the container.
+ if err := unix.Prctl(unix.PR_SET_CHILD_SUBREAPER, 1, 0, 0, 0); err != nil {
+ return nil, err
+ }
+ return signals, nil
+}
+
+func handleSignals(signals chan os.Signal, server *ttrpc.Server, sv *shim.Service) error {
+ var (
+ termOnce sync.Once
+ done = make(chan struct{})
+ )
+
+ for {
+ select {
+ case <-done:
+ return nil
+ case s := <-signals:
+ switch s {
+ case unix.SIGCHLD:
+ if _, err := sys.Reap(false); err != nil {
+ log.Printf("reap error: %v", err)
+ }
+ case unix.SIGTERM, unix.SIGINT:
+ go termOnce.Do(func() {
+ ctx := context.TODO()
+ if err := server.Shutdown(ctx); err != nil {
+ log.Printf("failed to shutdown server: %v", err)
+ }
+ // Ensure our child is dead if any.
+ sv.Kill(ctx, &KillRequest{
+ Signal: uint32(syscall.SIGKILL),
+ All: true,
+ })
+ sv.Delete(context.Background(), &types.Empty{})
+ close(done)
+ })
+ case unix.SIGPIPE:
+ }
+ }
+ }
+}
+
+type remoteEventsPublisher struct {
+ address string
+}
+
+func (l *remoteEventsPublisher) Publish(ctx context.Context, topic string, event events.Event) error {
+ ns, _ := namespaces.Namespace(ctx)
+ encoded, err := typeurl.MarshalAny(event)
+ if err != nil {
+ return err
+ }
+ data, err := encoded.Marshal()
+ if err != nil {
+ return err
+ }
+ cmd := exec.CommandContext(ctx, containerdBinaryFlag, "--address", l.address, "publish", "--topic", topic, "--namespace", ns)
+ cmd.Stdin = bytes.NewReader(data)
+ c, err := reaper.Default.Start(cmd)
+ if err != nil {
+ return err
+ }
+ status, err := reaper.Default.Wait(cmd, c)
+ if err != nil {
+ return fmt.Errorf("failed to publish event: %w", err)
+ }
+ if status != 0 {
+ return fmt.Errorf("failed to publish event: status %d", status)
+ }
+ return nil
+}
diff --git a/shim/v1/config.go b/shim/v1/cli/config.go
index a72cc7754..1be9597ed 100644
--- a/shim/v1/config.go
+++ b/shim/v1/cli/config.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package main
+package cli
import "github.com/BurntSushi/toml"
diff --git a/shim/v1/main.go b/shim/v1/main.go
index 3159923af..11ff4add1 100644
--- a/shim/v1/main.go
+++ b/shim/v1/main.go
@@ -1,5 +1,4 @@
-// Copyright 2018 The containerd Authors.
-// Copyright 2019 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -13,253 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Binary gvisor-containerd-shim is the v1 containerd shim.
package main
import (
- "bytes"
- "context"
- "flag"
- "fmt"
- "log"
- "net"
- "os"
- "os/exec"
- "os/signal"
- "path/filepath"
- "strings"
- "sync"
- "syscall"
-
- "github.com/containerd/containerd/events"
- "github.com/containerd/containerd/namespaces"
- "github.com/containerd/containerd/sys"
- "github.com/containerd/containerd/sys/reaper"
- "github.com/containerd/ttrpc"
- "github.com/containerd/typeurl"
- "github.com/gogo/protobuf/types"
- "golang.org/x/sys/unix"
-
- "gvisor.dev/gvisor/pkg/shim/runsc"
- "gvisor.dev/gvisor/pkg/shim/v1/shim"
-)
-
-var (
- debugFlag bool
- namespaceFlag string
- socketFlag string
- addressFlag string
- workdirFlag string
- runtimeRootFlag string
- containerdBinaryFlag string
- shimConfigFlag string
+ "gvisor.dev/gvisor/shim/v1/cli"
)
-// Containerd defaults to runc, unless another runtime is explicitly specified.
-// We keep the same default to make the default behavior consistent.
-const defaultRoot = "/run/containerd/runc"
-
-func init() {
- flag.BoolVar(&debugFlag, "debug", false, "enable debug output in logs")
- flag.StringVar(&namespaceFlag, "namespace", "", "namespace that owns the shim")
- flag.StringVar(&socketFlag, "socket", "", "abstract socket path to serve")
- flag.StringVar(&addressFlag, "address", "", "grpc address back to main containerd")
- flag.StringVar(&workdirFlag, "workdir", "", "path used to storge large temporary data")
- flag.StringVar(&runtimeRootFlag, "runtime-root", defaultRoot, "root directory for the runtime")
-
- // Currently, the `containerd publish` utility is embedded in the
- // daemon binary. The daemon invokes `containerd-shim
- // -containerd-binary ...` with its own os.Executable() path.
- flag.StringVar(&containerdBinaryFlag, "containerd-binary", "containerd", "path to containerd binary (used for `containerd publish`)")
- flag.StringVar(&shimConfigFlag, "config", "/etc/containerd/runsc.toml", "path to the shim configuration file")
-}
-
func main() {
- flag.Parse()
-
- // This is a hack. Exec current process to run standard containerd-shim
- // if runtime root is not `runsc`. We don't need this for shim v2 api.
- if filepath.Base(runtimeRootFlag) != "runsc" {
- if err := executeRuncShim(); err != nil {
- fmt.Fprintf(os.Stderr, "gvisor-containerd-shim: %s\n", err)
- os.Exit(1)
- }
- }
-
- // Run regular shim if needed.
- if err := executeShim(); err != nil {
- fmt.Fprintf(os.Stderr, "gvisor-containerd-shim: %s\n", err)
- os.Exit(1)
- }
-}
-
-// executeRuncShim execs current process to a containerd-shim process and
-// retains all flags and envs.
-func executeRuncShim() error {
- c, err := loadConfig(shimConfigFlag)
- if err != nil && !os.IsNotExist(err) {
- return fmt.Errorf("failed to load shim config: %w", err)
- }
- shimPath := c.RuncShim
- if shimPath == "" {
- shimPath, err = exec.LookPath("containerd-shim")
- if err != nil {
- return fmt.Errorf("lookup containerd-shim failed: %w", err)
- }
- }
-
- args := append([]string{shimPath}, os.Args[1:]...)
- if err := syscall.Exec(shimPath, args, os.Environ()); err != nil {
- return fmt.Errorf("exec containerd-shim @ %q failed: %w", shimPath, err)
- }
- return nil
-}
-
-func executeShim() error {
- // start handling signals as soon as possible so that things are
- // properly reaped or if runtime exits before we hit the handler.
- signals, err := setupSignals()
- if err != nil {
- return err
- }
- path, err := os.Getwd()
- if err != nil {
- return err
- }
- server, err := ttrpc.NewServer(ttrpc.WithServerHandshaker(ttrpc.UnixSocketRequireSameUser()))
- if err != nil {
- return fmt.Errorf("failed creating server: %w", err)
- }
- c, err := loadConfig(shimConfigFlag)
- if err != nil && !os.IsNotExist(err) {
- return fmt.Errorf("failed to load shim config: %w", err)
- }
- sv, err := shim.NewService(
- shim.Config{
- Path: path,
- Namespace: namespaceFlag,
- WorkDir: workdirFlag,
- RuntimeRoot: runtimeRootFlag,
- RunscConfig: c.RunscConfig,
- },
- &remoteEventsPublisher{address: addressFlag},
- )
- if err != nil {
- return err
- }
- registerShimService(server, sv)
- if err := serve(server, socketFlag); err != nil {
- return err
- }
- return handleSignals(signals, server, sv)
-}
-
-// serve serves the ttrpc API over a unix socket at the provided path this
-// function does not block.
-func serve(server *ttrpc.Server, path string) error {
- var (
- l net.Listener
- err error
- )
- if path == "" {
- l, err = net.FileListener(os.NewFile(3, "socket"))
- path = "[inherited from parent]"
- } else {
- if len(path) > 106 {
- return fmt.Errorf("%q: unix socket path too long (> 106)", path)
- }
- l, err = net.Listen("unix", "\x00"+path)
- }
- if err != nil {
- return err
- }
- go func() {
- defer l.Close()
- err := server.Serve(context.Background(), l)
- if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
- log.Fatalf("ttrpc server failure: %v", err)
- }
- }()
- return nil
-}
-
-// setupSignals creates a new signal handler for all signals and sets the shim
-// as a sub-reaper so that the container processes are reparented.
-func setupSignals() (chan os.Signal, error) {
- signals := make(chan os.Signal, 32)
- signal.Notify(signals, unix.SIGTERM, unix.SIGINT, unix.SIGCHLD, unix.SIGPIPE)
- // make sure runc is setup to use the monitor for waiting on processes.
- // TODO(random-liu): Move shim/reaper.go to a separate package.
- runsc.Monitor = reaper.Default
- // Set the shim as the subreaper for all orphaned processes created by
- // the container.
- if err := unix.Prctl(unix.PR_SET_CHILD_SUBREAPER, 1, 0, 0, 0); err != nil {
- return nil, err
- }
- return signals, nil
-}
-
-func handleSignals(signals chan os.Signal, server *ttrpc.Server, sv *shim.Service) error {
- var (
- termOnce sync.Once
- done = make(chan struct{})
- )
-
- for {
- select {
- case <-done:
- return nil
- case s := <-signals:
- switch s {
- case unix.SIGCHLD:
- if _, err := sys.Reap(false); err != nil {
- log.Printf("reap error: %v", err)
- }
- case unix.SIGTERM, unix.SIGINT:
- go termOnce.Do(func() {
- ctx := context.TODO()
- if err := server.Shutdown(ctx); err != nil {
- log.Printf("failed to shutdown server: %v", err)
- }
- // Ensure our child is dead if any.
- sv.Kill(ctx, &KillRequest{
- Signal: uint32(syscall.SIGKILL),
- All: true,
- })
- sv.Delete(context.Background(), &types.Empty{})
- close(done)
- })
- case unix.SIGPIPE:
- }
- }
- }
-}
-
-type remoteEventsPublisher struct {
- address string
-}
-
-func (l *remoteEventsPublisher) Publish(ctx context.Context, topic string, event events.Event) error {
- ns, _ := namespaces.Namespace(ctx)
- encoded, err := typeurl.MarshalAny(event)
- if err != nil {
- return err
- }
- data, err := encoded.Marshal()
- if err != nil {
- return err
- }
- cmd := exec.CommandContext(ctx, containerdBinaryFlag, "--address", l.address, "publish", "--topic", topic, "--namespace", ns)
- cmd.Stdin = bytes.NewReader(data)
- c, err := reaper.Default.Start(cmd)
- if err != nil {
- return err
- }
- status, err := reaper.Default.Wait(cmd, c)
- if err != nil {
- return fmt.Errorf("failed to publish event: %w", err)
- }
- if status != 0 {
- return fmt.Errorf("failed to publish event: status %d", status)
- }
- return nil
+ cli.Main()
}
diff --git a/shim/v2/BUILD b/shim/v2/BUILD
index 8de9ac0ba..b4a107d27 100644
--- a/shim/v2/BUILD
+++ b/shim/v2/BUILD
@@ -4,15 +4,10 @@ package(licenses = ["notice"])
go_binary(
name = "containerd-shim-runsc-v1",
- srcs = [
- "main.go",
- ],
+ srcs = ["main.go"],
static = True,
visibility = [
"//visibility:public",
],
- deps = [
- "//pkg/shim/v2",
- "@com_github_containerd_containerd//runtime/v2/shim:go_default_library",
- ],
+ deps = ["//shim/v2/cli"],
)
diff --git a/shim/v2/cli/BUILD b/shim/v2/cli/BUILD
new file mode 100644
index 000000000..6681e0772
--- /dev/null
+++ b/shim/v2/cli/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "cli",
+ srcs = ["cli.go"],
+ visibility = [
+ "//:__pkg__",
+ "//shim/v2:__pkg__",
+ ],
+ deps = [
+ "//pkg/shim/v2",
+ "@com_github_containerd_containerd//runtime/v2/shim:go_default_library",
+ ],
+)
diff --git a/shim/v2/cli/cli.go b/shim/v2/cli/cli.go
new file mode 100644
index 000000000..3d6644feb
--- /dev/null
+++ b/shim/v2/cli/cli.go
@@ -0,0 +1,28 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package cli defines the command line interface for the V2 shim.
+package cli
+
+import (
+ "github.com/containerd/containerd/runtime/v2/shim"
+
+ "gvisor.dev/gvisor/pkg/shim/v2"
+)
+
+// Main is the main entrypoint.
+func Main() {
+ shim.Run("io.containerd.runsc.v1", v2.New)
+}
diff --git a/shim/v2/main.go b/shim/v2/main.go
index 753871eea..3680cdf9c 100644
--- a/shim/v2/main.go
+++ b/shim/v2/main.go
@@ -1,5 +1,4 @@
-// Copyright 2018 The containerd Authors.
-// Copyright 2019 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -13,14 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Binary containerd-shim-runsc-v1 is the v2 containerd shim (implementing the formal v1 API).
package main
import (
- "github.com/containerd/containerd/runtime/v2/shim"
-
- "gvisor.dev/gvisor/pkg/shim/v2"
+ "gvisor.dev/gvisor/shim/v2/cli"
)
func main() {
- shim.Run("io.containerd.runsc.v1", v2.New)
+ cli.Main()
}
diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go
index 32bf2a992..d3e5efd4f 100644
--- a/test/iptables/filter_output.go
+++ b/test/iptables/filter_output.go
@@ -441,9 +441,20 @@ func (FilterOutputDestination) Name() string {
// ContainerAction implements TestCase.ContainerAction.
func (FilterOutputDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
- rules := [][]string{
- {"-A", "OUTPUT", "-d", ip.String(), "-j", "ACCEPT"},
- {"-P", "OUTPUT", "DROP"},
+ var rules [][]string
+ if ipv6 {
+ rules = [][]string{
+ {"-A", "OUTPUT", "-d", ip.String(), "-j", "ACCEPT"},
+ // Allow solicited node multicast addresses so we can send neighbor
+ // solicitations.
+ {"-A", "OUTPUT", "-d", "ff02::1:ff00:0/104", "-j", "ACCEPT"},
+ {"-P", "OUTPUT", "DROP"},
+ }
+ } else {
+ rules = [][]string{
+ {"-A", "OUTPUT", "-d", ip.String(), "-j", "ACCEPT"},
+ {"-P", "OUTPUT", "DROP"},
+ }
}
if err := filterTableRules(ipv6, rules); err != nil {
return err
diff --git a/test/iptables/nat.go b/test/iptables/nat.go
index dd9a18339..b98d99fb8 100644
--- a/test/iptables/nat.go
+++ b/test/iptables/nat.go
@@ -577,11 +577,18 @@ func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net.
connCh := make(chan int)
errCh := make(chan error)
go func() {
- connFD, _, err := syscall.Accept(sockfd)
- if err != nil {
- errCh <- err
+ for {
+ connFD, _, err := syscall.Accept(sockfd)
+ if errors.Is(err, syscall.EINTR) {
+ continue
+ }
+ if err != nil {
+ errCh <- err
+ return
+ }
+ connCh <- connFD
+ return
}
- connCh <- connFD
}()
// Wait for accept() to return or for the context to finish.
diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl
index 232fdd4d4..9b5994d59 100644
--- a/test/runner/defs.bzl
+++ b/test/runner/defs.bzl
@@ -57,13 +57,13 @@ def _syscall_test(
platform,
use_tmpfs,
tags,
+ debug,
network = "none",
file_access = "exclusive",
overlay = False,
add_uds_tree = False,
vfs2 = False,
- fuse = False,
- debug = True):
+ fuse = False):
# Prepend "runsc" to non-native platform names.
full_platform = platform if platform == "native" else "runsc_" + platform
@@ -179,6 +179,7 @@ def syscall_test(
use_tmpfs = use_tmpfs,
add_uds_tree = add_uds_tree,
tags = platforms[default_platform] + vfs2_tags,
+ debug = debug,
vfs2 = True,
fuse = fuse,
)
@@ -194,6 +195,7 @@ def syscall_test(
use_tmpfs = False,
add_uds_tree = add_uds_tree,
tags = list(tags),
+ debug = debug,
)
for (platform, platform_tags) in platforms.items():
@@ -205,6 +207,7 @@ def syscall_test(
use_tmpfs = use_tmpfs,
add_uds_tree = add_uds_tree,
tags = platform_tags + tags,
+ debug = debug,
)
if add_overlay:
@@ -216,6 +219,7 @@ def syscall_test(
use_tmpfs = use_tmpfs,
add_uds_tree = add_uds_tree,
tags = platforms[default_platform] + tags,
+ debug = debug,
overlay = True,
)
@@ -232,6 +236,7 @@ def syscall_test(
use_tmpfs = use_tmpfs,
add_uds_tree = add_uds_tree,
tags = platforms[default_platform] + overlay_vfs2_tags,
+ debug = debug,
overlay = True,
vfs2 = True,
)
@@ -246,6 +251,7 @@ def syscall_test(
network = "host",
add_uds_tree = add_uds_tree,
tags = platforms[default_platform] + tags,
+ debug = debug,
)
if not use_tmpfs:
@@ -258,6 +264,7 @@ def syscall_test(
use_tmpfs = use_tmpfs,
add_uds_tree = add_uds_tree,
tags = platforms[default_platform] + tags,
+ debug = debug,
file_access = "shared",
)
_syscall_test(
@@ -268,6 +275,7 @@ def syscall_test(
use_tmpfs = use_tmpfs,
add_uds_tree = add_uds_tree,
tags = platforms[default_platform] + vfs2_tags,
+ debug = debug,
file_access = "shared",
vfs2 = True,
)
diff --git a/test/runner/runner.go b/test/runner/runner.go
index 22d535f8d..7ab2c3edf 100644
--- a/test/runner/runner.go
+++ b/test/runner/runner.go
@@ -53,6 +53,9 @@ var (
runscPath = flag.String("runsc", "", "path to runsc binary")
addUDSTree = flag.Bool("add-uds-tree", false, "expose a tree of UDS utilities for use in tests")
+ // TODO(gvisor.dev/issue/4572): properly support leak checking for runsc, and
+ // set to true as the default for the test runner.
+ leakCheck = flag.Bool("leak-check", false, "check for reference leaks")
)
// runTestCaseNative runs the test case directly on the host machine.
@@ -174,6 +177,9 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error {
if *addUDSTree {
args = append(args, "-fsgofer-host-uds")
}
+ if *leakCheck {
+ args = append(args, "-ref-leak-mode=log-names")
+ }
testLogDir := ""
if undeclaredOutputsDir, ok := syscall.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); ok {
diff --git a/test/runtimes/exclude/java11.csv b/test/runtimes/exclude/java11.csv
index d978baca7..e41441374 100644
--- a/test/runtimes/exclude/java11.csv
+++ b/test/runtimes/exclude/java11.csv
@@ -144,6 +144,7 @@ jdk/jfr/cmd/TestSplit.java,,java.lang.RuntimeException: 'Missing file' missing f
jdk/jfr/cmd/TestSummary.java,,java.lang.RuntimeException: 'Missing file' missing from stdout/stderr
jdk/jfr/event/compiler/TestCompilerStats.java,,java.lang.RuntimeException: Field nmetodsSize not in event
jdk/jfr/event/metadata/TestDefaultConfigurations.java,,Setting 'threshold' in event 'jdk.SecurityPropertyModification' was not configured in the configuration 'default'
+jdk/jfr/event/oldobject/TestLargeRootSet.java,,Flaky - `main' threw exception: java.lang.RuntimeException: Could not find root object
jdk/jfr/event/runtime/TestActiveSettingEvent.java,,java.lang.Exception: Could not find setting with name jdk.X509Validation#threshold
jdk/jfr/event/runtime/TestModuleEvents.java,,java.lang.RuntimeException: assertEquals: expected jdk.proxy1 to equal java.base
jdk/jfr/event/runtime/TestNetworkUtilizationEvent.java,,
diff --git a/test/runtimes/exclude/nodejs12.4.0.csv b/test/runtimes/exclude/nodejs12.4.0.csv
index ba993814f..c4e7917ec 100644
--- a/test/runtimes/exclude/nodejs12.4.0.csv
+++ b/test/runtimes/exclude/nodejs12.4.0.csv
@@ -1,31 +1,22 @@
test name,bug id,comment
async-hooks/test-statwatcher.js,https://github.com/nodejs/node/issues/21425,Check for fix inclusion in nodejs releases after 2020-03-29
-benchmark/test-benchmark-fs.js,,
-benchmark/test-benchmark-napi.js,,
+benchmark/test-benchmark-fs.js,,Broken test
+benchmark/test-benchmark-napi.js,,Broken test
doctool/test-make-doc.js,b/68848110,Expected to fail.
internet/test-dgram-multicast-set-interface-lo.js,b/162798882,
-internet/test-doctool-versions.js,,
-internet/test-uv-threadpool-schedule.js,,
-parallel/test-cluster-dgram-reuse.js,b/64024294,
+internet/test-doctool-versions.js,,Broken test
+internet/test-uv-threadpool-schedule.js,,Broken test
parallel/test-dgram-bind-fd.js,b/132447356,
parallel/test-dgram-socket-buffer-size.js,b/68847921,
parallel/test-dns-channel-timeout.js,b/161893056,
-parallel/test-fs-access.js,,
-parallel/test-fs-watchfile.js,,Flaky - File already exists error
-parallel/test-fs-write-stream.js,b/166819807,Flaky
-parallel/test-fs-write-stream-double-close.js,b/166819807,Flaky
-parallel/test-fs-write-stream-throw-type-error.js,b/166819807,Flaky
-parallel/test-http-writable-true-after-close.js,,Flaky - Mismatched <anonymous> function calls. Expected exactly 1 actual 2
+parallel/test-fs-access.js,,Broken test
+parallel/test-fs-watchfile.js,b/166819807,Flaky - VFS1 only
+parallel/test-fs-write-stream.js,b/166819807,Flaky - VFS1 only
+parallel/test-fs-write-stream-double-close.js,b/166819807,Flaky - VFS1 only
+parallel/test-fs-write-stream-throw-type-error.js,b/166819807,Flaky - VFS1 only
+parallel/test-http-writable-true-after-close.js,b/171301436,Flaky - Mismatched <anonymous> function calls. Expected exactly 1 actual 2
parallel/test-os.js,b/63997097,
-parallel/test-net-server-listen-options.js,,Flaky - EADDRINUSE
-parallel/test-process-uid-gid.js,,
-parallel/test-tls-cli-min-version-1.0.js,,Flaky - EADDRINUSE
-parallel/test-tls-cli-min-version-1.1.js,,Flaky - EADDRINUSE
-parallel/test-tls-cli-min-version-1.2.js,,Flaky - EADDRINUSE
-parallel/test-tls-cli-min-version-1.3.js,,Flaky - EADDRINUSE
-parallel/test-tls-cli-max-version-1.2.js,,Flaky - EADDRINUSE
-parallel/test-tls-cli-max-version-1.3.js,,Flaky - EADDRINUSE
-parallel/test-tls-min-max-version.js,,Flaky - EADDRINUSE
+parallel/test-process-uid-gid.js,,Does not work inside Docker with gid nobody
pseudo-tty/test-assert-colors.js,b/162801321,
pseudo-tty/test-assert-no-color.js,b/162801321,
pseudo-tty/test-assert-position-indicator.js,b/162801321,
@@ -48,11 +39,7 @@ pseudo-tty/test-tty-stdout-resize.js,b/162801321,
pseudo-tty/test-tty-stream-constructors.js,b/162801321,
pseudo-tty/test-tty-window-size.js,b/162801321,
pseudo-tty/test-tty-wrap.js,b/162801321,
-pummel/test-heapdump-http2.js,,Flaky
-pummel/test-net-pingpong.js,,
+pummel/test-net-pingpong.js,,Broken test
pummel/test-vm-memleak.js,b/162799436,
-pummel/test-watch-file.js,,Flaky - Timeout
-sequential/test-child-process-pass-fd.js,b/63926391,Flaky
-sequential/test-https-connect-localport.js,,Flaky - EADDRINUSE
-sequential/test-net-bytes-per-incoming-chunk-overhead.js,,flaky - timeout
-tick-processor/test-tick-processor-builtin.js,,
+pummel/test-watch-file.js,,Flaky - VFS1 only
+tick-processor/test-tick-processor-builtin.js,,Broken test
diff --git a/test/runtimes/exclude/php7.3.6.csv b/test/runtimes/exclude/php7.3.6.csv
index a73f3bcfb..9e1f4c050 100644
--- a/test/runtimes/exclude/php7.3.6.csv
+++ b/test/runtimes/exclude/php7.3.6.csv
@@ -26,12 +26,13 @@ ext/standard/tests/file/php_fd_wrapper_01.phpt,,
ext/standard/tests/file/php_fd_wrapper_02.phpt,,
ext/standard/tests/file/php_fd_wrapper_03.phpt,,
ext/standard/tests/file/php_fd_wrapper_04.phpt,,
-ext/standard/tests/file/realpath_bug77484.phpt,b/162894969,
+ext/standard/tests/file/realpath_bug77484.phpt,b/162894969,VFS1 only failure
ext/standard/tests/file/rename_variation.phpt,b/68717309,
ext/standard/tests/file/symlink_link_linkinfo_is_link_variation4.phpt,b/162895341,
ext/standard/tests/file/symlink_link_linkinfo_is_link_variation8.phpt,b/162896223,
ext/standard/tests/general_functions/escapeshellarg_bug71270.phpt,,
ext/standard/tests/general_functions/escapeshellcmd_bug71270.phpt,,
+ext/standard/tests/network/bug20134.phpt,b/171347929,Flaky
ext/standard/tests/streams/proc_open_bug60120.phpt,,Flaky until php-src 3852a35fdbcb
ext/standard/tests/streams/proc_open_bug69900.phpt,,Flaky
ext/standard/tests/streams/stream_socket_sendto.phpt,,
diff --git a/test/runtimes/exclude/python3.7.3.csv b/test/runtimes/exclude/python3.7.3.csv
index 8760f8951..911f22855 100644
--- a/test/runtimes/exclude/python3.7.3.csv
+++ b/test/runtimes/exclude/python3.7.3.csv
@@ -18,4 +18,3 @@ test_selectors,b/76116849,OSError not raised with epoll
test_smtplib,b/162980434,unclosed sockets
test_signal,,Flaky - signal: alarm clock
test_socket,b/75983380,
-test_subprocess,b/162980831,
diff --git a/test/runtimes/proctor/main.go b/test/runtimes/proctor/main.go
index e5607ac92..81cb68381 100644
--- a/test/runtimes/proctor/main.go
+++ b/test/runtimes/proctor/main.go
@@ -22,6 +22,7 @@ import (
"log"
"os"
"strings"
+ "syscall"
"gvisor.dev/gvisor/test/runtimes/proctor/lib"
)
@@ -33,6 +34,29 @@ var (
pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children")
)
+// setNumFilesLimit changes the NOFILE soft rlimit if it is too high.
+func setNumFilesLimit() error {
+ // In docker containers, the default value of the NOFILE limit is
+ // 1048576. A few runtime tests (e.g. python:test_subprocess)
+ // enumerates all possible file descriptors and these tests can fail by
+ // timeout if the NOFILE limit is too high. On gVisor, syscalls are
+ // slower so these tests will need even more time to pass.
+ const nofile = 32768
+ rLimit := syscall.Rlimit{}
+ err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit)
+ if err != nil {
+ return fmt.Errorf("failed to get RLIMIT_NOFILE: %v", err)
+ }
+ if rLimit.Cur > nofile {
+ rLimit.Cur = nofile
+ err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rLimit)
+ if err != nil {
+ return fmt.Errorf("failed to set RLIMIT_NOFILE: %v", err)
+ }
+ }
+ return nil
+}
+
func main() {
flag.Parse()
@@ -74,6 +98,10 @@ func main() {
tests = strings.Split(*testNames, ",")
}
+ if err := setNumFilesLimit(); err != nil {
+ log.Fatalf("%v", err)
+ }
+
// Run tests.
cmds := tr.TestCmds(tests)
for _, cmd := range cmds {
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 36b7f1b97..c94c1d5bd 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -1285,6 +1285,7 @@ cc_binary(
"//test/util:mount_util",
"//test/util:multiprocess_util",
"//test/util:posix_error",
+ "//test/util:save_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
@@ -2434,6 +2435,7 @@ cc_library(
"@com_google_absl//absl/memory",
gtest,
"//test/util:posix_error",
+ "//test/util:save_util",
"//test/util:test_util",
],
alwayslink = 1,
@@ -3441,6 +3443,7 @@ cc_binary(
"@com_google_absl//absl/strings",
gtest,
"//test/util:posix_error",
+ "//test/util:save_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
@@ -3624,6 +3627,7 @@ cc_binary(
"//test/util:signal_util",
"//test/util:test_util",
"//test/util:thread_util",
+ "//test/util:timer_util",
],
)
diff --git a/test/syscalls/linux/flock.cc b/test/syscalls/linux/flock.cc
index 549141cbb..b286e84fe 100644
--- a/test/syscalls/linux/flock.cc
+++ b/test/syscalls/linux/flock.cc
@@ -216,14 +216,29 @@ TEST_F(FlockTest, TestSharedLockFailExclusiveHolderBlocking_NoRandomSave) {
const FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR));
- // Register a signal handler for SIGALRM and set an alarm that will go off
- // while blocking in the subsequent flock() call. This will interrupt flock()
- // and cause it to return EINTR.
+ // Make sure that a blocking flock() call will return EINTR when interrupted
+ // by a signal. Create a timer that will go off while blocking on flock(), and
+ // register the corresponding signal handler.
+ auto timer = ASSERT_NO_ERRNO_AND_VALUE(
+ TimerCreate(CLOCK_MONOTONIC, sigevent_t{
+ .sigev_signo = SIGALRM,
+ .sigev_notify = SIGEV_SIGNAL,
+ }));
+
struct sigaction act = {};
act.sa_handler = trivial_handler;
ASSERT_THAT(sigaction(SIGALRM, &act, NULL), SyscallSucceeds());
- ASSERT_THAT(ualarm(10000, 0), SyscallSucceeds());
+
+ // Now that the signal handler is registered, set the timer. Set an interval
+ // so that it's ok if the timer goes off before we call flock.
+ ASSERT_NO_ERRNO(
+ timer.Set(0, itimerspec{
+ .it_interval = absl::ToTimespec(absl::Milliseconds(10)),
+ .it_value = absl::ToTimespec(absl::Milliseconds(10)),
+ }));
+
ASSERT_THAT(flock(fd.get(), LOCK_SH), SyscallFailsWithErrno(EINTR));
+ timer.reset();
// Unlock
ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0));
@@ -258,14 +273,29 @@ TEST_F(FlockTest, TestExclusiveLockFailExclusiveHolderBlocking_NoRandomSave) {
const FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR));
- // Register a signal handler for SIGALRM and set an alarm that will go off
- // while blocking in the subsequent flock() call. This will interrupt flock()
- // and cause it to return EINTR.
+ // Make sure that a blocking flock() call will return EINTR when interrupted
+ // by a signal. Create a timer that will go off while blocking on flock(), and
+ // register the corresponding signal handler.
+ auto timer = ASSERT_NO_ERRNO_AND_VALUE(
+ TimerCreate(CLOCK_MONOTONIC, sigevent_t{
+ .sigev_signo = SIGALRM,
+ .sigev_notify = SIGEV_SIGNAL,
+ }));
+
struct sigaction act = {};
act.sa_handler = trivial_handler;
ASSERT_THAT(sigaction(SIGALRM, &act, NULL), SyscallSucceeds());
- ASSERT_THAT(ualarm(10000, 0), SyscallSucceeds());
+
+ // Now that the signal handler is registered, set the timer. Set an interval
+ // so that it's ok if the timer goes off before we call flock.
+ ASSERT_NO_ERRNO(
+ timer.Set(0, itimerspec{
+ .it_interval = absl::ToTimespec(absl::Milliseconds(10)),
+ .it_value = absl::ToTimespec(absl::Milliseconds(10)),
+ }));
+
ASSERT_THAT(flock(fd.get(), LOCK_EX), SyscallFailsWithErrno(EINTR));
+ timer.reset();
// Unlock
ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0));
diff --git a/test/syscalls/linux/mknod.cc b/test/syscalls/linux/mknod.cc
index ae65d366b..1635c6d0c 100644
--- a/test/syscalls/linux/mknod.cc
+++ b/test/syscalls/linux/mknod.cc
@@ -93,15 +93,15 @@ TEST(MknodTest, MknodOnExistingPathFails) {
}
TEST(MknodTest, UnimplementedTypesReturnError) {
- const std::string path = NewTempAbsPath();
+ // TODO(gvisor.dev/issue/1624): These file types are supported by some
+ // filesystems in VFS2, so this test should be deleted along with VFS1.
+ SKIP_IF(!IsRunningWithVFS1());
- if (IsRunningWithVFS1()) {
- ASSERT_THAT(mknod(path.c_str(), S_IFSOCK, 0),
- SyscallFailsWithErrno(EOPNOTSUPP));
- }
- // These will fail on linux as well since we don't have CAP_MKNOD.
- ASSERT_THAT(mknod(path.c_str(), S_IFCHR, 0), SyscallFailsWithErrno(EPERM));
- ASSERT_THAT(mknod(path.c_str(), S_IFBLK, 0), SyscallFailsWithErrno(EPERM));
+ const std::string path = NewTempAbsPath();
+ EXPECT_THAT(mknod(path.c_str(), S_IFSOCK, 0),
+ SyscallFailsWithErrno(EOPNOTSUPP));
+ EXPECT_THAT(mknod(path.c_str(), S_IFCHR, 0), SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(mknod(path.c_str(), S_IFBLK, 0), SyscallFailsWithErrno(EPERM));
}
TEST(MknodTest, Socket) {
@@ -125,6 +125,16 @@ TEST(MknodTest, Socket) {
ASSERT_THAT(unlink(filename.c_str()), SyscallSucceeds());
}
+PosixErrorOr<FileDescriptor> OpenRetryEINTR(std::string const& path, int flags,
+ mode_t mode = 0) {
+ while (true) {
+ auto maybe_fd = Open(path, flags, mode);
+ if (maybe_fd.ok() || maybe_fd.error().errno_value() != EINTR) {
+ return maybe_fd;
+ }
+ }
+}
+
TEST(MknodTest, Fifo) {
const std::string fifo = NewTempAbsPath();
ASSERT_THAT(mknod(fifo.c_str(), S_IFIFO | S_IRUSR | S_IWUSR, 0),
@@ -139,14 +149,16 @@ TEST(MknodTest, Fifo) {
// Read-end of the pipe.
ScopedThread t([&fifo, &buf, &msg]() {
- FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_RDONLY));
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenRetryEINTR(fifo.c_str(), O_RDONLY));
EXPECT_THAT(ReadFd(fd.get(), buf.data(), buf.size()),
SyscallSucceedsWithValue(msg.length()));
EXPECT_EQ(msg, std::string(buf.data()));
});
// Write-end of the pipe.
- FileDescriptor wfd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_WRONLY));
+ FileDescriptor wfd =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenRetryEINTR(fifo.c_str(), O_WRONLY));
EXPECT_THAT(WriteFd(wfd.get(), msg.c_str(), msg.length()),
SyscallSucceedsWithValue(msg.length()));
}
@@ -164,15 +176,16 @@ TEST(MknodTest, FifoOtrunc) {
std::vector<char> buf(512);
// Read-end of the pipe.
ScopedThread t([&fifo, &buf, &msg]() {
- FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_RDONLY));
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenRetryEINTR(fifo.c_str(), O_RDONLY));
EXPECT_THAT(ReadFd(fd.get(), buf.data(), buf.size()),
SyscallSucceedsWithValue(msg.length()));
EXPECT_EQ(msg, std::string(buf.data()));
});
// Write-end of the pipe.
- FileDescriptor wfd =
- ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_WRONLY | O_TRUNC));
+ FileDescriptor wfd = ASSERT_NO_ERRNO_AND_VALUE(
+ OpenRetryEINTR(fifo.c_str(), O_WRONLY | O_TRUNC));
EXPECT_THAT(WriteFd(wfd.get(), msg.c_str(), msg.length()),
SyscallSucceedsWithValue(msg.length()));
}
@@ -192,14 +205,15 @@ TEST(MknodTest, FifoTruncNoOp) {
std::vector<char> buf(512);
// Read-end of the pipe.
ScopedThread t([&fifo, &buf, &msg]() {
- FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_RDONLY));
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenRetryEINTR(fifo.c_str(), O_RDONLY));
EXPECT_THAT(ReadFd(fd.get(), buf.data(), buf.size()),
SyscallSucceedsWithValue(msg.length()));
EXPECT_EQ(msg, std::string(buf.data()));
});
- FileDescriptor wfd =
- ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_WRONLY | O_TRUNC));
+ FileDescriptor wfd = ASSERT_NO_ERRNO_AND_VALUE(
+ OpenRetryEINTR(fifo.c_str(), O_WRONLY | O_TRUNC));
EXPECT_THAT(ftruncate(wfd.get(), 0), SyscallFailsWithErrno(EINVAL));
EXPECT_THAT(WriteFd(wfd.get(), msg.c_str(), msg.length()),
SyscallSucceedsWithValue(msg.length()));
diff --git a/test/syscalls/linux/mount.cc b/test/syscalls/linux/mount.cc
index 3aab25b23..d65b7d031 100644
--- a/test/syscalls/linux/mount.cc
+++ b/test/syscalls/linux/mount.cc
@@ -34,6 +34,7 @@
#include "test/util/mount_util.h"
#include "test/util/multiprocess_util.h"
#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
@@ -131,7 +132,9 @@ TEST(MountTest, UmountDetach) {
ASSERT_NO_ERRNO_AND_VALUE(Mount("", dir.path(), "tmpfs", 0, "mode=0700",
/* umountflags= */ MNT_DETACH));
const struct stat after = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path()));
- EXPECT_NE(before.st_ino, after.st_ino);
+ EXPECT_FALSE(before.st_dev == after.st_dev && before.st_ino == after.st_ino)
+ << "mount point has device number " << before.st_dev
+ << " and inode number " << before.st_ino << " before and after mount";
// Create files in the new mount.
constexpr char kContents[] = "no no no";
@@ -147,12 +150,14 @@ TEST(MountTest, UmountDetach) {
// Unmount the tmpfs.
mount.Release()();
- // Only check for inode number equality if the directory is not in overlayfs.
- // If xino option is not enabled and if all overlayfs layers do not belong to
- // the same filesystem then "the value of st_ino for directory objects may not
- // be persistent and could change even while the overlay filesystem is
- // mounted." -- Documentation/filesystems/overlayfs.txt
- if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) {
+ // Inode numbers for gofer-accessed files may change across save/restore.
+ //
+ // For overlayfs, if xino option is not enabled and if all overlayfs layers do
+ // not belong to the same filesystem then "the value of st_ino for directory
+ // objects may not be persistent and could change even while the overlay
+ // filesystem is mounted." -- Documentation/filesystems/overlayfs.txt
+ if (!IsRunningWithSaveRestore() &&
+ !ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) {
const struct stat after2 = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path()));
EXPECT_EQ(before.st_ino, after2.st_ino);
}
@@ -214,18 +219,23 @@ TEST(MountTest, MountTmpfs) {
const struct stat s = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path()));
EXPECT_EQ(s.st_mode, S_IFDIR | 0700);
- EXPECT_NE(s.st_ino, before.st_ino);
+ EXPECT_FALSE(before.st_dev == s.st_dev && before.st_ino == s.st_ino)
+ << "mount point has device number " << before.st_dev
+ << " and inode number " << before.st_ino << " before and after mount";
EXPECT_NO_ERRNO(Open(JoinPath(dir.path(), "foo"), O_CREAT | O_RDWR, 0777));
}
// Now that dir is unmounted again, we should have the old inode back.
- // Only check for inode number equality if the directory is not in overlayfs.
- // If xino option is not enabled and if all overlayfs layers do not belong to
- // the same filesystem then "the value of st_ino for directory objects may not
- // be persistent and could change even while the overlay filesystem is
- // mounted." -- Documentation/filesystems/overlayfs.txt
- if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) {
+ //
+ // Inode numbers for gofer-accessed files may change across save/restore.
+ //
+ // For overlayfs, if xino option is not enabled and if all overlayfs layers do
+ // not belong to the same filesystem then "the value of st_ino for directory
+ // objects may not be persistent and could change even while the overlay
+ // filesystem is mounted." -- Documentation/filesystems/overlayfs.txt
+ if (!IsRunningWithSaveRestore() &&
+ !ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) {
const struct stat after = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path()));
EXPECT_EQ(before.st_ino, after.st_ino);
}
diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc
index b558e3a01..a7c46adbf 100644
--- a/test/syscalls/linux/packet_socket_raw.cc
+++ b/test/syscalls/linux/packet_socket_raw.cc
@@ -664,6 +664,17 @@ TEST_P(RawPacketTest, SetAndGetSocketLinger) {
EXPECT_EQ(0, memcmp(&sl, &got_linger, length));
}
+TEST_P(RawPacketTest, GetSocketAcceptConn) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int got = -1;
+ socklen_t length = sizeof(got);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ACCEPTCONN, &got, &length),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_EQ(length, sizeof(got));
+ EXPECT_EQ(got, 0);
+}
INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest,
::testing::Values(ETH_P_IP, ETH_P_ALL));
diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc
index c097c9187..06d9dbf65 100644
--- a/test/syscalls/linux/pipe.cc
+++ b/test/syscalls/linux/pipe.cc
@@ -569,30 +569,38 @@ TEST_P(PipeTest, Streaming) {
// Size() requires 2 syscalls, call it once and remember the value.
const int pipe_size = Size();
+ const size_t streamed_bytes = 4 * pipe_size;
absl::Notification notify;
- ScopedThread t([this, &notify, pipe_size]() {
+ ScopedThread t([&, this]() {
+ std::vector<char> buf(1024);
// Don't start until it's full.
notify.WaitForNotification();
- for (int i = 0; i < pipe_size; i++) {
- int rbuf;
- ASSERT_THAT(read(rfd_.get(), &rbuf, sizeof(rbuf)),
- SyscallSucceedsWithValue(sizeof(rbuf)));
- EXPECT_EQ(rbuf, i);
+ ssize_t total = 0;
+ while (total < streamed_bytes) {
+ ASSERT_THAT(read(rfd_.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+ total += buf.size();
}
});
// Write 4 bytes * pipe_size. It will fill up the pipe once, notify the reader
// to start. Then we write pipe size worth 3 more times to ensure the reader
// can follow along.
+ //
+ // The size of each write (which is determined by buf.size()) must be smaller
+ // than the size of the pipe (which, in the "smallbuffer" configuration, is 1
+ // page) for the check for notify.Notify() below to be correct.
+ std::vector<char> buf(1024);
+ RandomizeBuffer(buf.data(), buf.size());
ssize_t total = 0;
- for (int i = 0; i < pipe_size; i++) {
- ssize_t written = write(wfd_.get(), &i, sizeof(i));
- ASSERT_THAT(written, SyscallSucceedsWithValue(sizeof(i)));
- total += written;
+ while (total < streamed_bytes) {
+ ASSERT_THAT(write(wfd_.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+ total += buf.size();
// Is the next write about to fill up the buffer? Wake up the reader once.
- if (total < pipe_size && (total + written) >= pipe_size) {
+ if (total < pipe_size && (total + buf.size()) >= pipe_size) {
notify.Notify();
}
}
diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc
index e8fcc4439..7a0f33dff 100644
--- a/test/syscalls/linux/proc.cc
+++ b/test/syscalls/linux/proc.cc
@@ -26,6 +26,7 @@
#include <string.h>
#include <sys/mman.h>
#include <sys/prctl.h>
+#include <sys/ptrace.h>
#include <sys/stat.h>
#include <sys/statfs.h>
#include <sys/utsname.h>
@@ -512,6 +513,414 @@ TEST(ProcSelfAuxv, EntryValues) {
EXPECT_EQ(i, proc_auxv.size());
}
+// Just open and read a part of /proc/self/mem, check that we can read an item.
+TEST(ProcPidMem, Read) {
+ auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/mem", O_RDONLY));
+ char input[] = "hello-world";
+ char output[sizeof(input)];
+ ASSERT_THAT(pread(memfd.get(), output, sizeof(output),
+ reinterpret_cast<off_t>(input)),
+ SyscallSucceedsWithValue(sizeof(input)));
+ ASSERT_STREQ(input, output);
+}
+
+// Perform read on an unmapped region.
+TEST(ProcPidMem, Unmapped) {
+ // Strategy: map then unmap, so we have a guaranteed unmapped region
+ auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/mem", O_RDONLY));
+ Mapping mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ // Fill it with things
+ memset(mapping.ptr(), 'x', mapping.len());
+ char expected = 'x', output;
+ ASSERT_THAT(pread(memfd.get(), &output, sizeof(output),
+ reinterpret_cast<off_t>(mapping.ptr())),
+ SyscallSucceedsWithValue(sizeof(output)));
+ ASSERT_EQ(expected, output);
+
+ // Unmap region again
+ ASSERT_THAT(munmap(mapping.ptr(), mapping.len()), SyscallSucceeds());
+
+ // Now we want EIO error
+ ASSERT_THAT(pread(memfd.get(), &output, sizeof(output),
+ reinterpret_cast<off_t>(mapping.ptr())),
+ SyscallFailsWithErrno(EIO));
+}
+
+// Perform read repeatedly to verify offset change.
+TEST(ProcPidMem, RepeatedRead) {
+ auto const num_reads = 3;
+ char expected[] = "01234567890abcdefghijkl";
+ char output[sizeof(expected) / num_reads];
+
+ auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/mem", O_RDONLY));
+ ASSERT_THAT(lseek(memfd.get(), reinterpret_cast<off_t>(&expected), SEEK_SET),
+ SyscallSucceedsWithValue(reinterpret_cast<off_t>(&expected)));
+ for (auto i = 0; i < num_reads; i++) {
+ ASSERT_THAT(read(memfd.get(), &output, sizeof(output)),
+ SyscallSucceedsWithValue(sizeof(output)));
+ ASSERT_EQ(strncmp(&expected[i * sizeof(output)], output, sizeof(output)),
+ 0);
+ }
+}
+
+// Perform seek operations repeatedly.
+TEST(ProcPidMem, RepeatedSeek) {
+ auto const num_reads = 3;
+ char expected[] = "01234567890abcdefghijkl";
+ char output[sizeof(expected) / num_reads];
+
+ auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/mem", O_RDONLY));
+ ASSERT_THAT(lseek(memfd.get(), reinterpret_cast<off_t>(&expected), SEEK_SET),
+ SyscallSucceedsWithValue(reinterpret_cast<off_t>(&expected)));
+ // Read from start
+ ASSERT_THAT(read(memfd.get(), &output, sizeof(output)),
+ SyscallSucceedsWithValue(sizeof(output)));
+ ASSERT_EQ(strncmp(&expected[0 * sizeof(output)], output, sizeof(output)), 0);
+ // Skip ahead one read
+ ASSERT_THAT(lseek(memfd.get(), sizeof(output), SEEK_CUR),
+ SyscallSucceedsWithValue(reinterpret_cast<off_t>(&expected) +
+ sizeof(output) * 2));
+ // Do read again
+ ASSERT_THAT(read(memfd.get(), &output, sizeof(output)),
+ SyscallSucceedsWithValue(sizeof(output)));
+ ASSERT_EQ(strncmp(&expected[2 * sizeof(output)], output, sizeof(output)), 0);
+ // Skip back three reads
+ ASSERT_THAT(lseek(memfd.get(), -3 * sizeof(output), SEEK_CUR),
+ SyscallSucceedsWithValue(reinterpret_cast<off_t>(&expected)));
+ // Do read again
+ ASSERT_THAT(read(memfd.get(), &output, sizeof(output)),
+ SyscallSucceedsWithValue(sizeof(output)));
+ ASSERT_EQ(strncmp(&expected[0 * sizeof(output)], output, sizeof(output)), 0);
+ // Check that SEEK_END does not work
+ ASSERT_THAT(lseek(memfd.get(), 0, SEEK_END), SyscallFailsWithErrno(EINVAL));
+}
+
+// Perform read past an allocated memory region.
+TEST(ProcPidMem, PartialRead) {
+ // Strategy: map large region, then do unmap and remap smaller region
+ auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/mem", O_RDONLY));
+
+ Mapping mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ ASSERT_THAT(munmap(mapping.ptr(), mapping.len()), SyscallSucceeds());
+ Mapping smaller_mapping = ASSERT_NO_ERRNO_AND_VALUE(
+ Mmap(mapping.ptr(), kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0));
+
+ // Fill it with things
+ memset(smaller_mapping.ptr(), 'x', smaller_mapping.len());
+
+ // Now we want no error
+ char expected[] = {'x'};
+ std::unique_ptr<char[]> output(new char[kPageSize]);
+ off_t read_offset =
+ reinterpret_cast<off_t>(smaller_mapping.ptr()) + kPageSize - 1;
+ ASSERT_THAT(
+ pread(memfd.get(), output.get(), sizeof(output.get()), read_offset),
+ SyscallSucceedsWithValue(sizeof(expected)));
+ // Since output is larger, than expected we have to do manual compare
+ ASSERT_EQ(expected[0], (output).get()[0]);
+}
+
+// Perform read on /proc/[pid]/mem after exit.
+TEST(ProcPidMem, AfterExit) {
+ int pfd1[2] = {};
+ int pfd2[2] = {};
+
+ char expected[] = "hello-world";
+
+ ASSERT_THAT(pipe(pfd1), SyscallSucceeds());
+ ASSERT_THAT(pipe(pfd2), SyscallSucceeds());
+
+ // Create child process
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // Close reading end of first pipe
+ close(pfd1[0]);
+
+ // Tell parent about location of input
+ char ok = 1;
+ TEST_CHECK(WriteFd(pfd1[1], &ok, sizeof(ok)) == sizeof(ok));
+ TEST_PCHECK(close(pfd1[1]) == 0);
+
+ // Close writing end of second pipe
+ TEST_PCHECK(close(pfd2[1]) == 0);
+
+ // Await parent OK to die
+ ok = 0;
+ TEST_CHECK(ReadFd(pfd2[0], &ok, sizeof(ok)) == sizeof(ok));
+
+ // Close rest pipes
+ TEST_PCHECK(close(pfd2[0]) == 0);
+ _exit(0);
+ }
+
+ // In parent process.
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ // Close writing end of first pipe
+ EXPECT_THAT(close(pfd1[1]), SyscallSucceeds());
+
+ // Wait for child to be alive and well
+ char ok = 0;
+ EXPECT_THAT(ReadFd(pfd1[0], &ok, sizeof(ok)),
+ SyscallSucceedsWithValue(sizeof(ok)));
+ // Close reading end of first pipe
+ EXPECT_THAT(close(pfd1[0]), SyscallSucceeds());
+
+ // Open /proc/pid/mem fd
+ std::string mempath = absl::StrCat("/proc/", child_pid, "/mem");
+ auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open(mempath, O_RDONLY));
+
+ // Expect that we can read
+ char output[sizeof(expected)];
+ EXPECT_THAT(pread(memfd.get(), &output, sizeof(output),
+ reinterpret_cast<off_t>(&expected)),
+ SyscallSucceedsWithValue(sizeof(output)));
+ EXPECT_STREQ(expected, output);
+
+ // Tell proc its ok to go
+ EXPECT_THAT(close(pfd2[0]), SyscallSucceeds());
+ ok = 1;
+ EXPECT_THAT(WriteFd(pfd2[1], &ok, sizeof(ok)),
+ SyscallSucceedsWithValue(sizeof(ok)));
+ EXPECT_THAT(close(pfd2[1]), SyscallSucceeds());
+
+ // Expect termination
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0), SyscallSucceeds());
+
+ // Expect that we can't read anymore
+ EXPECT_THAT(pread(memfd.get(), &output, sizeof(output),
+ reinterpret_cast<off_t>(&expected)),
+ SyscallSucceedsWithValue(0));
+}
+
+// Read from /proc/[pid]/mem with different UID/GID and attached state.
+TEST(ProcPidMem, DifferentUserAttached) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID)));
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_DAC_OVERRIDE)));
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_PTRACE)));
+
+ int pfd1[2] = {};
+ int pfd2[2] = {};
+
+ ASSERT_THAT(pipe(pfd1), SyscallSucceeds());
+ ASSERT_THAT(pipe(pfd2), SyscallSucceeds());
+
+ // Create child process
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // Close reading end of first pipe
+ close(pfd1[0]);
+
+ // Tell parent about location of input
+ char input[] = "hello-world";
+ off_t input_location = reinterpret_cast<off_t>(input);
+ TEST_CHECK(WriteFd(pfd1[1], &input_location, sizeof(input_location)) ==
+ sizeof(input_location));
+ TEST_PCHECK(close(pfd1[1]) == 0);
+
+ // Close writing end of second pipe
+ TEST_PCHECK(close(pfd2[1]) == 0);
+
+ // Await parent OK to die
+ char ok = 0;
+ TEST_CHECK(ReadFd(pfd2[0], &ok, sizeof(ok)) == sizeof(ok));
+
+ // Close rest pipes
+ TEST_PCHECK(close(pfd2[0]) == 0);
+ _exit(0);
+ }
+
+ // In parent process.
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ // Close writing end of first pipe
+ EXPECT_THAT(close(pfd1[1]), SyscallSucceeds());
+
+ // Read target location from child
+ off_t target_location;
+ EXPECT_THAT(ReadFd(pfd1[0], &target_location, sizeof(target_location)),
+ SyscallSucceedsWithValue(sizeof(target_location)));
+ // Close reading end of first pipe
+ EXPECT_THAT(close(pfd1[0]), SyscallSucceeds());
+
+ ScopedThread([&] {
+ // Attach to child subprocess without stopping it
+ EXPECT_THAT(ptrace(PTRACE_SEIZE, child_pid, NULL, NULL), SyscallSucceeds());
+
+ // Keep capabilities after setuid
+ EXPECT_THAT(prctl(PR_SET_KEEPCAPS, 1, 0, 0, 0), SyscallSucceeds());
+ constexpr int kNobody = 65534;
+ EXPECT_THAT(syscall(SYS_setuid, kNobody), SyscallSucceeds());
+
+ // Only restore CAP_SYS_PTRACE and CAP_DAC_OVERRIDE
+ EXPECT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, true));
+ EXPECT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, true));
+
+ // Open /proc/pid/mem fd
+ std::string mempath = absl::StrCat("/proc/", child_pid, "/mem");
+ auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open(mempath, O_RDONLY));
+ char expected[] = "hello-world";
+ char output[sizeof(expected)];
+ EXPECT_THAT(pread(memfd.get(), output, sizeof(output),
+ reinterpret_cast<off_t>(target_location)),
+ SyscallSucceedsWithValue(sizeof(output)));
+ EXPECT_STREQ(expected, output);
+
+ // Tell proc its ok to go
+ EXPECT_THAT(close(pfd2[0]), SyscallSucceeds());
+ char ok = 1;
+ EXPECT_THAT(WriteFd(pfd2[1], &ok, sizeof(ok)),
+ SyscallSucceedsWithValue(sizeof(ok)));
+ EXPECT_THAT(close(pfd2[1]), SyscallSucceeds());
+
+ // Expect termination
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0)
+ << " status " << status;
+ });
+}
+
+// Attempt to read from /proc/[pid]/mem with different UID/GID.
+TEST(ProcPidMem, DifferentUser) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID)));
+
+ int pfd1[2] = {};
+ int pfd2[2] = {};
+
+ ASSERT_THAT(pipe(pfd1), SyscallSucceeds());
+ ASSERT_THAT(pipe(pfd2), SyscallSucceeds());
+
+ // Create child process
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // Close reading end of first pipe
+ close(pfd1[0]);
+
+ // Tell parent about location of input
+ char input[] = "hello-world";
+ off_t input_location = reinterpret_cast<off_t>(input);
+ TEST_CHECK(WriteFd(pfd1[1], &input_location, sizeof(input_location)) ==
+ sizeof(input_location));
+ TEST_PCHECK(close(pfd1[1]) == 0);
+
+ // Close writing end of second pipe
+ TEST_PCHECK(close(pfd2[1]) == 0);
+
+ // Await parent OK to die
+ char ok = 0;
+ TEST_CHECK(ReadFd(pfd2[0], &ok, sizeof(ok)) == sizeof(ok));
+
+ // Close rest pipes
+ TEST_PCHECK(close(pfd2[0]) == 0);
+ _exit(0);
+ }
+
+ // In parent process.
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ // Close writing end of first pipe
+ EXPECT_THAT(close(pfd1[1]), SyscallSucceeds());
+
+ // Read target location from child
+ off_t target_location;
+ EXPECT_THAT(ReadFd(pfd1[0], &target_location, sizeof(target_location)),
+ SyscallSucceedsWithValue(sizeof(target_location)));
+ // Close reading end of first pipe
+ EXPECT_THAT(close(pfd1[0]), SyscallSucceeds());
+
+ ScopedThread([&] {
+ constexpr int kNobody = 65534;
+ EXPECT_THAT(syscall(SYS_setuid, kNobody), SyscallSucceeds());
+
+ // Attempt to open /proc/[child_pid]/mem
+ std::string mempath = absl::StrCat("/proc/", child_pid, "/mem");
+ EXPECT_THAT(open(mempath.c_str(), O_RDONLY), SyscallFailsWithErrno(EACCES));
+
+ // Tell proc its ok to go
+ EXPECT_THAT(close(pfd2[0]), SyscallSucceeds());
+ char ok = 1;
+ EXPECT_THAT(WriteFd(pfd2[1], &ok, sizeof(ok)),
+ SyscallSucceedsWithValue(sizeof(ok)));
+ EXPECT_THAT(close(pfd2[1]), SyscallSucceeds());
+
+ // Expect termination
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0), SyscallSucceeds());
+ });
+}
+
+// Perform read on /proc/[pid]/mem with same UID/GID.
+TEST(ProcPidMem, SameUser) {
+ int pfd1[2] = {};
+ int pfd2[2] = {};
+
+ ASSERT_THAT(pipe(pfd1), SyscallSucceeds());
+ ASSERT_THAT(pipe(pfd2), SyscallSucceeds());
+
+ // Create child process
+ pid_t const child_pid = fork();
+ if (child_pid == 0) {
+ // Close reading end of first pipe
+ close(pfd1[0]);
+
+ // Tell parent about location of input
+ char input[] = "hello-world";
+ off_t input_location = reinterpret_cast<off_t>(input);
+ TEST_CHECK(WriteFd(pfd1[1], &input_location, sizeof(input_location)) ==
+ sizeof(input_location));
+ TEST_PCHECK(close(pfd1[1]) == 0);
+
+ // Close writing end of second pipe
+ TEST_PCHECK(close(pfd2[1]) == 0);
+
+ // Await parent OK to die
+ char ok = 0;
+ TEST_CHECK(ReadFd(pfd2[0], &ok, sizeof(ok)) == sizeof(ok));
+
+ // Close rest pipes
+ TEST_PCHECK(close(pfd2[0]) == 0);
+ _exit(0);
+ }
+ // In parent process.
+ ASSERT_THAT(child_pid, SyscallSucceeds());
+
+ // Close writing end of first pipe
+ EXPECT_THAT(close(pfd1[1]), SyscallSucceeds());
+
+ // Read target location from child
+ off_t target_location;
+ EXPECT_THAT(ReadFd(pfd1[0], &target_location, sizeof(target_location)),
+ SyscallSucceedsWithValue(sizeof(target_location)));
+ // Close reading end of first pipe
+ EXPECT_THAT(close(pfd1[0]), SyscallSucceeds());
+
+ // Open /proc/pid/mem fd
+ std::string mempath = absl::StrCat("/proc/", child_pid, "/mem");
+ auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open(mempath, O_RDONLY));
+ char expected[] = "hello-world";
+ char output[sizeof(expected)];
+ EXPECT_THAT(pread(memfd.get(), output, sizeof(output),
+ reinterpret_cast<off_t>(target_location)),
+ SyscallSucceedsWithValue(sizeof(output)));
+ EXPECT_STREQ(expected, output);
+
+ // Tell proc its ok to go
+ EXPECT_THAT(close(pfd2[0]), SyscallSucceeds());
+ char ok = 1;
+ EXPECT_THAT(WriteFd(pfd2[1], &ok, sizeof(ok)),
+ SyscallSucceedsWithValue(sizeof(ok)));
+ EXPECT_THAT(close(pfd2[1]), SyscallSucceeds());
+
+ // Expect termination
+ int status;
+ ASSERT_THAT(waitpid(child_pid, &status, 0), SyscallSucceeds());
+}
+
// Just open and read /proc/self/maps, check that we can find [stack]
TEST(ProcSelfMaps, Basic) {
auto proc_self_maps =
diff --git a/test/syscalls/linux/proc_pid_smaps.cc b/test/syscalls/linux/proc_pid_smaps.cc
index 9fb1b3a2c..738923822 100644
--- a/test/syscalls/linux/proc_pid_smaps.cc
+++ b/test/syscalls/linux/proc_pid_smaps.cc
@@ -191,7 +191,7 @@ PosixErrorOr<std::vector<ProcPidSmapsEntry>> ParseProcPidSmaps(
// amount of whitespace).
if (!entry) {
std::cerr << "smaps line not considered a maps line: "
- << maybe_maps_entry.error_message() << std::endl;
+ << maybe_maps_entry.error().message() << std::endl;
return PosixError(
EINVAL,
absl::StrCat("smaps field line without preceding maps line: ", l));
diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc
index 1b9dbc584..bd779da92 100644
--- a/test/syscalls/linux/raw_socket_icmp.cc
+++ b/test/syscalls/linux/raw_socket_icmp.cc
@@ -438,6 +438,19 @@ TEST_F(RawSocketICMPTest, SetAndGetSocketLinger) {
EXPECT_EQ(0, memcmp(&sl, &got_linger, length));
}
+// Test getsockopt for SO_ACCEPTCONN.
+TEST_F(RawSocketICMPTest, GetSocketAcceptConn) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int got = -1;
+ socklen_t length = sizeof(got);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ACCEPTCONN, &got, &length),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_EQ(length, sizeof(got));
+ EXPECT_EQ(got, 0);
+}
+
void RawSocketICMPTest::ExpectICMPSuccess(const struct icmphdr& icmp) {
// We're going to receive both the echo request and reply, but the order is
// indeterminate.
diff --git a/test/syscalls/linux/semaphore.cc b/test/syscalls/linux/semaphore.cc
index e9b131ca9..ed6a1c2aa 100644
--- a/test/syscalls/linux/semaphore.cc
+++ b/test/syscalls/linux/semaphore.cc
@@ -486,6 +486,62 @@ TEST(SemaphoreTest, SemIpcSet) {
ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallFailsWithErrno(EACCES));
}
+TEST(SemaphoreTest, SemCtlIpcStat) {
+ // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false));
+ const uid_t kUid = getuid();
+ const gid_t kGid = getgid();
+ time_t start_time = time(nullptr);
+
+ AutoSem sem(semget(IPC_PRIVATE, 10, 0600 | IPC_CREAT));
+ ASSERT_THAT(sem.get(), SyscallSucceeds());
+
+ struct semid_ds ds;
+ EXPECT_THAT(semctl(sem.get(), 0, IPC_STAT, &ds), SyscallSucceeds());
+
+ EXPECT_EQ(ds.sem_perm.__key, IPC_PRIVATE);
+ EXPECT_EQ(ds.sem_perm.uid, kUid);
+ EXPECT_EQ(ds.sem_perm.gid, kGid);
+ EXPECT_EQ(ds.sem_perm.cuid, kUid);
+ EXPECT_EQ(ds.sem_perm.cgid, kGid);
+ EXPECT_EQ(ds.sem_perm.mode, 0600);
+ // Last semop time is not set on creation.
+ EXPECT_EQ(ds.sem_otime, 0);
+ EXPECT_GE(ds.sem_ctime, start_time);
+ EXPECT_EQ(ds.sem_nsems, 10);
+
+ // The timestamps only have a resolution of seconds; slow down so we actually
+ // see the timestamps change.
+ absl::SleepFor(absl::Seconds(1));
+
+ // Set semid_ds structure of the set.
+ auto last_ctime = ds.sem_ctime;
+ start_time = time(nullptr);
+ struct semid_ds semid_to_set = {};
+ semid_to_set.sem_perm.uid = kUid;
+ semid_to_set.sem_perm.gid = kGid;
+ semid_to_set.sem_perm.mode = 0666;
+ ASSERT_THAT(semctl(sem.get(), 0, IPC_SET, &semid_to_set), SyscallSucceeds());
+ struct sembuf buf = {};
+ buf.sem_op = 1;
+ ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallSucceeds());
+
+ EXPECT_THAT(semctl(sem.get(), 0, IPC_STAT, &ds), SyscallSucceeds());
+ EXPECT_EQ(ds.sem_perm.mode, 0666);
+ EXPECT_GE(ds.sem_otime, start_time);
+ EXPECT_GT(ds.sem_ctime, last_ctime);
+
+ // An invalid semid fails the syscall with errno EINVAL.
+ EXPECT_THAT(semctl(sem.get() + 1, 0, IPC_STAT, &ds),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Make semaphore not readable and check the signal fails.
+ semid_to_set.sem_perm.mode = 0200;
+ ASSERT_THAT(semctl(sem.get(), 0, IPC_SET, &semid_to_set), SyscallSucceeds());
+ EXPECT_THAT(semctl(sem.get(), 0, IPC_STAT, &ds),
+ SyscallFailsWithErrno(EACCES));
+}
+
} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 11fcec443..39a68c5a5 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -350,6 +350,10 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdownListen) {
sockaddr_storage conn_addr = connector.addr;
ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ // TODO(b/157236388): Remove Disable save after bug is fixed. S/R test can
+ // fail because the last socket may not be delivered to the accept queue
+ // by the time connect returns.
+ DisableSave ds;
for (int i = 0; i < kBacklog; i++) {
auto client = ASSERT_NO_ERRNO_AND_VALUE(
Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
@@ -554,7 +558,11 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdownWhileConnect) {
});
}
-TEST_P(SocketInetLoopbackTest, TCPbacklog) {
+// TODO(b/157236388): Remove _NoRandomSave once bug is fixed. Test fails w/
+// random save as established connections which can't be delivered to the accept
+// queue because the queue is full are not correctly delivered after restore
+// causing the last accept to timeout on the restore.
+TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) {
auto const& param = GetParam();
TestAddress const& listener = param.listener;
@@ -567,7 +575,8 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog) {
ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
listener.addr_len),
SyscallSucceeds());
- ASSERT_THAT(listen(listen_fd.get(), 2), SyscallSucceeds());
+ constexpr int kBacklogSize = 2;
+ ASSERT_THAT(listen(listen_fd.get(), kBacklogSize), SyscallSucceeds());
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
@@ -1143,6 +1152,9 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) {
sockaddr_storage conn_addr = connector.addr;
ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+
+ // TODO(b/157236388): Reenable Cooperative S/R once bug is fixed.
+ DisableSave ds;
ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
reinterpret_cast<sockaddr*>(&conn_addr),
connector.addr_len),
diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc
index f4b69c46c..831d96262 100644
--- a/test/syscalls/linux/socket_ip_tcp_generic.cc
+++ b/test/syscalls/linux/socket_ip_tcp_generic.cc
@@ -14,6 +14,7 @@
#include "test/syscalls/linux/socket_ip_tcp_generic.h"
+#include <fcntl.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <poll.h>
@@ -979,6 +980,56 @@ TEST_P(TCPSocketPairTest, SetTCPUserTimeoutAboveZero) {
EXPECT_EQ(get, kAbove);
}
+#ifdef __linux__
+TEST_P(TCPSocketPairTest, SpliceFromPipe) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ FileDescriptor rfd(fds[0]);
+ FileDescriptor wfd(fds[1]);
+
+ // Fill with some random data.
+ std::vector<char> buf(kPageSize / 2);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+
+ EXPECT_THAT(
+ splice(rfd.get(), nullptr, sockets->first_fd(), nullptr, kPageSize, 0),
+ SyscallSucceedsWithValue(buf.size()));
+
+ std::vector<char> rbuf(buf.size());
+ ASSERT_THAT(read(sockets->second_fd(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+ EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0);
+}
+
+TEST_P(TCPSocketPairTest, SpliceToPipe) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ FileDescriptor rfd(fds[0]);
+ FileDescriptor wfd(fds[1]);
+
+ // Fill with some random data.
+ std::vector<char> buf(kPageSize / 2);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(sockets->first_fd(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+ shutdown(sockets->first_fd(), SHUT_WR);
+ EXPECT_THAT(
+ splice(sockets->second_fd(), nullptr, wfd.get(), nullptr, kPageSize, 0),
+ SyscallSucceedsWithValue(buf.size()));
+
+ std::vector<char> rbuf(buf.size());
+ ASSERT_THAT(read(rfd.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+ EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0);
+}
+#endif // __linux__
+
TEST_P(TCPSocketPairTest, SetTCPWindowClampBelowMinRcvBufConnectedSocket) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
// Discover minimum receive buf by setting a really low value
diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc
index 3f2c0fdf2..f69f8f99f 100644
--- a/test/syscalls/linux/socket_ip_udp_generic.cc
+++ b/test/syscalls/linux/socket_ip_udp_generic.cc
@@ -472,5 +472,19 @@ TEST_P(UDPSocketPairTest, SetAndGetSocketLinger) {
EXPECT_EQ(0, memcmp(&sl, &got_linger, length));
}
+// Test getsockopt for SO_ACCEPTCONN on udp socket.
+TEST_P(UDPSocketPairTest, GetSocketAcceptConn) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int got = -1;
+ socklen_t length = sizeof(got);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_ACCEPTCONN, &got, &length),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_EQ(length, sizeof(got));
+ EXPECT_EQ(got, 0);
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
index a72c76c97..b3f54e7f6 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
@@ -28,6 +28,7 @@
#include "test/syscalls/linux/ip_socket_test_util.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
#include "test/util/test_util.h"
namespace gvisor {
@@ -75,7 +76,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNoGroup) {
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
EXPECT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
PosixErrorIs(EAGAIN, ::testing::_));
}
@@ -209,7 +210,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddr) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
@@ -265,7 +266,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNic) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
@@ -321,7 +322,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddr) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
@@ -377,7 +378,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNic) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
@@ -437,7 +438,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrConnect) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
@@ -497,7 +498,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicConnect) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
@@ -553,7 +554,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelf) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
@@ -609,7 +610,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelf) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
@@ -669,7 +670,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfConnect) {
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
EXPECT_THAT(
- RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
PosixErrorIs(EAGAIN, ::testing::_));
}
@@ -727,7 +728,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfConnect) {
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
EXPECT_THAT(
- RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
PosixErrorIs(EAGAIN, ::testing::_));
}
@@ -785,7 +786,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfNoLoop) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
@@ -845,7 +846,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfNoLoop) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
@@ -919,7 +920,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropAddr) {
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
EXPECT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
PosixErrorIs(EAGAIN, ::testing::_));
}
@@ -977,7 +978,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropNic) {
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
EXPECT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
PosixErrorIs(EAGAIN, ::testing::_));
}
@@ -1330,8 +1331,8 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionOnTwoSockets) {
// Check that we received the multicast packet on both sockets.
for (auto& sockets : socket_pairs) {
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RecvMsgTimeout(sockets->second_fd(), recv_buf,
- sizeof(recv_buf), 1 /*timeout*/),
+ ASSERT_THAT(RecvTimeout(sockets->second_fd(), recv_buf, sizeof(recv_buf),
+ 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -1409,8 +1410,8 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) {
// Check that we received the multicast packet on both sockets.
for (auto& sockets : socket_pairs) {
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RecvMsgTimeout(sockets->second_fd(), recv_buf,
- sizeof(recv_buf), 1 /*timeout*/),
+ ASSERT_THAT(RecvTimeout(sockets->second_fd(), recv_buf, sizeof(recv_buf),
+ 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -1432,8 +1433,8 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) {
char recv_buf[sizeof(send_buf)] = {};
for (auto& sockets : socket_pairs) {
- ASSERT_THAT(RecvMsgTimeout(sockets->second_fd(), recv_buf,
- sizeof(recv_buf), 1 /*timeout*/),
+ ASSERT_THAT(RecvTimeout(sockets->second_fd(), recv_buf, sizeof(recv_buf),
+ 1 /*timeout*/),
PosixErrorIs(EAGAIN, ::testing::_));
}
}
@@ -1486,7 +1487,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenJoinThenReceive) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -1530,7 +1531,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenNoJoinThenNoReceive) {
// Check that we don't receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
PosixErrorIs(EAGAIN, ::testing::_));
}
@@ -1580,7 +1581,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenSend) {
// Check that we received the packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -1627,7 +1628,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenReceive) {
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -1678,7 +1679,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenSend) {
// Check that we received the packet.
char recv_buf[sizeof(send_buf)] = {};
ASSERT_THAT(
- RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -1737,7 +1738,7 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) {
// of the other sockets to have received it, but we will check that later.
char recv_buf[sizeof(send_buf)] = {};
EXPECT_THAT(
- RecvMsgTimeout(last->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ RecvTimeout(last->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(send_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -1745,9 +1746,9 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) {
// Verify that no other messages were received.
for (auto& socket : sockets) {
char recv_buf[kMessageSize] = {};
- EXPECT_THAT(RecvMsgTimeout(socket->get(), recv_buf, sizeof(recv_buf),
- 1 /*timeout*/),
- PosixErrorIs(EAGAIN, ::testing::_));
+ EXPECT_THAT(
+ RecvTimeout(socket->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ PosixErrorIs(EAGAIN, ::testing::_));
}
}
@@ -2108,6 +2109,9 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) {
constexpr int kMessageSize = 10;
+ // Saving during each iteration of the following loop is too expensive.
+ DisableSave ds;
+
for (int i = 0; i < 100; ++i) {
// Send a new message to the REUSEADDR/REUSEPORT group. We use a new socket
// each time so that a new ephemerial port will be used each time. This
@@ -2120,16 +2124,18 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) {
SyscallSucceedsWithValue(sizeof(send_buf)));
}
+ ds.reset();
+
// Check that both receivers got messages. This checks that we are using load
// balancing (REUSEPORT) instead of the most recently bound socket
// (REUSEADDR).
char recv_buf[kMessageSize] = {};
- EXPECT_THAT(RecvMsgTimeout(receiver1->get(), recv_buf, sizeof(recv_buf),
- 1 /*timeout*/),
- IsPosixErrorOkAndHolds(kMessageSize));
- EXPECT_THAT(RecvMsgTimeout(receiver2->get(), recv_buf, sizeof(recv_buf),
- 1 /*timeout*/),
- IsPosixErrorOkAndHolds(kMessageSize));
+ EXPECT_THAT(
+ RecvTimeout(receiver1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(kMessageSize));
+ EXPECT_THAT(
+ RecvTimeout(receiver2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(kMessageSize));
}
// Test that socket will receive packet info control message.
@@ -2193,8 +2199,8 @@ TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) {
received_msg.msg_controllen = CMSG_LEN(cmsg_data_len);
received_msg.msg_control = received_cmsg_buf;
- ASSERT_THAT(RetryEINTR(recvmsg)(receiver->get(), &received_msg, 0),
- SyscallSucceedsWithValue(kDataLength));
+ ASSERT_THAT(RecvMsgTimeout(receiver->get(), &received_msg, 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(kDataLength));
cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg);
ASSERT_NE(cmsg, nullptr);
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc
index 49a0f06d9..875016812 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc
@@ -40,17 +40,9 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, JoinSubnet) {
/*prefixlen=*/24, &addr, sizeof(addr)));
Cleanup defer_addr_removal = Cleanup(
[loopback_link = std::move(loopback_link), addr = std::move(addr)] {
- if (IsRunningOnGvisor()) {
- // TODO(gvisor.dev/issue/3921): Remove this once deleting addresses
- // via netlink is supported.
- EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET,
- /*prefixlen=*/24, &addr, sizeof(addr)),
- PosixErrorIs(EOPNOTSUPP, ::testing::_));
- } else {
- EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET,
- /*prefixlen=*/24, &addr,
- sizeof(addr)));
- }
+ EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET,
+ /*prefixlen=*/24, &addr,
+ sizeof(addr)));
});
auto snd_sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -124,17 +116,9 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, ReuseAddrSubnetDirectedBroadcast) {
24 /* prefixlen */, &addr, sizeof(addr)));
Cleanup defer_addr_removal = Cleanup(
[loopback_link = std::move(loopback_link), addr = std::move(addr)] {
- if (IsRunningOnGvisor()) {
- // TODO(gvisor.dev/issue/3921): Remove this once deleting addresses
- // via netlink is supported.
- EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET,
- /*prefixlen=*/24, &addr, sizeof(addr)),
- PosixErrorIs(EOPNOTSUPP, ::testing::_));
- } else {
- EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET,
- /*prefixlen=*/24, &addr,
- sizeof(addr)));
- }
+ EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET,
+ /*prefixlen=*/24, &addr,
+ sizeof(addr)));
});
TestAddress broadcast_address("SubnetBroadcastAddress");
diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc
index b3fcf8e7c..e83f0d81f 100644
--- a/test/syscalls/linux/socket_netlink_route.cc
+++ b/test/syscalls/linux/socket_netlink_route.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include <arpa/inet.h>
+#include <fcntl.h>
#include <ifaddrs.h>
#include <linux/if.h>
#include <linux/netlink.h>
@@ -335,6 +336,49 @@ TEST(NetlinkRouteTest, MsgHdrMsgTrunc) {
EXPECT_EQ((msg.msg_flags & MSG_TRUNC), MSG_TRUNC);
}
+TEST(NetlinkRouteTest, SpliceFromPipe) {
+ Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ FileDescriptor rfd(fds[0]);
+ FileDescriptor wfd(fds[1]);
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifm.ifi_family = AF_UNSPEC;
+ req.ifm.ifi_index = loopback_link.index;
+
+ ASSERT_THAT(write(wfd.get(), &req, sizeof(req)),
+ SyscallSucceedsWithValue(sizeof(req)));
+
+ EXPECT_THAT(splice(rfd.get(), nullptr, fd.get(), nullptr, sizeof(req) + 1, 0),
+ SyscallSucceedsWithValue(sizeof(req)));
+ close(wfd.release());
+ EXPECT_THAT(splice(rfd.get(), nullptr, fd.get(), nullptr, sizeof(req) + 1, 0),
+ SyscallSucceedsWithValue(0));
+
+ bool found = false;
+ ASSERT_NO_ERRNO(NetlinkResponse(
+ fd,
+ [&](const struct nlmsghdr* hdr) {
+ CheckLinkMsg(hdr, loopback_link);
+ found = true;
+ },
+ false));
+ EXPECT_TRUE(found) << "Netlink response does not contain any links.";
+}
+
TEST(NetlinkRouteTest, MsgTruncMsgHdrMsgTrunc) {
FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
@@ -467,53 +511,42 @@ TEST(NetlinkRouteTest, LookupAll) {
ASSERT_GT(count, 0);
}
-TEST(NetlinkRouteTest, AddAddr) {
+TEST(NetlinkRouteTest, AddAndRemoveAddr) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+ // Don't do cooperative save/restore because netstack state is not restored.
+ // TODO(gvisor.dev/issue/4595): enable cooperative save tests.
+ const DisableSave ds;
Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink());
- FileDescriptor fd =
- ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
-
- struct request {
- struct nlmsghdr hdr;
- struct ifaddrmsg ifa;
- struct rtattr rtattr;
- struct in_addr addr;
- char pad[NLMSG_ALIGNTO + RTA_ALIGNTO];
- };
-
- struct request req = {};
- req.hdr.nlmsg_type = RTM_NEWADDR;
- req.hdr.nlmsg_seq = kSeq;
- req.ifa.ifa_family = AF_INET;
- req.ifa.ifa_prefixlen = 24;
- req.ifa.ifa_flags = 0;
- req.ifa.ifa_scope = 0;
- req.ifa.ifa_index = loopback_link.index;
- req.rtattr.rta_type = IFA_LOCAL;
- req.rtattr.rta_len = RTA_LENGTH(sizeof(req.addr));
- inet_pton(AF_INET, "10.0.0.1", &req.addr);
- req.hdr.nlmsg_len =
- NLMSG_LENGTH(sizeof(req.ifa)) + NLMSG_ALIGN(req.rtattr.rta_len);
+ struct in_addr addr;
+ ASSERT_EQ(inet_pton(AF_INET, "10.0.0.1", &addr), 1);
// Create should succeed, as no such address in kernel.
- req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_ACK;
- EXPECT_NO_ERRNO(
- NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len));
+ ASSERT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET,
+ /*prefixlen=*/24, &addr, sizeof(addr)));
+
+ Cleanup defer_addr_removal = Cleanup(
+ [loopback_link = std::move(loopback_link), addr = std::move(addr)] {
+ // First delete should succeed, as address exists.
+ EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET,
+ /*prefixlen=*/24, &addr,
+ sizeof(addr)));
+
+ // Second delete should fail, as address no longer exists.
+ EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET,
+ /*prefixlen=*/24, &addr, sizeof(addr)),
+ PosixErrorIs(EINVAL, ::testing::_));
+ });
// Replace an existing address should succeed.
- req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_REPLACE | NLM_F_ACK;
- req.hdr.nlmsg_seq++;
- EXPECT_NO_ERRNO(
- NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len));
+ ASSERT_NO_ERRNO(LinkReplaceLocalAddr(loopback_link.index, AF_INET,
+ /*prefixlen=*/24, &addr, sizeof(addr)));
// Create exclusive should fail, as we created the address above.
- req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL | NLM_F_ACK;
- req.hdr.nlmsg_seq++;
- EXPECT_THAT(
- NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len),
- PosixErrorIs(EEXIST, ::testing::_));
+ EXPECT_THAT(LinkAddExclusiveLocalAddr(loopback_link.index, AF_INET,
+ /*prefixlen=*/24, &addr, sizeof(addr)),
+ PosixErrorIs(EEXIST, ::testing::_));
}
// GetRouteDump tests a RTM_GETROUTE + NLM_F_DUMP request.
diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc
index 7a0bad4cb..46f749c7c 100644
--- a/test/syscalls/linux/socket_netlink_route_util.cc
+++ b/test/syscalls/linux/socket_netlink_route_util.cc
@@ -29,6 +29,8 @@ constexpr uint32_t kSeq = 12345;
// Types of address modifications that may be performed on an interface.
enum class LinkAddrModification {
kAdd,
+ kAddExclusive,
+ kReplace,
kDelete,
};
@@ -40,6 +42,14 @@ PosixError PopulateNlmsghdr(LinkAddrModification modification,
hdr->nlmsg_type = RTM_NEWADDR;
hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
return NoError();
+ case LinkAddrModification::kAddExclusive:
+ hdr->nlmsg_type = RTM_NEWADDR;
+ hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_EXCL | NLM_F_ACK;
+ return NoError();
+ case LinkAddrModification::kReplace:
+ hdr->nlmsg_type = RTM_NEWADDR;
+ hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_REPLACE | NLM_F_ACK;
+ return NoError();
case LinkAddrModification::kDelete:
hdr->nlmsg_type = RTM_DELADDR;
hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
@@ -144,6 +154,18 @@ PosixError LinkAddLocalAddr(int index, int family, int prefixlen,
LinkAddrModification::kAdd);
}
+PosixError LinkAddExclusiveLocalAddr(int index, int family, int prefixlen,
+ const void* addr, int addrlen) {
+ return LinkModifyLocalAddr(index, family, prefixlen, addr, addrlen,
+ LinkAddrModification::kAddExclusive);
+}
+
+PosixError LinkReplaceLocalAddr(int index, int family, int prefixlen,
+ const void* addr, int addrlen) {
+ return LinkModifyLocalAddr(index, family, prefixlen, addr, addrlen,
+ LinkAddrModification::kReplace);
+}
+
PosixError LinkDelLocalAddr(int index, int family, int prefixlen,
const void* addr, int addrlen) {
return LinkModifyLocalAddr(index, family, prefixlen, addr, addrlen,
diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h
index e5badca70..eaa91ad79 100644
--- a/test/syscalls/linux/socket_netlink_route_util.h
+++ b/test/syscalls/linux/socket_netlink_route_util.h
@@ -39,10 +39,19 @@ PosixErrorOr<std::vector<Link>> DumpLinks();
// Returns the loopback link on the system. ENOENT if not found.
PosixErrorOr<Link> LoopbackLink();
-// LinkAddLocalAddr sets IFA_LOCAL attribute on the interface.
+// LinkAddLocalAddr adds a new IFA_LOCAL address to the interface.
PosixError LinkAddLocalAddr(int index, int family, int prefixlen,
const void* addr, int addrlen);
+// LinkAddExclusiveLocalAddr adds a new IFA_LOCAL address with NLM_F_EXCL flag
+// to the interface.
+PosixError LinkAddExclusiveLocalAddr(int index, int family, int prefixlen,
+ const void* addr, int addrlen);
+
+// LinkReplaceLocalAddr replaces an IFA_LOCAL address on the interface.
+PosixError LinkReplaceLocalAddr(int index, int family, int prefixlen,
+ const void* addr, int addrlen);
+
// LinkDelLocalAddr removes IFA_LOCAL attribute on the interface.
PosixError LinkDelLocalAddr(int index, int family, int prefixlen,
const void* addr, int addrlen);
diff --git a/test/syscalls/linux/socket_netlink_util.cc b/test/syscalls/linux/socket_netlink_util.cc
index 952eecfe8..bdebea321 100644
--- a/test/syscalls/linux/socket_netlink_util.cc
+++ b/test/syscalls/linux/socket_netlink_util.cc
@@ -67,10 +67,21 @@ PosixError NetlinkRequestResponse(
RETURN_ERROR_IF_SYSCALL_FAIL(RetryEINTR(sendmsg)(fd.get(), &msg, 0));
+ return NetlinkResponse(fd, fn, expect_nlmsgerr);
+}
+
+PosixError NetlinkResponse(
+ const FileDescriptor& fd,
+ const std::function<void(const struct nlmsghdr* hdr)>& fn,
+ bool expect_nlmsgerr) {
constexpr size_t kBufferSize = 4096;
std::vector<char> buf(kBufferSize);
+ struct iovec iov = {};
iov.iov_base = buf.data();
iov.iov_len = buf.size();
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
// If NLM_F_MULTI is set, response is a series of messages that ends with a
// NLMSG_DONE message.
diff --git a/test/syscalls/linux/socket_netlink_util.h b/test/syscalls/linux/socket_netlink_util.h
index e13ead406..f97276d44 100644
--- a/test/syscalls/linux/socket_netlink_util.h
+++ b/test/syscalls/linux/socket_netlink_util.h
@@ -41,6 +41,14 @@ PosixError NetlinkRequestResponse(
const std::function<void(const struct nlmsghdr* hdr)>& fn,
bool expect_nlmsgerr);
+// Call fn on all response netlink messages.
+//
+// To be used on requests with NLM_F_MULTI reponses.
+PosixError NetlinkResponse(
+ const FileDescriptor& fd,
+ const std::function<void(const struct nlmsghdr* hdr)>& fn,
+ bool expect_nlmsgerr);
+
// Send the passed request and call fn on all response netlink messages.
//
// To be used on requests without NLM_F_MULTI reponses.
diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc
index e11792309..a760581b5 100644
--- a/test/syscalls/linux/socket_test_util.cc
+++ b/test/syscalls/linux/socket_test_util.cc
@@ -753,8 +753,7 @@ PosixErrorOr<int> SendMsg(int sock, msghdr* msg, char buf[], int buf_size) {
return ret;
}
-PosixErrorOr<int> RecvMsgTimeout(int sock, char buf[], int buf_size,
- int timeout) {
+PosixErrorOr<int> RecvTimeout(int sock, char buf[], int buf_size, int timeout) {
fd_set rfd;
struct timeval to = {.tv_sec = timeout, .tv_usec = 0};
FD_ZERO(&rfd);
@@ -767,6 +766,19 @@ PosixErrorOr<int> RecvMsgTimeout(int sock, char buf[], int buf_size,
return ret;
}
+PosixErrorOr<int> RecvMsgTimeout(int sock, struct msghdr* msg, int timeout) {
+ fd_set rfd;
+ struct timeval to = {.tv_sec = timeout, .tv_usec = 0};
+ FD_ZERO(&rfd);
+ FD_SET(sock, &rfd);
+
+ int ret;
+ RETURN_ERROR_IF_SYSCALL_FAIL(ret = select(1, &rfd, NULL, NULL, &to));
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ ret = RetryEINTR(recvmsg)(sock, msg, MSG_DONTWAIT));
+ return ret;
+}
+
void RecvNoData(int sock) {
char data = 0;
struct iovec iov;
diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h
index 468bc96e0..5e205339f 100644
--- a/test/syscalls/linux/socket_test_util.h
+++ b/test/syscalls/linux/socket_test_util.h
@@ -467,9 +467,12 @@ PosixError FreeAvailablePort(int port);
// SendMsg converts a buffer to an iovec and adds it to msg before sending it.
PosixErrorOr<int> SendMsg(int sock, msghdr* msg, char buf[], int buf_size);
-// RecvMsgTimeout calls select on sock with timeout and then calls recv on sock.
-PosixErrorOr<int> RecvMsgTimeout(int sock, char buf[], int buf_size,
- int timeout);
+// RecvTimeout calls select on sock with timeout and then calls recv on sock.
+PosixErrorOr<int> RecvTimeout(int sock, char buf[], int buf_size, int timeout);
+
+// RecvMsgTimeout calls select on sock with timeout and then calls recvmsg on
+// sock.
+PosixErrorOr<int> RecvMsgTimeout(int sock, msghdr* msg, int timeout);
// RecvNoData checks that no data is receivable on sock.
void RecvNoData(int sock);
diff --git a/test/syscalls/linux/socket_unix_stream.cc b/test/syscalls/linux/socket_unix_stream.cc
index 1edcb15a7..ad9c4bf37 100644
--- a/test/syscalls/linux/socket_unix_stream.cc
+++ b/test/syscalls/linux/socket_unix_stream.cc
@@ -121,6 +121,19 @@ TEST_P(StreamUnixSocketPairTest, SetAndGetSocketLinger) {
EXPECT_EQ(0, memcmp(&got_linger, &sl, length));
}
+TEST_P(StreamUnixSocketPairTest, GetSocketAcceptConn) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int got = -1;
+ socklen_t length = sizeof(got);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_ACCEPTCONN, &got, &length),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_EQ(length, sizeof(got));
+ EXPECT_EQ(got, 0);
+}
+
INSTANTIATE_TEST_SUITE_P(
AllUnixDomainSockets, StreamUnixSocketPairTest,
::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>(
diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc
index 92260b1e1..6e7142a42 100644
--- a/test/syscalls/linux/stat.cc
+++ b/test/syscalls/linux/stat.cc
@@ -31,6 +31,7 @@
#include "test/util/cleanup.h"
#include "test/util/file_descriptor.h"
#include "test/util/fs_util.h"
+#include "test/util/save_util.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
@@ -328,7 +329,10 @@ TEST_F(StatTest, LeadingDoubleSlash) {
ASSERT_THAT(lstat(double_slash_path.c_str(), &double_slash_st),
SyscallSucceeds());
EXPECT_EQ(st.st_dev, double_slash_st.st_dev);
- EXPECT_EQ(st.st_ino, double_slash_st.st_ino);
+ // Inode numbers for gofer-accessed files may change across save/restore.
+ if (!IsRunningWithSaveRestore()) {
+ EXPECT_EQ(st.st_ino, double_slash_st.st_ino);
+ }
}
// Test that a rename doesn't change the underlying file.
@@ -346,8 +350,14 @@ TEST_F(StatTest, StatDoesntChangeAfterRename) {
EXPECT_EQ(st_old.st_nlink, st_new.st_nlink);
EXPECT_EQ(st_old.st_dev, st_new.st_dev);
+ // Inode numbers for gofer-accessed files on which no reference is held may
+ // change across save/restore because the information that the gofer client
+ // uses to track file identity (9P QID path) is inconsistent between gofer
+ // processes, which are restarted across save/restore.
+ //
// Overlay filesystems may synthesize directory inode numbers on the fly.
- if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir()))) {
+ if (!IsRunningWithSaveRestore() &&
+ !ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir()))) {
EXPECT_EQ(st_old.st_ino, st_new.st_ino);
}
EXPECT_EQ(st_old.st_mode, st_new.st_mode);
@@ -541,6 +551,26 @@ TEST_F(StatTest, LstatELOOPPath) {
ASSERT_THAT(lstat(path.c_str(), &s), SyscallFailsWithErrno(ELOOP));
}
+TEST(SimpleStatTest, DifferentFilesHaveDifferentDeviceInodeNumberPairs) {
+ // TODO(gvisor.dev/issue/1624): This test case fails in VFS1 save/restore
+ // tests because VFS1 gofer inode number assignment restarts after
+ // save/restore, such that the inodes for file1 and file2 (which are
+ // unreferenced and therefore not retained in sentry checkpoints before the
+ // calls to lstat()) are assigned the same inode number.
+ SKIP_IF(IsRunningWithVFS1() && IsRunningWithSaveRestore());
+
+ TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ TempPath file2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+
+ MaybeSave();
+ struct stat st1 = ASSERT_NO_ERRNO_AND_VALUE(Lstat(file1.path()));
+ MaybeSave();
+ struct stat st2 = ASSERT_NO_ERRNO_AND_VALUE(Lstat(file2.path()));
+ EXPECT_FALSE(st1.st_dev == st2.st_dev && st1.st_ino == st2.st_ino)
+ << "both files have device number " << st1.st_dev << " and inode number "
+ << st1.st_ino;
+}
+
// Ensure that inode allocation for anonymous devices work correctly across
// save/restore. In particular, inode numbers should be unique across S/R.
TEST(SimpleStatTest, AnonDeviceAllocatesUniqueInodesAcrossSaveRestore) {
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index 9f522f833..ebd873068 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -1725,6 +1725,63 @@ TEST_P(SimpleTcpSocketTest, CloseNonConnectedLingerOption) {
ASSERT_LT((end_time - start_time), absl::Seconds(kLingerTimeout));
}
+// Tests that SO_ACCEPTCONN returns non zero value for listening sockets.
+TEST_P(TcpSocketTest, GetSocketAcceptConnListener) {
+ int got = -1;
+ socklen_t length = sizeof(got);
+ ASSERT_THAT(getsockopt(listener_, SOL_SOCKET, SO_ACCEPTCONN, &got, &length),
+ SyscallSucceeds());
+ ASSERT_EQ(length, sizeof(got));
+ EXPECT_EQ(got, 1);
+}
+
+// Tests that SO_ACCEPTCONN returns zero value for not listening sockets.
+TEST_P(TcpSocketTest, GetSocketAcceptConnNonListener) {
+ int got = -1;
+ socklen_t length = sizeof(got);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ACCEPTCONN, &got, &length),
+ SyscallSucceeds());
+ ASSERT_EQ(length, sizeof(got));
+ EXPECT_EQ(got, 0);
+
+ ASSERT_THAT(getsockopt(t_, SOL_SOCKET, SO_ACCEPTCONN, &got, &length),
+ SyscallSucceeds());
+ ASSERT_EQ(length, sizeof(got));
+ EXPECT_EQ(got, 0);
+}
+
+TEST_P(SimpleTcpSocketTest, GetSocketAcceptConnWithShutdown) {
+ // TODO(b/171345701): Fix the TCP state for listening socket on shutdown.
+ SKIP_IF(IsRunningOnGvisor());
+
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ // Initialize address to the loopback one.
+ sockaddr_storage addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t addrlen = sizeof(addr);
+
+ // Bind to some port then start listening.
+ ASSERT_THAT(bind(s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(s.get(), SOMAXCONN), SyscallSucceeds());
+
+ int got = -1;
+ socklen_t length = sizeof(got);
+ ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_ACCEPTCONN, &got, &length),
+ SyscallSucceeds());
+ ASSERT_EQ(length, sizeof(got));
+ EXPECT_EQ(got, 1);
+
+ EXPECT_THAT(shutdown(s.get(), SHUT_RD), SyscallSucceeds());
+ ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_ACCEPTCONN, &got, &length),
+ SyscallSucceeds());
+ ASSERT_EQ(length, sizeof(got));
+ EXPECT_EQ(got, 0);
+}
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest,
::testing::Values(AF_INET, AF_INET6));
diff --git a/test/syscalls/linux/timers.cc b/test/syscalls/linux/timers.cc
index 4b3c44527..cac94d9e1 100644
--- a/test/syscalls/linux/timers.cc
+++ b/test/syscalls/linux/timers.cc
@@ -33,6 +33,7 @@
#include "test/util/signal_util.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
+#include "test/util/timer_util.h"
ABSL_FLAG(bool, timers_test_sleep, false,
"If true, sleep forever instead of running tests.");
@@ -215,99 +216,6 @@ TEST(TimerTest, ProcessKilledOnCPUHardLimit) {
EXPECT_GE(cpu, kHardLimit);
}
-// RAII type for a kernel "POSIX" interval timer. (The kernel provides system
-// calls such as timer_create that behave very similarly, but not identically,
-// to those described by timer_create(2); in particular, the kernel does not
-// implement SIGEV_THREAD. glibc builds POSIX-compliant interval timers based on
-// these kernel interval timers.)
-//
-// Compare implementation to FileDescriptor.
-class IntervalTimer {
- public:
- IntervalTimer() = default;
-
- explicit IntervalTimer(int id) { set_id(id); }
-
- IntervalTimer(IntervalTimer&& orig) : id_(orig.release()) {}
-
- IntervalTimer& operator=(IntervalTimer&& orig) {
- if (this == &orig) return *this;
- reset(orig.release());
- return *this;
- }
-
- IntervalTimer(const IntervalTimer& other) = delete;
- IntervalTimer& operator=(const IntervalTimer& other) = delete;
-
- ~IntervalTimer() { reset(); }
-
- int get() const { return id_; }
-
- int release() {
- int const id = id_;
- id_ = -1;
- return id;
- }
-
- void reset() { reset(-1); }
-
- void reset(int id) {
- if (id_ >= 0) {
- TEST_PCHECK(syscall(SYS_timer_delete, id_) == 0);
- MaybeSave();
- }
- set_id(id);
- }
-
- PosixErrorOr<struct itimerspec> Set(
- int flags, const struct itimerspec& new_value) const {
- struct itimerspec old_value = {};
- if (syscall(SYS_timer_settime, id_, flags, &new_value, &old_value) < 0) {
- return PosixError(errno, "timer_settime");
- }
- MaybeSave();
- return old_value;
- }
-
- PosixErrorOr<struct itimerspec> Get() const {
- struct itimerspec curr_value = {};
- if (syscall(SYS_timer_gettime, id_, &curr_value) < 0) {
- return PosixError(errno, "timer_gettime");
- }
- MaybeSave();
- return curr_value;
- }
-
- PosixErrorOr<int> Overruns() const {
- int rv = syscall(SYS_timer_getoverrun, id_);
- if (rv < 0) {
- return PosixError(errno, "timer_getoverrun");
- }
- MaybeSave();
- return rv;
- }
-
- private:
- void set_id(int id) { id_ = std::max(id, -1); }
-
- // Kernel timer_t is int; glibc timer_t is void*.
- int id_ = -1;
-};
-
-PosixErrorOr<IntervalTimer> TimerCreate(clockid_t clockid,
- const struct sigevent& sev) {
- int timerid;
- int ret = syscall(SYS_timer_create, clockid, &sev, &timerid);
- if (ret < 0) {
- return PosixError(errno, "timer_create");
- }
- if (ret > 0) {
- return PosixError(EINVAL, "timer_create should never return positive");
- }
- MaybeSave();
- return IntervalTimer(timerid);
-}
-
// See timerfd.cc:TimerSlack() for rationale.
constexpr absl::Duration kTimerSlack = absl::Milliseconds(500);
diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc
index 1a7673317..bc5bd9218 100644
--- a/test/syscalls/linux/udp_socket.cc
+++ b/test/syscalls/linux/udp_socket.cc
@@ -679,6 +679,43 @@ TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) {
SyscallSucceedsWithValue(sizeof(buf)));
}
+TEST_P(UdpSocketTest, ConnectAndSendNoReceiver) {
+ ASSERT_NO_ERRNO(BindLoopback());
+ // Close the socket to release the port so that we get an ICMP error.
+ ASSERT_THAT(close(bind_.release()), SyscallSucceeds());
+
+ // Connect to loopback:bind_addr_ which should *hopefully* not be bound by an
+ // UDP socket. There is no easy way to ensure that the UDP port is not bound
+ // by another conncurrently running test. *This is potentially flaky*.
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ char buf[512];
+ EXPECT_THAT(send(sock_.get(), buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ constexpr int kTimeout = 1000;
+ // Poll to make sure we get the ICMP error back before issuing more writes.
+ struct pollfd pfd = {sock_.get(), POLLERR, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+
+ // Next write should fail with ECONNREFUSED due to the ICMP error generated in
+ // response to the previous write.
+ ASSERT_THAT(send(sock_.get(), buf, sizeof(buf), 0),
+ SyscallFailsWithErrno(ECONNREFUSED));
+
+ // The next write should succeed again since the last write call would have
+ // retrieved and cleared the socket error.
+ ASSERT_THAT(send(sock_.get(), buf, sizeof(buf), 0), SyscallSucceeds());
+
+ // Poll to make sure we get the ICMP error back before issuing more writes.
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+
+ // Next write should fail with ECONNREFUSED due to the ICMP error generated in
+ // response to the previous write.
+ ASSERT_THAT(send(sock_.get(), buf, sizeof(buf), 0),
+ SyscallFailsWithErrno(ECONNREFUSED));
+}
+
TEST_P(UdpSocketTest, ZerolengthWriteAllowed) {
// TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes.
SKIP_IF(IsRunningWithHostinet());
@@ -838,7 +875,7 @@ TEST_P(UdpSocketTest, ReceiveBeforeConnect) {
// Receive the data. It works because it was sent before the connect.
char received[sizeof(buf)];
EXPECT_THAT(
- RecvMsgTimeout(bind_.get(), received, sizeof(received), 1 /*timeout*/),
+ RecvTimeout(bind_.get(), received, sizeof(received), 1 /*timeout*/),
IsPosixErrorOkAndHolds(sizeof(received)));
EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
@@ -928,9 +965,8 @@ TEST_P(UdpSocketTest, ReadShutdownNonblockPendingData) {
SyscallSucceedsWithValue(1));
// We should get the data even though read has been shutdown.
- EXPECT_THAT(
- RecvMsgTimeout(bind_.get(), received, 2 /*buf_size*/, 1 /*timeout*/),
- IsPosixErrorOkAndHolds(2));
+ EXPECT_THAT(RecvTimeout(bind_.get(), received, 2 /*buf_size*/, 1 /*timeout*/),
+ IsPosixErrorOkAndHolds(2));
// Because we read less than the entire packet length, since it's a packet
// based socket any subsequent reads should return EWOULDBLOCK.
@@ -1698,8 +1734,8 @@ TEST_P(UdpSocketTest, RecvBufLimitsEmptyRcvBuf) {
sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_),
SyscallSucceedsWithValue(buf.size()));
std::vector<char> received(buf.size());
- EXPECT_THAT(RecvMsgTimeout(bind_.get(), received.data(), received.size(),
- 1 /*timeout*/),
+ EXPECT_THAT(RecvTimeout(bind_.get(), received.data(), received.size(),
+ 1 /*timeout*/),
IsPosixErrorOkAndHolds(received.size()));
}
@@ -1714,8 +1750,8 @@ TEST_P(UdpSocketTest, RecvBufLimitsEmptyRcvBuf) {
SyscallSucceedsWithValue(buf.size()));
std::vector<char> received(buf.size());
- ASSERT_THAT(RecvMsgTimeout(bind_.get(), received.data(), received.size(),
- 1 /*timeout*/),
+ ASSERT_THAT(RecvTimeout(bind_.get(), received.data(), received.size(),
+ 1 /*timeout*/),
IsPosixErrorOkAndHolds(received.size()));
}
}
@@ -1785,8 +1821,8 @@ TEST_P(UdpSocketTest, RecvBufLimits) {
for (int i = 0; i < sent - 1; i++) {
// Receive the data.
std::vector<char> received(buf.size());
- EXPECT_THAT(RecvMsgTimeout(bind_.get(), received.data(), received.size(),
- 1 /*timeout*/),
+ EXPECT_THAT(RecvTimeout(bind_.get(), received.data(), received.size(),
+ 1 /*timeout*/),
IsPosixErrorOkAndHolds(received.size()));
EXPECT_EQ(memcmp(buf.data(), received.data(), buf.size()), 0);
}
diff --git a/test/util/BUILD b/test/util/BUILD
index 26c2b6a2f..1b028a477 100644
--- a/test/util/BUILD
+++ b/test/util/BUILD
@@ -155,6 +155,10 @@ cc_library(
],
hdrs = ["save_util.h"],
defines = select_system(),
+ deps = [
+ ":logging",
+ "@com_google_absl//absl/types:optional",
+ ],
)
cc_library(
diff --git a/test/util/posix_error.cc b/test/util/posix_error.cc
index cebf7e0ac..deed0c05b 100644
--- a/test/util/posix_error.cc
+++ b/test/util/posix_error.cc
@@ -87,7 +87,7 @@ bool PosixErrorIsMatcherCommonImpl::MatchAndExplain(
return false;
}
- if (!message_matcher_.Matches(error.error_message())) {
+ if (!message_matcher_.Matches(error.message())) {
return false;
}
diff --git a/test/util/posix_error.h b/test/util/posix_error.h
index ad666bce0..b634a7f78 100644
--- a/test/util/posix_error.h
+++ b/test/util/posix_error.h
@@ -26,11 +26,6 @@
namespace gvisor {
namespace testing {
-class PosixErrorIsMatcherCommonImpl;
-
-template <typename T>
-class PosixErrorOr;
-
class ABSL_MUST_USE_RESULT PosixError {
public:
PosixError() {}
@@ -49,7 +44,8 @@ class ABSL_MUST_USE_RESULT PosixError {
// PosixErrorOr.
const PosixError& error() const { return *this; }
- std::string error_message() const { return msg_; }
+ int errno_value() const { return errno_; }
+ std::string message() const { return msg_; }
// ToString produces a full string representation of this posix error
// including the printable representation of the errno and the error message.
@@ -61,14 +57,8 @@ class ABSL_MUST_USE_RESULT PosixError {
void IgnoreError() const {}
private:
- int errno_value() const { return errno_; }
int errno_ = 0;
std::string msg_;
-
- friend class PosixErrorIsMatcherCommonImpl;
-
- template <typename T>
- friend class PosixErrorOr;
};
template <typename T>
@@ -94,15 +84,12 @@ class ABSL_MUST_USE_RESULT PosixErrorOr {
template <typename U>
PosixErrorOr& operator=(PosixErrorOr<U> other);
- // Return a reference to the error or NoError().
- PosixError error() const;
-
- // Returns this->error().error_message();
- std::string error_message() const;
-
// Returns true if this PosixErrorOr contains some T.
bool ok() const;
+ // Return a copy of the contained PosixError or NoError().
+ PosixError error() const;
+
// Returns a reference to our current value, or CHECK-fails if !this->ok().
const T& ValueOrDie() const&;
T& ValueOrDie() &;
@@ -115,7 +102,6 @@ class ABSL_MUST_USE_RESULT PosixErrorOr {
void IgnoreError() const {}
private:
- int errno_value() const;
absl::variant<T, PosixError> value_;
friend class PosixErrorIsMatcherCommonImpl;
@@ -171,16 +157,6 @@ PosixError PosixErrorOr<T>::error() const {
}
template <typename T>
-int PosixErrorOr<T>::errno_value() const {
- return error().errno_value();
-}
-
-template <typename T>
-std::string PosixErrorOr<T>::error_message() const {
- return error().error_message();
-}
-
-template <typename T>
bool PosixErrorOr<T>::ok() const {
return absl::holds_alternative<T>(value_);
}
diff --git a/test/util/save_util.cc b/test/util/save_util.cc
index 384d626f0..59d47e06e 100644
--- a/test/util/save_util.cc
+++ b/test/util/save_util.cc
@@ -21,35 +21,47 @@
#include <atomic>
#include <cerrno>
-#define GVISOR_COOPERATIVE_SAVE_TEST "GVISOR_COOPERATIVE_SAVE_TEST"
+#include "absl/types/optional.h"
namespace gvisor {
namespace testing {
namespace {
-enum class CooperativeSaveMode {
- kUnknown = 0, // cooperative_save_mode is statically-initialized to 0
- kAvailable,
- kNotAvailable,
-};
-
-std::atomic<CooperativeSaveMode> cooperative_save_mode;
-
-bool CooperativeSaveEnabled() {
- auto mode = cooperative_save_mode.load();
- if (mode == CooperativeSaveMode::kUnknown) {
- mode = (getenv(GVISOR_COOPERATIVE_SAVE_TEST) != nullptr)
- ? CooperativeSaveMode::kAvailable
- : CooperativeSaveMode::kNotAvailable;
- cooperative_save_mode.store(mode);
+std::atomic<absl::optional<bool>> cooperative_save_present;
+std::atomic<absl::optional<bool>> random_save_present;
+
+bool CooperativeSavePresent() {
+ auto present = cooperative_save_present.load();
+ if (!present.has_value()) {
+ present = getenv("GVISOR_COOPERATIVE_SAVE_TEST") != nullptr;
+ cooperative_save_present.store(present);
+ }
+ return present.value();
+}
+
+bool RandomSavePresent() {
+ auto present = random_save_present.load();
+ if (!present.has_value()) {
+ present = getenv("GVISOR_RANDOM_SAVE_TEST") != nullptr;
+ random_save_present.store(present);
}
- return mode == CooperativeSaveMode::kAvailable;
+ return present.value();
}
std::atomic<int> save_disable;
} // namespace
+bool IsRunningWithSaveRestore() {
+ return CooperativeSavePresent() || RandomSavePresent();
+}
+
+void MaybeSave() {
+ if (CooperativeSavePresent() && save_disable.load() == 0) {
+ internal::DoCooperativeSave();
+ }
+}
+
DisableSave::DisableSave() { save_disable++; }
DisableSave::~DisableSave() { reset(); }
@@ -61,11 +73,5 @@ void DisableSave::reset() {
}
}
-namespace internal {
-bool ShouldSave() {
- return CooperativeSaveEnabled() && (save_disable.load() == 0);
-}
-} // namespace internal
-
} // namespace testing
} // namespace gvisor
diff --git a/test/util/save_util.h b/test/util/save_util.h
index bddad6120..e7218ae88 100644
--- a/test/util/save_util.h
+++ b/test/util/save_util.h
@@ -17,9 +17,17 @@
namespace gvisor {
namespace testing {
-// Disable save prevents saving while the given function executes.
+
+// Returns true if the environment in which the calling process is executing
+// allows the test to be checkpointed and restored during execution.
+bool IsRunningWithSaveRestore();
+
+// May perform a co-operative save cycle.
//
-// This lasts the duration of the object, unless reset is called.
+// errno is guaranteed to be preserved.
+void MaybeSave();
+
+// Causes MaybeSave to become a no-op until destroyed or reset.
class DisableSave {
public:
DisableSave();
@@ -37,13 +45,13 @@ class DisableSave {
bool reset_ = false;
};
-// May perform a co-operative save cycle.
+namespace internal {
+
+// Causes a co-operative save cycle to occur.
//
// errno is guaranteed to be preserved.
-void MaybeSave();
+void DoCooperativeSave();
-namespace internal {
-bool ShouldSave();
} // namespace internal
} // namespace testing
diff --git a/test/util/save_util_linux.cc b/test/util/save_util_linux.cc
index d0aea8e6a..57431b3ea 100644
--- a/test/util/save_util_linux.cc
+++ b/test/util/save_util_linux.cc
@@ -30,20 +30,20 @@
namespace gvisor {
namespace testing {
-
-void MaybeSave() {
- if (internal::ShouldSave()) {
- int orig_errno = errno;
- // We use it to trigger saving the sentry state
- // when this syscall is called.
- // Notice: this needs to be a valid syscall
- // that is not used in any of the syscall tests.
- syscall(SYS_TRIGGER_SAVE, nullptr, 0);
- errno = orig_errno;
- }
+namespace internal {
+
+void DoCooperativeSave() {
+ int orig_errno = errno;
+ // We use it to trigger saving the sentry state
+ // when this syscall is called.
+ // Notice: this needs to be a valid syscall
+ // that is not used in any of the syscall tests.
+ syscall(SYS_TRIGGER_SAVE, nullptr, 0);
+ errno = orig_errno;
}
+} // namespace internal
} // namespace testing
} // namespace gvisor
-#endif
+#endif // __linux__
diff --git a/test/util/save_util_other.cc b/test/util/save_util_other.cc
index 931af2c29..7749ded76 100644
--- a/test/util/save_util_other.cc
+++ b/test/util/save_util_other.cc
@@ -14,13 +14,17 @@
#ifndef __linux__
+#include "test/util/logging.h"
+
namespace gvisor {
namespace testing {
+namespace internal {
-void MaybeSave() {
- // Saving is never available in a non-linux environment.
+void DoCooperativeSave() {
+ TEST_CHECK_MSG(false, "DoCooperativeSave not implemented");
}
+} // namespace internal
} // namespace testing
} // namespace gvisor
diff --git a/test/util/timer_util.cc b/test/util/timer_util.cc
index 43a26b0d3..75cfc4f40 100644
--- a/test/util/timer_util.cc
+++ b/test/util/timer_util.cc
@@ -23,5 +23,23 @@ absl::Time Now(clockid_t id) {
return absl::TimeFromTimespec(now);
}
+#ifdef __linux__
+
+PosixErrorOr<IntervalTimer> TimerCreate(clockid_t clockid,
+ const struct sigevent& sev) {
+ int timerid;
+ int ret = syscall(SYS_timer_create, clockid, &sev, &timerid);
+ if (ret < 0) {
+ return PosixError(errno, "timer_create");
+ }
+ if (ret > 0) {
+ return PosixError(EINVAL, "timer_create should never return positive");
+ }
+ MaybeSave();
+ return IntervalTimer(timerid);
+}
+
+#endif // __linux__
+
} // namespace testing
} // namespace gvisor
diff --git a/test/util/timer_util.h b/test/util/timer_util.h
index 31aea4fc6..926e6632f 100644
--- a/test/util/timer_util.h
+++ b/test/util/timer_util.h
@@ -16,6 +16,9 @@
#define GVISOR_TEST_UTIL_TIMER_UTIL_H_
#include <errno.h>
+#ifdef __linux__
+#include <sys/syscall.h>
+#endif
#include <sys/time.h>
#include <functional>
@@ -30,6 +33,9 @@
namespace gvisor {
namespace testing {
+// Returns the current time.
+absl::Time Now(clockid_t id);
+
// MonotonicTimer is a simple timer that uses a monotonic clock.
class MonotonicTimer {
public:
@@ -65,8 +71,92 @@ inline PosixErrorOr<Cleanup> ScopedItimer(int which,
}));
}
-// Returns the current time.
-absl::Time Now(clockid_t id);
+#ifdef __linux__
+
+// RAII type for a kernel "POSIX" interval timer. (The kernel provides system
+// calls such as timer_create that behave very similarly, but not identically,
+// to those described by timer_create(2); in particular, the kernel does not
+// implement SIGEV_THREAD. glibc builds POSIX-compliant interval timers based on
+// these kernel interval timers.)
+//
+// Compare implementation to FileDescriptor.
+class IntervalTimer {
+ public:
+ IntervalTimer() = default;
+
+ explicit IntervalTimer(int id) { set_id(id); }
+
+ IntervalTimer(IntervalTimer&& orig) : id_(orig.release()) {}
+
+ IntervalTimer& operator=(IntervalTimer&& orig) {
+ if (this == &orig) return *this;
+ reset(orig.release());
+ return *this;
+ }
+
+ IntervalTimer(const IntervalTimer& other) = delete;
+ IntervalTimer& operator=(const IntervalTimer& other) = delete;
+
+ ~IntervalTimer() { reset(); }
+
+ int get() const { return id_; }
+
+ int release() {
+ int const id = id_;
+ id_ = -1;
+ return id;
+ }
+
+ void reset() { reset(-1); }
+
+ void reset(int id) {
+ if (id_ >= 0) {
+ TEST_PCHECK(syscall(SYS_timer_delete, id_) == 0);
+ MaybeSave();
+ }
+ set_id(id);
+ }
+
+ PosixErrorOr<struct itimerspec> Set(
+ int flags, const struct itimerspec& new_value) const {
+ struct itimerspec old_value = {};
+ if (syscall(SYS_timer_settime, id_, flags, &new_value, &old_value) < 0) {
+ return PosixError(errno, "timer_settime");
+ }
+ MaybeSave();
+ return old_value;
+ }
+
+ PosixErrorOr<struct itimerspec> Get() const {
+ struct itimerspec curr_value = {};
+ if (syscall(SYS_timer_gettime, id_, &curr_value) < 0) {
+ return PosixError(errno, "timer_gettime");
+ }
+ MaybeSave();
+ return curr_value;
+ }
+
+ PosixErrorOr<int> Overruns() const {
+ int rv = syscall(SYS_timer_getoverrun, id_);
+ if (rv < 0) {
+ return PosixError(errno, "timer_getoverrun");
+ }
+ MaybeSave();
+ return rv;
+ }
+
+ private:
+ void set_id(int id) { id_ = std::max(id, -1); }
+
+ // Kernel timer_t is int; glibc timer_t is void*.
+ int id_ = -1;
+};
+
+// A wrapper around timer_create(2).
+PosixErrorOr<IntervalTimer> TimerCreate(clockid_t clockid,
+ const struct sigevent& sev);
+
+#endif // __linux__
} // namespace testing
} // namespace gvisor
diff --git a/tools/bazel.mk b/tools/bazel.mk
index 25575c02c..88431ce66 100644
--- a/tools/bazel.mk
+++ b/tools/bazel.mk
@@ -14,12 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+# Make hacks.
+EMPTY :=
+SPACE := $(EMPTY) $(EMPTY)
+
# See base Makefile.
SHELL=/bin/bash -o pipefail
BRANCH_NAME := $(shell (git branch --show-current 2>/dev/null || \
git rev-parse --abbrev-ref HEAD 2>/dev/null) | \
xargs -n 1 basename 2>/dev/null)
-BUILD_ROOT := $(CURDIR)/bazel-bin/
+BUILD_ROOTS := bazel-bin/ bazel-out/
# Bazel container configuration (see below).
USER ?= gvisor
@@ -152,10 +156,12 @@ build_cmd = docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) sh -o pipefai
build_paths = $(build_cmd) 2>&1 \
| tee /proc/self/fd/2 \
- | grep " bazel-bin/" \
+ | grep -A1 -E '^Target' \
+ | grep -E '^ ($(subst $(SPACE),|,$(BUILD_ROOTS)))' \
| sed "s/ /\n/g" \
| strings -n 10 \
| awk '{$$1=$$1};1' \
+ | xargs -n 1 -I {} readlink -f "{}" \
| xargs -n 1 -I {} sh -c "$(1)"
build: bazel-server
diff --git a/tools/bazeldefs/BUILD b/tools/bazeldefs/BUILD
index 8d4356119..d043caf06 100644
--- a/tools/bazeldefs/BUILD
+++ b/tools/bazeldefs/BUILD
@@ -26,43 +26,6 @@ rbe_platform(
remote_execution_properties = """
properties: {
name: "container-image"
- value:"docker://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04@sha256:93f7e127196b9b653d39830c50f8b05d49ef6fd8739a9b5b8ab16e1df5399e50"
- }
- properties: {
- name: "dockerAddCapabilities"
- value: "SYS_ADMIN"
- }
- properties: {
- name: "dockerPrivileged"
- value: "true"
- }
- """,
-)
-
-rbe_toolchain(
- name = "cc-toolchain-clang-x86_64-default",
- exec_compatible_with = [],
- tags = [
- "manual",
- ],
- target_compatible_with = [],
- toolchain = "@bazel_toolchains//configs/ubuntu16_04_clang/10.0.0/bazel_2.0.0/cc:cc-compiler-k8",
- toolchain_type = "@bazel_tools//tools/cpp:toolchain_type",
-)
-
-# Updated versions of the above, compatible with bazel3.
-rbe_platform(
- name = "rbe_ubuntu1604_bazel3",
- constraint_values = [
- "@bazel_tools//platforms:x86_64",
- "@bazel_tools//platforms:linux",
- "@bazel_tools//tools/cpp:clang",
- "@bazel_toolchains_bazel3//constraints:xenial",
- "@bazel_toolchains_bazel3//constraints/sanitizers:support_msan",
- ],
- remote_execution_properties = """
- properties: {
- name: "container-image"
value:"docker://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04@sha256:b516a2d69537cb40a7c6a7d92d0008abb29fba8725243772bdaf2c83f1be2272"
}
properties: {
@@ -77,13 +40,13 @@ rbe_platform(
)
rbe_toolchain(
- name = "cc-toolchain-clang-x86_64-default_bazel3",
+ name = "cc-toolchain-clang-x86_64-default",
exec_compatible_with = [],
tags = [
"manual",
],
target_compatible_with = [],
- toolchain = "@bazel_toolchains_bazel3//configs/ubuntu16_04_clang/11.0.0/bazel_3.1.0/cc:cc-compiler-k8",
+ toolchain = "@bazel_toolchains//configs/ubuntu16_04_clang/11.0.0/bazel_3.1.0/cc:cc-compiler-k8",
toolchain_type = "@bazel_tools//tools/cpp:toolchain_type",
)
diff --git a/tools/bazeldefs/cc.bzl b/tools/bazeldefs/cc.bzl
new file mode 100644
index 000000000..7f41a0142
--- /dev/null
+++ b/tools/bazeldefs/cc.bzl
@@ -0,0 +1,43 @@
+"""C++ rules."""
+
+load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", _cc_flags_supplier = "cc_flags_supplier")
+load("@rules_cc//cc:defs.bzl", _cc_binary = "cc_binary", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test")
+load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", _cc_grpc_library = "cc_grpc_library")
+
+cc_library = _cc_library
+cc_flags_supplier = _cc_flags_supplier
+cc_proto_library = _cc_proto_library
+cc_test = _cc_test
+cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain"
+gtest = "@com_google_googletest//:gtest"
+gbenchmark = "@com_google_benchmark//:benchmark"
+grpcpp = "@com_github_grpc_grpc//:grpc++"
+vdso_linker_option = "-fuse-ld=gold "
+
+def cc_grpc_library(name, **kwargs):
+ _cc_grpc_library(name = name, grpc_only = True, **kwargs)
+
+def cc_binary(name, static = False, **kwargs):
+ """Run cc_binary.
+
+ Args:
+ name: name of the target.
+ static: make a static binary if True
+ **kwargs: the rest of the args.
+ """
+ if static:
+ # How to statically link a c++ program that uses threads, like for gRPC:
+ # https://gcc.gnu.org/legacy-ml/gcc-help/2010-05/msg00029.html
+ if "linkopts" not in kwargs:
+ kwargs["linkopts"] = []
+ kwargs["linkopts"] += [
+ "-static",
+ "-lstdc++",
+ "-Wl,--whole-archive",
+ "-lpthread",
+ "-Wl,--no-whole-archive",
+ ]
+ _cc_binary(
+ name = name,
+ **kwargs
+ )
diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl
index cf5b1dc0d..ba186aace 100644
--- a/tools/bazeldefs/defs.bzl
+++ b/tools/bazeldefs/defs.bzl
@@ -1,35 +1,13 @@
-"""Bazel implementations of standard rules."""
+"""Meta and miscellaneous rules."""
-load("@bazel_gazelle//:def.bzl", _gazelle = "gazelle")
load("@bazel_skylib//rules:build_test.bzl", _build_test = "build_test")
load("@bazel_skylib//:bzl_library.bzl", _bzl_library = "bzl_library")
-load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", _cc_flags_supplier = "cc_flags_supplier")
-load("@io_bazel_rules_go//go:def.bzl", "GoLibrary", _go_binary = "go_binary", _go_context = "go_context", _go_embed_data = "go_embed_data", _go_library = "go_library", _go_path = "go_path", _go_test = "go_test")
-load("@io_bazel_rules_go//proto:def.bzl", _go_grpc_library = "go_grpc_library", _go_proto_library = "go_proto_library")
-load("@rules_cc//cc:defs.bzl", _cc_binary = "cc_binary", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test")
-load("@rules_pkg//:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar")
-load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", _cc_grpc_library = "cc_grpc_library")
build_test = _build_test
bzl_library = _bzl_library
-cc_library = _cc_library
-cc_flags_supplier = _cc_flags_supplier
-cc_proto_library = _cc_proto_library
-cc_test = _cc_test
-cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain"
-gazelle = _gazelle
-go_embed_data = _go_embed_data
-go_path = _go_path
-gtest = "@com_google_googletest//:gtest"
-grpcpp = "@com_github_grpc_grpc//:grpc++"
-gbenchmark = "@com_google_benchmark//:benchmark"
loopback = "//tools/bazeldefs:loopback"
-pkg_deb = _pkg_deb
-pkg_tar = _pkg_tar
-py_binary = native.py_binary
rbe_platform = native.platform
rbe_toolchain = native.toolchain
-vdso_linker_option = "-fuse-ld=gold "
def short_path(path):
return path
@@ -40,140 +18,6 @@ def proto_library(name, has_services = None, **kwargs):
**kwargs
)
-def cc_grpc_library(name, **kwargs):
- _cc_grpc_library(name = name, grpc_only = True, **kwargs)
-
-def _go_proto_or_grpc_library(go_library_func, name, **kwargs):
- deps = [
- dep.replace("_proto", "_go_proto")
- for dep in (kwargs.pop("deps", []) or [])
- ]
- go_library_func(
- name = name + "_go_proto",
- importpath = "gvisor.dev/gvisor/" + native.package_name() + "/" + name + "_go_proto",
- proto = ":" + name + "_proto",
- deps = deps,
- **kwargs
- )
-
-def go_proto_library(name, **kwargs):
- _go_proto_or_grpc_library(_go_proto_library, name, **kwargs)
-
-def go_grpc_and_proto_libraries(name, **kwargs):
- _go_proto_or_grpc_library(_go_grpc_library, name, **kwargs)
-
-def cc_binary(name, static = False, **kwargs):
- """Run cc_binary.
-
- Args:
- name: name of the target.
- static: make a static binary if True
- **kwargs: the rest of the args.
- """
- if static:
- # How to statically link a c++ program that uses threads, like for gRPC:
- # https://gcc.gnu.org/legacy-ml/gcc-help/2010-05/msg00029.html
- if "linkopts" not in kwargs:
- kwargs["linkopts"] = []
- kwargs["linkopts"] += [
- "-static",
- "-lstdc++",
- "-Wl,--whole-archive",
- "-lpthread",
- "-Wl,--no-whole-archive",
- ]
- _cc_binary(
- name = name,
- **kwargs
- )
-
-def go_binary(name, static = False, pure = False, x_defs = None, **kwargs):
- """Build a go binary.
-
- Args:
- name: name of the target.
- static: build a static binary.
- pure: build without cgo.
- x_defs: additional definitions.
- **kwargs: rest of the arguments are passed to _go_binary.
- """
- if static:
- kwargs["static"] = "on"
- if pure:
- kwargs["pure"] = "on"
- _go_binary(
- name = name,
- x_defs = x_defs,
- **kwargs
- )
-
-def go_importpath(target):
- """Returns the importpath for the target."""
- return target[GoLibrary].importpath
-
-def go_library(name, **kwargs):
- _go_library(
- name = name,
- importpath = "gvisor.dev/gvisor/" + native.package_name(),
- **kwargs
- )
-
-def go_test(name, pure = False, library = None, **kwargs):
- """Build a go test.
-
- Args:
- name: name of the output binary.
- pure: should it be built without cgo.
- library: the library to embed.
- **kwargs: rest of the arguments to pass to _go_test.
- """
- if pure:
- kwargs["pure"] = "on"
- if library:
- kwargs["embed"] = [library]
- _go_test(
- name = name,
- **kwargs
- )
-
-def go_rule(rule, implementation, **kwargs):
- """Wraps a rule definition with Go attributes.
-
- Args:
- rule: rule function (typically rule or aspect).
- implementation: implementation function.
- **kwargs: other arguments to pass to rule.
-
- Returns:
- The result of invoking the rule.
- """
- attrs = kwargs.pop("attrs", dict())
- attrs["_go_context_data"] = attr.label(default = "@io_bazel_rules_go//:go_context_data")
- attrs["_stdlib"] = attr.label(default = "@io_bazel_rules_go//:stdlib")
- toolchains = kwargs.get("toolchains", []) + ["@io_bazel_rules_go//go:toolchain"]
- return rule(implementation, attrs = attrs, toolchains = toolchains, **kwargs)
-
-def go_test_library(target):
- if hasattr(target.attr, "embed") and len(target.attr.embed) > 0:
- return target.attr.embed[0]
- return None
-
-def go_context(ctx, std = False):
- # We don't change anything for the standard library analysis. All Go files
- # are available in all instances. Note that this includes the standard
- # library sources, which are analyzed by nogo.
- go_ctx = _go_context(ctx)
- return struct(
- go = go_ctx.go,
- env = go_ctx.env,
- nogo_args = [],
- stdlib_srcs = go_ctx.sdk.srcs,
- runfiles = depset([go_ctx.go] + go_ctx.sdk.srcs + go_ctx.sdk.tools + go_ctx.stdlib.libs),
- goos = go_ctx.sdk.goos,
- goarch = go_ctx.sdk.goarch,
- tags = go_ctx.tags,
- )
-
def select_arch(amd64 = "amd64", arm64 = "arm64", default = None, **kwargs):
values = {
"@bazel_tools//src/conditions:linux_x86_64": amd64,
diff --git a/tools/bazeldefs/go.bzl b/tools/bazeldefs/go.bzl
new file mode 100644
index 000000000..d388346a5
--- /dev/null
+++ b/tools/bazeldefs/go.bzl
@@ -0,0 +1,142 @@
+"""Go rules."""
+
+load("@bazel_gazelle//:def.bzl", _gazelle = "gazelle")
+load("@io_bazel_rules_go//go:def.bzl", "GoLibrary", _go_binary = "go_binary", _go_context = "go_context", _go_embed_data = "go_embed_data", _go_library = "go_library", _go_path = "go_path", _go_test = "go_test")
+load("@io_bazel_rules_go//proto:def.bzl", _go_grpc_library = "go_grpc_library", _go_proto_library = "go_proto_library")
+load("//tools/bazeldefs:defs.bzl", "select_arch", "select_system")
+
+gazelle = _gazelle
+go_embed_data = _go_embed_data
+go_path = _go_path
+
+def _go_proto_or_grpc_library(go_library_func, name, **kwargs):
+ deps = [
+ dep.replace("_proto", "_go_proto")
+ for dep in (kwargs.pop("deps", []) or [])
+ ]
+ go_library_func(
+ name = name + "_go_proto",
+ importpath = "gvisor.dev/gvisor/" + native.package_name() + "/" + name + "_go_proto",
+ proto = ":" + name + "_proto",
+ deps = deps,
+ **kwargs
+ )
+
+def go_proto_library(name, **kwargs):
+ _go_proto_or_grpc_library(_go_proto_library, name, **kwargs)
+
+def go_grpc_and_proto_libraries(name, **kwargs):
+ _go_proto_or_grpc_library(_go_grpc_library, name, **kwargs)
+
+def go_binary(name, static = False, pure = False, x_defs = None, **kwargs):
+ """Build a go binary.
+
+ Args:
+ name: name of the target.
+ static: build a static binary.
+ pure: build without cgo.
+ x_defs: additional definitions.
+ **kwargs: rest of the arguments are passed to _go_binary.
+ """
+ if static:
+ kwargs["static"] = "on"
+ if pure:
+ kwargs["pure"] = "on"
+ _go_binary(
+ name = name,
+ x_defs = x_defs,
+ **kwargs
+ )
+
+def go_importpath(target):
+ """Returns the importpath for the target."""
+ return target[GoLibrary].importpath
+
+def go_library(name, **kwargs):
+ _go_library(
+ name = name,
+ importpath = "gvisor.dev/gvisor/" + native.package_name(),
+ **kwargs
+ )
+
+def go_test(name, pure = False, library = None, **kwargs):
+ """Build a go test.
+
+ Args:
+ name: name of the output binary.
+ pure: should it be built without cgo.
+ library: the library to embed.
+ **kwargs: rest of the arguments to pass to _go_test.
+ """
+ if pure:
+ kwargs["pure"] = "on"
+ if library:
+ kwargs["embed"] = [library]
+ _go_test(
+ name = name,
+ **kwargs
+ )
+
+def go_rule(rule, implementation, **kwargs):
+ """Wraps a rule definition with Go attributes.
+
+ Args:
+ rule: rule function (typically rule or aspect).
+ implementation: implementation function.
+ **kwargs: other arguments to pass to rule.
+
+ Returns:
+ The result of invoking the rule.
+ """
+ attrs = kwargs.pop("attrs", dict())
+ attrs["_go_context_data"] = attr.label(default = "@io_bazel_rules_go//:go_context_data")
+ attrs["_stdlib"] = attr.label(default = "@io_bazel_rules_go//:stdlib")
+ toolchains = kwargs.get("toolchains", []) + ["@io_bazel_rules_go//go:toolchain"]
+ return rule(implementation, attrs = attrs, toolchains = toolchains, **kwargs)
+
+def go_test_library(target):
+ if hasattr(target.attr, "embed") and len(target.attr.embed) > 0:
+ return target.attr.embed[0]
+ return None
+
+def go_context(ctx, goos = None, goarch = None, std = False):
+ """Extracts a standard Go context struct.
+
+ Args:
+ ctx: the starlark context (required).
+ goos: the GOOS value.
+ goarch: the GOARCH value.
+ std: ignored.
+
+ Returns:
+ A context Go struct with pointers to Go toolchain components.
+ """
+
+ # We don't change anything for the standard library analysis. All Go files
+ # are available in all instances. Note that this includes the standard
+ # library sources, which are analyzed by nogo.
+ go_ctx = _go_context(ctx)
+ if goos == None:
+ goos = go_ctx.sdk.goos
+ elif goos != go_ctx.sdk.goos:
+ fail("Internal GOOS (%s) doesn't match GoSdk GOOS (%s)." % (goos, go_ctx.sdk.goos))
+ if goarch == None:
+ goarch = go_ctx.sdk.goarch
+ elif goarch != go_ctx.sdk.goarch:
+ fail("Internal GOARCH (%s) doesn't match GoSdk GOARCH (%s)." % (goarch, go_ctx.sdk.goarch))
+ return struct(
+ go = go_ctx.go,
+ env = go_ctx.env,
+ nogo_args = [],
+ stdlib_srcs = go_ctx.sdk.srcs,
+ runfiles = depset([go_ctx.go] + go_ctx.sdk.srcs + go_ctx.sdk.tools + go_ctx.stdlib.libs),
+ goos = go_ctx.sdk.goos,
+ goarch = go_ctx.sdk.goarch,
+ tags = go_ctx.tags,
+ )
+
+def select_goarch():
+ return select_arch(arm64 = "arm64", amd64 = "amd64")
+
+def select_goos():
+ return select_system(linux = "linux")
diff --git a/tools/bazeldefs/pkg.bzl b/tools/bazeldefs/pkg.bzl
new file mode 100644
index 000000000..56317d93f
--- /dev/null
+++ b/tools/bazeldefs/pkg.bzl
@@ -0,0 +1,6 @@
+"""Packaging rules."""
+
+load("@rules_pkg//:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar")
+
+pkg_deb = _pkg_deb
+pkg_tar = _pkg_tar
diff --git a/tools/bigquery/BUILD b/tools/bigquery/BUILD
index 2b0062a63..1cea9e1c9 100644
--- a/tools/bigquery/BUILD
+++ b/tools/bigquery/BUILD
@@ -9,5 +9,8 @@ go_library(
visibility = [
"//:sandbox",
],
- deps = ["@com_google_cloud_go_bigquery//:go_default_library"],
+ deps = [
+ "@com_google_cloud_go_bigquery//:go_default_library",
+ "@org_golang_google_api//option:go_default_library",
+ ],
)
diff --git a/tools/bigquery/bigquery.go b/tools/bigquery/bigquery.go
index 5f1a882de..34b270cc0 100644
--- a/tools/bigquery/bigquery.go
+++ b/tools/bigquery/bigquery.go
@@ -25,22 +25,30 @@ import (
"time"
bq "cloud.google.com/go/bigquery"
+ "google.golang.org/api/option"
)
-// Benchmark is the top level structure of recorded benchmark data. BigQuery
+// Suite is the top level structure for a benchmark run. BigQuery
// will infer the schema from this.
+type Suite struct {
+ Name string `bq:"name"`
+ Conditions []*Condition `bq:"conditions"`
+ Benchmarks []*Benchmark `bq:"benchmarks"`
+ Official bool `bq:"official"`
+ Timestamp time.Time `bq:"timestamp"`
+}
+
+// Benchmark represents an individual benchmark in a suite.
type Benchmark struct {
Name string `bq:"name"`
Condition []*Condition `bq:"condition"`
- Timestamp time.Time `bq:"timestamp"`
- Official bool `bq:"official"`
Metric []*Metric `bq:"metric"`
- Metadata *Metadata `bq:"metadata"`
}
-// Condition represents qualifiers for the benchmark. For example:
+// Condition represents qualifiers for the benchmark or suite. For example:
// Get_Pid/1/real_time would have Benchmark Name "Get_Pid" with "1"
-// and "real_time" parameters as conditions.
+// and "real_time" parameters as conditions. Suite conditions include
+// information such as the CL number and platform name.
type Condition struct {
Name string `bq:"name"`
Value string `bq:"value"`
@@ -53,19 +61,9 @@ type Metric struct {
Sample float64 `bq:"sample"`
}
-// Metadata about this benchmark.
-type Metadata struct {
- CL string `bq:"changelist"`
- IterationID string `bq:"iteration_id"`
- PendingCL string `bq:"pending_cl"`
- Workflow string `bq:"workflow"`
- Platform string `bq:"platform"`
- Gofer string `bq:"gofer"`
-}
-
// InitBigQuery initializes a BigQuery dataset/table in the project. If the dataset/table already exists, it is not duplicated.
-func InitBigQuery(ctx context.Context, projectID, datasetID, tableID string) error {
- client, err := bq.NewClient(ctx, projectID)
+func InitBigQuery(ctx context.Context, projectID, datasetID, tableID string, opts []option.ClientOption) error {
+ client, err := bq.NewClient(ctx, projectID, opts...)
if err != nil {
return fmt.Errorf("failed to initialize client on project %s: %v", projectID, err)
}
@@ -77,7 +75,7 @@ func InitBigQuery(ctx context.Context, projectID, datasetID, tableID string) err
}
table := dataset.Table(tableID)
- schema, err := bq.InferSchema(Benchmark{})
+ schema, err := bq.InferSchema(Suite{})
if err != nil {
return fmt.Errorf("failed to infer schema: %v", err)
}
@@ -109,24 +107,32 @@ func (bm *Benchmark) AddMetric(metricName, unit string, sample float64) {
// NewBenchmark initializes a new benchmark.
func NewBenchmark(name string, iters int, official bool) *Benchmark {
return &Benchmark{
- Name: name,
- Timestamp: time.Now().UTC(),
- Official: official,
- Metric: make([]*Metric, 0),
+ Name: name,
+ Metric: make([]*Metric, 0),
+ }
+}
+
+// NewSuite initializes a new Suite.
+func NewSuite(name string) *Suite {
+ return &Suite{
+ Name: name,
+ Timestamp: time.Now().UTC(),
+ Benchmarks: make([]*Benchmark, 0),
+ Conditions: make([]*Condition, 0),
}
}
// SendBenchmarks sends the slice of benchmarks to the BigQuery dataset/table.
-func SendBenchmarks(ctx context.Context, benchmarks []*Benchmark, projectID, datasetID, tableID string) error {
- client, err := bq.NewClient(ctx, projectID)
+func SendBenchmarks(ctx context.Context, suite *Suite, projectID, datasetID, tableID string, opts []option.ClientOption) error {
+ client, err := bq.NewClient(ctx, projectID, opts...)
if err != nil {
return fmt.Errorf("failed to initialize client on project: %s: %v", projectID, err)
}
defer client.Close()
uploader := client.Dataset(datasetID).Table(tableID).Uploader()
- if err = uploader.Put(ctx, benchmarks); err != nil {
- return fmt.Errorf("failed to upload benchmarks to proejct %s, table %s.%s: %v", projectID, datasetID, tableID, err)
+ if err = uploader.Put(ctx, suite); err != nil {
+ return fmt.Errorf("failed to upload benchmarks %s to project %s, table %s.%s: %v", suite.Name, projectID, datasetID, tableID, err)
}
return nil
diff --git a/tools/checkescape/checkescape.go b/tools/checkescape/checkescape.go
index 523a42692..e5a7e23c7 100644
--- a/tools/checkescape/checkescape.go
+++ b/tools/checkescape/checkescape.go
@@ -67,6 +67,7 @@ import (
"go/token"
"go/types"
"io"
+ "log"
"os"
"os/exec"
"path/filepath"
@@ -619,7 +620,10 @@ func findReasons(pass *analysis.Pass, fdecl *ast.FuncDecl) ([]EscapeReason, bool
func run(pass *analysis.Pass, localEscapes bool) (interface{}, error) {
calls, err := loadObjdump()
if err != nil {
- return nil, err
+ // Note that if this analysis fails, then we don't actually
+ // fail the analyzer itself. We simply report every possible
+ // escape. In most cases this will work just fine.
+ log.Printf("WARNING: unable to load objdump: %v", err)
}
allEscapes := make(map[string][]Escapes)
mergedEscapes := make(map[string]Escapes)
@@ -641,6 +645,11 @@ func run(pass *analysis.Pass, localEscapes bool) (interface{}, error) {
}
hasCall := func(inst poser) (string, bool) {
p := linePosition(inst, nil)
+ if calls == nil {
+ // See above: we don't have access to the binary
+ // itself, so need to include every possible call.
+ return "(possible)", true
+ }
s, ok := calls[p.Simplified()]
if !ok {
return "", false
diff --git a/tools/defs.bzl b/tools/defs.bzl
index e2c72a1f6..d75e40863 100644
--- a/tools/defs.bzl
+++ b/tools/defs.bzl
@@ -7,39 +7,49 @@ change for Google-internal and bazel-compatible rules.
load("//tools/go_stateify:defs.bzl", "go_stateify")
load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps")
-load("//tools/bazeldefs:defs.bzl", _build_test = "build_test", _bzl_library = "bzl_library", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _coreutil = "coreutil", _default_installer = "default_installer", _default_net_util = "default_net_util", _gazelle = "gazelle", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_path = "go_path", _go_proto_library = "go_proto_library", _go_test = "go_test", _grpcpp = "grpcpp", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _rbe_platform = "rbe_platform", _rbe_toolchain = "rbe_toolchain", _select_arch = "select_arch", _select_system = "select_system", _short_path = "short_path", _vdso_linker_option = "vdso_linker_option")
+load("//tools/nogo:defs.bzl", "nogo_test")
+load("//tools/bazeldefs:defs.bzl", _build_test = "build_test", _bzl_library = "bzl_library", _coreutil = "coreutil", _default_installer = "default_installer", _default_net_util = "default_net_util", _loopback = "loopback", _proto_library = "proto_library", _rbe_platform = "rbe_platform", _rbe_toolchain = "rbe_toolchain", _select_arch = "select_arch", _select_system = "select_system", _short_path = "short_path")
+load("//tools/bazeldefs:cc.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _gbenchmark = "gbenchmark", _grpcpp = "grpcpp", _gtest = "gtest", _vdso_linker_option = "vdso_linker_option")
+load("//tools/bazeldefs:go.bzl", _gazelle = "gazelle", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_path = "go_path", _go_proto_library = "go_proto_library", _go_test = "go_test", _select_goarch = "select_goarch", _select_goos = "select_goos")
+load("//tools/bazeldefs:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar")
load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms")
load("//tools/bazeldefs:tags.bzl", "go_suffixes")
-load("//tools/nogo:defs.bzl", "nogo_test")
-# Delegate directly.
+# Core rules.
build_test = _build_test
bzl_library = _bzl_library
+default_installer = _default_installer
+default_net_util = _default_net_util
+loopback = _loopback
+select_arch = _select_arch
+select_system = _select_system
+short_path = _short_path
+rbe_platform = _rbe_platform
+rbe_toolchain = _rbe_toolchain
+coreutil = _coreutil
+
+# C++ rules.
cc_binary = _cc_binary
cc_flags_supplier = _cc_flags_supplier
cc_grpc_library = _cc_grpc_library
cc_library = _cc_library
cc_test = _cc_test
cc_toolchain = _cc_toolchain
-default_installer = _default_installer
-default_net_util = _default_net_util
gbenchmark = _gbenchmark
+gtest = _gtest
+grpcpp = _grpcpp
+vdso_linker_option = _vdso_linker_option
+
+# Go rules.
gazelle = _gazelle
go_embed_data = _go_embed_data
go_path = _go_path
-gtest = _gtest
-grpcpp = _grpcpp
-loopback = _loopback
+select_goos = _select_goos
+select_goarch = _select_goarch
+
+# Packaging rules.
pkg_deb = _pkg_deb
pkg_tar = _pkg_tar
-py_binary = _py_binary
-select_arch = _select_arch
-select_system = _select_system
-short_path = _short_path
-rbe_platform = _rbe_platform
-rbe_toolchain = _rbe_toolchain
-vdso_linker_option = _vdso_linker_option
-coreutil = _coreutil
# Platform options.
default_platform = _default_platform
@@ -66,14 +76,20 @@ def go_binary(name, nogo = True, pure = False, static = False, x_defs = None, **
if nogo:
# Note that the nogo rule applies only for go_library and go_test
# targets, therefore we construct a library from the binary sources.
+ # This is done because the binary may not be in a form that objdump
+ # supports (i.e. a pure Go binary).
_go_library(
name = name + "_nogo_library",
- **kwargs
+ srcs = kwargs.get("srcs", []),
+ deps = kwargs.get("deps", []),
+ testonly = 1,
)
nogo_test(
name = name + "_nogo",
+ config = "//:nogo_config",
srcs = kwargs.get("srcs", []),
- library = ":" + name + "_nogo_library",
+ deps = [":" + name + "_nogo_library"],
+ tags = ["nogo"],
)
def calculate_sets(srcs):
@@ -204,8 +220,10 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F
if nogo:
nogo_test(
name = name + "_nogo",
+ config = "//:nogo_config",
srcs = all_srcs,
- library = ":" + name,
+ deps = [":" + name],
+ tags = ["nogo"],
)
if marshal:
@@ -241,8 +259,10 @@ def go_test(name, nogo = True, **kwargs):
if nogo:
nogo_test(
name = name + "_nogo",
+ config = "//:nogo_config",
srcs = kwargs.get("srcs", []),
- library = ":" + name,
+ deps = [":" + name],
+ tags = ["nogo"],
)
def proto_library(name, srcs, deps = None, has_services = 0, **kwargs):
diff --git a/tools/github/main.go b/tools/github/main.go
index 7a74dc033..681003eef 100644
--- a/tools/github/main.go
+++ b/tools/github/main.go
@@ -20,6 +20,7 @@ import (
"flag"
"fmt"
"io/ioutil"
+ "log"
"os"
"os/exec"
"strings"
@@ -34,21 +35,43 @@ var (
owner string
repo string
tokenFile string
- path string
+ paths stringList
commit string
dryRun bool
)
+type stringList []string
+
+func (s *stringList) String() string {
+ return strings.Join(*s, ",")
+}
+
+func (s *stringList) Set(value string) error {
+ *s = append(*s, value)
+ return nil
+}
+
// Keep the options simple for now. Supports only a single path and repo.
func init() {
flag.StringVar(&owner, "owner", "", "GitHub project org/owner (required, except nogo dry-run)")
flag.StringVar(&repo, "repo", "", "GitHub repo (required, except nogo dry-run)")
flag.StringVar(&tokenFile, "oauth-token-file", "", "file containing the GitHub token (or GITHUB_TOKEN is set)")
- flag.StringVar(&path, "path", ".", "path to scan (required for revive and nogo)")
+ flag.Var(&paths, "path", "path(s) to scan (required for revive and nogo)")
flag.StringVar(&commit, "commit", "", "commit to associated (required for nogo, except dry-run)")
flag.BoolVar(&dryRun, "dry-run", false, "just print changes to be made")
}
+func filterPaths(paths []string) (existing []string) {
+ for _, path := range paths {
+ if _, err := os.Stat(path); err != nil {
+ log.Printf("WARNING: skipping %v: %v", path, err)
+ continue
+ }
+ existing = append(existing, path)
+ }
+ return
+}
+
func main() {
// Set defaults from the environment.
repository := os.Getenv("GITHUB_REPOSITORY")
@@ -83,8 +106,9 @@ func main() {
flag.Usage()
os.Exit(1)
}
- if len(path) == 0 {
- fmt.Fprintln(flag.CommandLine.Output(), "missing --path option.")
+ filteredPaths := filterPaths(paths)
+ if len(filteredPaths) == 0 {
+ fmt.Fprintln(flag.CommandLine.Output(), "no valid --path options provided.")
flag.Usage()
os.Exit(1)
}
@@ -123,7 +147,7 @@ func main() {
os.Exit(1)
}
// Scan the provided path.
- rev := reviver.New([]string{path}, []reviver.Bugger{bugger})
+ rev := reviver.New(filteredPaths, []reviver.Bugger{bugger})
if errs := rev.Run(); len(errs) > 0 {
fmt.Fprintf(os.Stderr, "Encountered %d errors:\n", len(errs))
for _, err := range errs {
@@ -145,7 +169,7 @@ func main() {
}
// Scan all findings.
poster := nogo.NewFindingsPoster(client, owner, repo, commit, dryRun)
- if err := poster.Walk(path); err != nil {
+ if err := poster.Walk(filteredPaths); err != nil {
fmt.Fprintln(os.Stderr, "Error finding nogo findings:", err)
os.Exit(1)
}
diff --git a/tools/github/nogo/BUILD b/tools/github/nogo/BUILD
index 0633eaf19..19b7eec4d 100644
--- a/tools/github/nogo/BUILD
+++ b/tools/github/nogo/BUILD
@@ -10,7 +10,7 @@ go_library(
"//tools/github:__subpackages__",
],
deps = [
- "//tools/nogo/util",
+ "//tools/nogo",
"@com_github_google_go_github_v28//github:go_default_library",
],
)
diff --git a/tools/github/nogo/nogo.go b/tools/github/nogo/nogo.go
index b70dfe63b..27ab1b8eb 100644
--- a/tools/github/nogo/nogo.go
+++ b/tools/github/nogo/nogo.go
@@ -24,7 +24,7 @@ import (
"time"
"github.com/google/go-github/github"
- "gvisor.dev/gvisor/tools/nogo/util"
+ "gvisor.dev/gvisor/tools/nogo"
)
// FindingsPoster is a simple wrapper around the GitHub api.
@@ -35,7 +35,7 @@ type FindingsPoster struct {
dryRun bool
startTime time.Time
- findings map[util.Finding]struct{}
+ findings map[nogo.Finding]struct{}
client *github.Client
}
@@ -47,32 +47,37 @@ func NewFindingsPoster(client *github.Client, owner, repo, commit string, dryRun
commit: commit,
dryRun: dryRun,
startTime: time.Now(),
- findings: make(map[util.Finding]struct{}),
+ findings: make(map[nogo.Finding]struct{}),
client: client,
}
}
// Walk walks the given path tree for findings files.
-func (p *FindingsPoster) Walk(path string) error {
- return filepath.Walk(path, func(filename string, info os.FileInfo, err error) error {
- if err != nil {
- return err
- }
- // Skip any directories or files not ending in .findings.
- if !strings.HasSuffix(filename, ".findings") || info.IsDir() {
+func (p *FindingsPoster) Walk(paths []string) error {
+ for _, path := range paths {
+ if err := filepath.Walk(path, func(filename string, info os.FileInfo, err error) error {
+ if err != nil {
+ return err
+ }
+ // Skip any directories or files not ending in .findings.
+ if !strings.HasSuffix(filename, ".findings") || info.IsDir() {
+ return nil
+ }
+ findings, err := nogo.ExtractFindingsFromFile(filename)
+ if err != nil {
+ return err
+ }
+ // Add all findings to the list. We use a map to ensure
+ // that each finding is unique.
+ for _, finding := range findings {
+ p.findings[finding] = struct{}{}
+ }
return nil
- }
- findings, err := util.ExtractFindingsFromFile(filename)
- if err != nil {
+ }); err != nil {
return err
}
- // Add all findings to the list. We use a map to ensure
- // that each finding is unique.
- for _, finding := range findings {
- p.findings[finding] = struct{}{}
- }
- return nil
- })
+ }
+ return nil
}
// Post posts all results to the GitHub API as a check run.
@@ -81,7 +86,7 @@ func (p *FindingsPoster) Post() error {
if p.dryRun {
for finding, _ := range p.findings {
// Pretty print, so that this is useful for debugging.
- fmt.Printf("%s: (%s+%d) %s\n", finding.Category, finding.Path, finding.Line, finding.Message)
+ fmt.Printf("%s: (%s+%d) %s\n", finding.Category, finding.Position.Filename, finding.Position.Line, finding.Message)
}
return nil
}
@@ -110,12 +115,13 @@ func (p *FindingsPoster) Post() error {
}
annotationLevel := "failure" // Always.
for finding, _ := range p.findings {
+ title := string(finding.Category)
opts.Output.Annotations = append(opts.Output.Annotations, &github.CheckRunAnnotation{
- Path: &finding.Path,
- StartLine: &finding.Line,
- EndLine: &finding.Line,
+ Path: &finding.Position.Filename,
+ StartLine: &finding.Position.Line,
+ EndLine: &finding.Position.Line,
Message: &finding.Message,
- Title: &finding.Category,
+ Title: &title,
AnnotationLevel: &annotationLevel,
})
}
diff --git a/tools/github/reviver/github.go b/tools/github/reviver/github.go
index a95df0fb6..c4b624f2a 100644
--- a/tools/github/reviver/github.go
+++ b/tools/github/reviver/github.go
@@ -121,13 +121,24 @@ func (b *GitHubBugger) Activate(todo *Todo) (bool, error) {
return true, nil
}
+var issuePrefixes = []string{
+ "gvisor.dev/issue/",
+ "gvisor.dev/issues/",
+}
+
// parseIssueNo parses the issue number out of the issue url.
+//
+// 0 is returned if url does not correspond to an issue.
func parseIssueNo(url string) (int, error) {
- const prefix = "gvisor.dev/issue/"
-
// First check if I can handle the TODO.
- idStr := strings.TrimPrefix(url, prefix)
- if len(url) == len(idStr) {
+ var idStr string
+ for _, p := range issuePrefixes {
+ if str := strings.TrimPrefix(url, p); len(str) < len(url) {
+ idStr = str
+ break
+ }
+ }
+ if len(idStr) == 0 {
return 0, nil
}
diff --git a/tools/github/reviver/reviver_test.go b/tools/github/reviver/reviver_test.go
index a9fb1f9f1..851306c9d 100644
--- a/tools/github/reviver/reviver_test.go
+++ b/tools/github/reviver/reviver_test.go
@@ -33,6 +33,15 @@ func TestProcessLine(t *testing.T) {
},
},
{
+ line: "// TODO(foobar.com/issues/123): comment, bla. blabla.",
+ want: &Todo{
+ Issue: "foobar.com/issues/123",
+ Locations: []Location{
+ {Comment: "comment, bla. blabla."},
+ },
+ },
+ },
+ {
line: "// FIXME(b/123): internal bug",
want: &Todo{
Issue: "b/123",
diff --git a/tools/go_generics/defs.bzl b/tools/go_generics/defs.bzl
index 33329cf28..ad97208a8 100644
--- a/tools/go_generics/defs.bzl
+++ b/tools/go_generics/defs.bzl
@@ -1,25 +1,32 @@
-"""Generics support via go_generics."""
+"""Generics support via go_generics.
+
+A Go template is similar to a go library, except that it has certain types that
+can be replaced before usage. For example, one could define a templatized List
+struct, whose elements are of type T, then instantiate that template for
+T=segment, where "segment" is the concrete type.
+"""
TemplateInfo = provider(
+ "Information about a go_generics template.",
fields = {
+ "unsafe": "whether the template requires unsafe code",
"types": "required types",
"opt_types": "optional types",
"consts": "required consts",
"opt_consts": "optional consts",
"deps": "package dependencies",
- "file": "merged template",
+ "template": "merged template source file",
},
)
def _go_template_impl(ctx):
srcs = ctx.files.srcs
- output = ctx.outputs.out
-
- args = ["-o=%s" % output.path] + [f.path for f in srcs]
+ template = ctx.actions.declare_file(ctx.label.name + "_template.go")
+ args = ["-o=%s" % template.path] + [f.path for f in srcs]
ctx.actions.run(
inputs = srcs,
- outputs = [output],
+ outputs = [template],
mnemonic = "GoGenericsTemplate",
progress_message = "Building Go template %s" % ctx.label,
arguments = args,
@@ -32,74 +39,48 @@ def _go_template_impl(ctx):
consts = ctx.attr.consts,
opt_consts = ctx.attr.opt_consts,
deps = ctx.attr.deps,
- file = output,
+ template = template,
)]
-"""
-Generates a Go template from a set of Go files.
-
-A Go template is similar to a go library, except that it has certain types that
-can be replaced before usage. For example, one could define a templatized List
-struct, whose elements are of type T, then instantiate that template for
-T=segment, where "segment" is the concrete type.
-
-Args:
- name: the name of the template.
- srcs: the list of source files that comprise the template.
- types: the list of generic types in the template that are required to be specified.
- opt_types: the list of generic types in the template that can but aren't required to be specified.
- consts: the list of constants in the template that are required to be specified.
- opt_consts: the list of constants in the template that can but aren't required to be specified.
- deps: the list of dependencies.
-"""
go_template = rule(
implementation = _go_template_impl,
attrs = {
- "srcs": attr.label_list(mandatory = True, allow_files = True),
- "deps": attr.label_list(allow_files = True, cfg = "target"),
- "types": attr.string_list(),
- "opt_types": attr.string_list(),
- "consts": attr.string_list(),
- "opt_consts": attr.string_list(),
+ "srcs": attr.label_list(doc = "the list of source files that comprise the template", mandatory = True, allow_files = True),
+ "deps": attr.label_list(doc = "the standard dependency list", allow_files = True, cfg = "target"),
+ "types": attr.string_list(doc = "the list of generic types in the template that are required to be specified"),
+ "opt_types": attr.string_list(doc = "the list of generic types in the template that can but aren't required to be specified"),
+ "consts": attr.string_list(doc = "the list of constants in the template that are required to be specified"),
+ "opt_consts": attr.string_list(doc = "the list of constants in the template that can but aren't required to be specified"),
"_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_generics/go_merge")),
},
- outputs = {
- "out": "%{name}_template.go",
- },
-)
-
-TemplateInstanceInfo = provider(
- fields = {
- "srcs": "source files",
- },
)
def _go_template_instance_impl(ctx):
- template = ctx.attr.template[TemplateInfo]
+ info = ctx.attr.template[TemplateInfo]
output = ctx.outputs.out
# Check that all required types are defined.
- for t in template.types:
+ for t in info.types:
if t not in ctx.attr.types:
fail("Missing value for type %s in %s" % (t, ctx.attr.template.label))
# Check that all defined types are expected by the template.
for t in ctx.attr.types:
- if (t not in template.types) and (t not in template.opt_types):
+ if (t not in info.types) and (t not in info.opt_types):
fail("Type %s it not a parameter to %s" % (t, ctx.attr.template.label))
# Check that all required consts are defined.
- for t in template.consts:
+ for t in info.consts:
if t not in ctx.attr.consts:
fail("Missing value for constant %s in %s" % (t, ctx.attr.template.label))
# Check that all defined consts are expected by the template.
for t in ctx.attr.consts:
- if (t not in template.consts) and (t not in template.opt_consts):
+ if (t not in info.consts) and (t not in info.opt_consts):
fail("Const %s it not a parameter to %s" % (t, ctx.attr.template.label))
# Build the argument list.
- args = ["-i=%s" % template.file.path, "-o=%s" % output.path]
+ args = ["-i=%s" % info.template.path, "-o=%s" % output.path]
if ctx.attr.package:
args.append("-p=%s" % ctx.attr.package)
@@ -117,7 +98,7 @@ def _go_template_instance_impl(ctx):
args.append("-anon")
ctx.actions.run(
- inputs = [template.file],
+ inputs = [info.template],
outputs = [output],
mnemonic = "GoGenericsInstance",
progress_message = "Building Go template instance %s" % ctx.label,
@@ -125,35 +106,22 @@ def _go_template_instance_impl(ctx):
executable = ctx.executable._tool,
)
- return [TemplateInstanceInfo(
- srcs = [output],
+ return [DefaultInfo(
+ files = depset([output]),
)]
-"""
-Instantiates a Go template by replacing all generic types with concrete ones.
-
-Args:
- name: the name of the template instance.
- template: the label of the template to be instatiated.
- prefix: a prefix to be added to globals in the template.
- suffix: a suffix to be added to global in the template.
- types: the map from generic type names to concrete ones.
- consts: the map from constant names to their values.
- imports: the map from imports used in types/consts to their import paths.
- package: the name of the package the instantiated template will be compiled into.
-"""
go_template_instance = rule(
implementation = _go_template_instance_impl,
attrs = {
- "template": attr.label(mandatory = True),
- "prefix": attr.string(),
- "suffix": attr.string(),
- "types": attr.string_dict(),
- "consts": attr.string_dict(),
- "imports": attr.string_dict(),
- "anon": attr.bool(mandatory = False, default = False),
- "package": attr.string(mandatory = False),
- "out": attr.output(mandatory = True),
+ "template": attr.label(doc = "the label of the template to be instantiated", mandatory = True),
+ "prefix": attr.string(doc = "a prefix to be added to globals in the template"),
+ "suffix": attr.string(doc = "a suffix to be added to globals in the template"),
+ "types": attr.string_dict(doc = "the map from generic type names to concrete ones"),
+ "consts": attr.string_dict(doc = "the map from constant names to their values"),
+ "imports": attr.string_dict(doc = "the map from imports used in types/consts to their import paths"),
+ "anon": attr.bool(doc = "whether anoymous fields should be processed", mandatory = False, default = False),
+ "package": attr.string(doc = "the package for the generated source file", mandatory = False),
+ "out": attr.output(doc = "output file", mandatory = True),
"_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_generics")),
},
)
diff --git a/tools/nogo/BUILD b/tools/nogo/BUILD
index dd4b46f58..12b8b597c 100644
--- a/tools/nogo/BUILD
+++ b/tools/nogo/BUILD
@@ -1,8 +1,15 @@
-load("//tools:defs.bzl", "bzl_library", "go_library")
-load("//tools/nogo:defs.bzl", "nogo_objdump_tool", "nogo_stdlib")
+load("//tools:defs.bzl", "bzl_library", "go_library", "select_goarch", "select_goos")
+load("//tools/nogo:defs.bzl", "nogo_objdump_tool", "nogo_stdlib", "nogo_target")
package(licenses = ["notice"])
+nogo_target(
+ name = "target",
+ goarch = select_goarch(),
+ goos = select_goos(),
+ visibility = ["//visibility:public"],
+)
+
nogo_objdump_tool(
name = "objdump_tool",
visibility = ["//visibility:public"],
@@ -13,20 +20,14 @@ nogo_stdlib(
visibility = ["//visibility:public"],
)
-sh_binary(
- name = "gentest",
- srcs = ["gentest.sh"],
- visibility = ["//visibility:public"],
-)
-
go_library(
name = "nogo",
srcs = [
+ "analyzers.go",
"build.go",
"config.go",
- "matchers.go",
+ "findings.go",
"nogo.go",
- "register.go",
],
nogo = False,
visibility = ["//:sandbox"],
diff --git a/tools/nogo/analyzers.go b/tools/nogo/analyzers.go
new file mode 100644
index 000000000..b919bc2f8
--- /dev/null
+++ b/tools/nogo/analyzers.go
@@ -0,0 +1,131 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package nogo
+
+import (
+ "encoding/gob"
+
+ "golang.org/x/tools/go/analysis"
+ "golang.org/x/tools/go/analysis/passes/asmdecl"
+ "golang.org/x/tools/go/analysis/passes/assign"
+ "golang.org/x/tools/go/analysis/passes/atomic"
+ "golang.org/x/tools/go/analysis/passes/bools"
+ "golang.org/x/tools/go/analysis/passes/buildtag"
+ "golang.org/x/tools/go/analysis/passes/cgocall"
+ "golang.org/x/tools/go/analysis/passes/composite"
+ "golang.org/x/tools/go/analysis/passes/copylock"
+ "golang.org/x/tools/go/analysis/passes/errorsas"
+ "golang.org/x/tools/go/analysis/passes/httpresponse"
+ "golang.org/x/tools/go/analysis/passes/loopclosure"
+ "golang.org/x/tools/go/analysis/passes/lostcancel"
+ "golang.org/x/tools/go/analysis/passes/nilfunc"
+ "golang.org/x/tools/go/analysis/passes/nilness"
+ "golang.org/x/tools/go/analysis/passes/printf"
+ "golang.org/x/tools/go/analysis/passes/shadow"
+ "golang.org/x/tools/go/analysis/passes/shift"
+ "golang.org/x/tools/go/analysis/passes/stdmethods"
+ "golang.org/x/tools/go/analysis/passes/stringintconv"
+ "golang.org/x/tools/go/analysis/passes/structtag"
+ "golang.org/x/tools/go/analysis/passes/tests"
+ "golang.org/x/tools/go/analysis/passes/unmarshal"
+ "golang.org/x/tools/go/analysis/passes/unreachable"
+ "golang.org/x/tools/go/analysis/passes/unsafeptr"
+ "golang.org/x/tools/go/analysis/passes/unusedresult"
+ "honnef.co/go/tools/staticcheck"
+ "honnef.co/go/tools/stylecheck"
+
+ "gvisor.dev/gvisor/tools/checkescape"
+ "gvisor.dev/gvisor/tools/checkunsafe"
+)
+
+// AllAnalyzers is a list of all available analyzers.
+var AllAnalyzers = []*analysis.Analyzer{
+ asmdecl.Analyzer,
+ assign.Analyzer,
+ atomic.Analyzer,
+ bools.Analyzer,
+ buildtag.Analyzer,
+ cgocall.Analyzer,
+ composite.Analyzer,
+ copylock.Analyzer,
+ errorsas.Analyzer,
+ httpresponse.Analyzer,
+ loopclosure.Analyzer,
+ lostcancel.Analyzer,
+ nilfunc.Analyzer,
+ nilness.Analyzer,
+ printf.Analyzer,
+ shift.Analyzer,
+ stdmethods.Analyzer,
+ stringintconv.Analyzer,
+ shadow.Analyzer,
+ structtag.Analyzer,
+ tests.Analyzer,
+ unmarshal.Analyzer,
+ unreachable.Analyzer,
+ unsafeptr.Analyzer,
+ unusedresult.Analyzer,
+ checkescape.Analyzer,
+ checkunsafe.Analyzer,
+}
+
+// EscapeAnalyzers is a list of escape-related analyzers.
+var EscapeAnalyzers = []*analysis.Analyzer{
+ checkescape.EscapeAnalyzer,
+}
+
+func register(all []*analysis.Analyzer) {
+ // Register all fact types.
+ //
+ // N.B. This needs to be done recursively, because there may be
+ // analyzers in the Requires list that do not appear explicitly above.
+ registered := make(map[*analysis.Analyzer]struct{})
+ var registerOne func(*analysis.Analyzer)
+ registerOne = func(a *analysis.Analyzer) {
+ if _, ok := registered[a]; ok {
+ return
+ }
+
+ // Register dependencies.
+ for _, da := range a.Requires {
+ registerOne(da)
+ }
+
+ // Register local facts.
+ for _, f := range a.FactTypes {
+ gob.Register(f)
+ }
+
+ registered[a] = struct{}{} // Done.
+ }
+ for _, a := range all {
+ registerOne(a)
+ }
+}
+
+func init() {
+ // Add all staticcheck analyzers.
+ for _, a := range staticcheck.Analyzers {
+ AllAnalyzers = append(AllAnalyzers, a)
+ }
+ // Add all stylecheck analyzers.
+ for _, a := range stylecheck.Analyzers {
+ AllAnalyzers = append(AllAnalyzers, a)
+ }
+
+ // Register lists.
+ register(AllAnalyzers)
+ register(EscapeAnalyzers)
+}
diff --git a/tools/nogo/build.go b/tools/nogo/build.go
index 55d34760f..d173cff1f 100644
--- a/tools/nogo/build.go
+++ b/tools/nogo/build.go
@@ -20,22 +20,6 @@ import (
"os"
)
-var (
- // internalPrefix is the internal path prefix. Note that this is not
- // special, as paths should be passed relative to the repository root
- // and should not have any special prefix applied.
- internalPrefix = fmt.Sprintf("^")
-
- // internalDefault is applied when no paths are provided.
- internalDefault = fmt.Sprintf("%s/.*", notPath("external"))
-
- // generatedPrefix is a regex for generated files.
- generatedPrefix = "^(.*/)?(bazel-genfiles|bazel-out|bazel-bin)/"
-
- // externalPrefix is external workspace packages.
- externalPrefix = "^external/"
-)
-
// findStdPkg needs to find the bundled standard library packages.
func findStdPkg(GOOS, GOARCH, path string) (io.ReadCloser, error) {
if path == "C" {
diff --git a/tools/nogo/check/BUILD b/tools/nogo/check/BUILD
index 21ba2c306..e18483a18 100644
--- a/tools/nogo/check/BUILD
+++ b/tools/nogo/check/BUILD
@@ -2,8 +2,6 @@ load("//tools:defs.bzl", "go_binary")
package(licenses = ["notice"])
-# Note that the check binary must be public, since an aspect may be applied
-# across lots of different rules in different repositories.
go_binary(
name = "check",
srcs = ["main.go"],
diff --git a/tools/nogo/check/main.go b/tools/nogo/check/main.go
index 3828edf3a..69bdfe502 100644
--- a/tools/nogo/check/main.go
+++ b/tools/nogo/check/main.go
@@ -16,9 +16,99 @@
package main
import (
+ "encoding/json"
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "os"
+
"gvisor.dev/gvisor/tools/nogo"
)
+var (
+ packageFile = flag.String("package", "", "package configuration file (in JSON format)")
+ stdlibFile = flag.String("stdlib", "", "stdlib configuration file (in JSON format)")
+ findingsOutput = flag.String("findings", "", "output file (or stdout, if not specified)")
+ factsOutput = flag.String("facts", "", "output file for facts (optional)")
+ escapesOutput = flag.String("escapes", "", "output file for escapes (optional)")
+)
+
+func loadConfig(file string, config interface{}) interface{} {
+ // Load the configuration.
+ f, err := os.Open(file)
+ if err != nil {
+ log.Fatalf("unable to open configuration %q: %v", file, err)
+ }
+ defer f.Close()
+ dec := json.NewDecoder(f)
+ dec.DisallowUnknownFields()
+ if err := dec.Decode(config); err != nil {
+ log.Fatalf("unable to decode configuration: %v", err)
+ }
+ return config
+}
+
func main() {
- nogo.Main()
+ // Parse all flags.
+ flag.Parse()
+
+ var (
+ findings []nogo.Finding
+ factData []byte
+ err error
+ )
+
+ // Check & load the configuration.
+ if *packageFile != "" && *stdlibFile != "" {
+ log.Fatalf("unable to perform stdlib and package analysis; provide only one!")
+ }
+
+ // Run the configuration.
+ if *stdlibFile != "" {
+ // Perform basic analysis.
+ c := loadConfig(*stdlibFile, new(nogo.StdlibConfig)).(*nogo.StdlibConfig)
+ findings, factData, err = nogo.CheckStdlib(c, nogo.AllAnalyzers)
+
+ } else if *packageFile != "" {
+ // Perform basic analysis.
+ c := loadConfig(*packageFile, new(nogo.PackageConfig)).(*nogo.PackageConfig)
+ findings, factData, err = nogo.CheckPackage(c, nogo.AllAnalyzers, nil)
+
+ // Do we need to do escape analysis?
+ if *escapesOutput != "" {
+ escapes, _, err := nogo.CheckPackage(c, nogo.EscapeAnalyzers, nil)
+ if err != nil {
+ log.Fatalf("error performing escape analysis: %v", err)
+ }
+ if err := nogo.WriteFindingsToFile(escapes, *escapesOutput); err != nil {
+ log.Fatalf("error writing escapes to %q: %v", *escapesOutput, err)
+ }
+ }
+ } else {
+ log.Fatalf("please provide at least one of package or stdlib!")
+ }
+
+ // Check that analysis was successful.
+ if err != nil {
+ log.Fatalf("error performing analysis: %v", err)
+ }
+
+ // Save facts.
+ if *factsOutput != "" {
+ if err := ioutil.WriteFile(*factsOutput, factData, 0644); err != nil {
+ log.Fatalf("error saving findings to %q: %v", *factsOutput, err)
+ }
+ }
+
+ // Write all findings.
+ if *findingsOutput != "" {
+ if err := nogo.WriteFindingsToFile(findings, *findingsOutput); err != nil {
+ log.Fatalf("error writing findings to %q: %v", *findingsOutput, err)
+ }
+ } else {
+ for _, finding := range findings {
+ fmt.Fprintf(os.Stdout, "%s\n", finding.String())
+ }
+ }
}
diff --git a/tools/nogo/config.go b/tools/nogo/config.go
index 8079618ab..2fea5b3e1 100644
--- a/tools/nogo/config.go
+++ b/tools/nogo/config.go
@@ -15,543 +15,247 @@
package nogo
import (
- "golang.org/x/tools/go/analysis"
- "golang.org/x/tools/go/analysis/passes/asmdecl"
- "golang.org/x/tools/go/analysis/passes/assign"
- "golang.org/x/tools/go/analysis/passes/atomic"
- "golang.org/x/tools/go/analysis/passes/bools"
- "golang.org/x/tools/go/analysis/passes/buildtag"
- "golang.org/x/tools/go/analysis/passes/cgocall"
- "golang.org/x/tools/go/analysis/passes/composite"
- "golang.org/x/tools/go/analysis/passes/copylock"
- "golang.org/x/tools/go/analysis/passes/errorsas"
- "golang.org/x/tools/go/analysis/passes/httpresponse"
- "golang.org/x/tools/go/analysis/passes/loopclosure"
- "golang.org/x/tools/go/analysis/passes/lostcancel"
- "golang.org/x/tools/go/analysis/passes/nilfunc"
- "golang.org/x/tools/go/analysis/passes/nilness"
- "golang.org/x/tools/go/analysis/passes/printf"
- "golang.org/x/tools/go/analysis/passes/shadow"
- "golang.org/x/tools/go/analysis/passes/shift"
- "golang.org/x/tools/go/analysis/passes/stdmethods"
- "golang.org/x/tools/go/analysis/passes/stringintconv"
- "golang.org/x/tools/go/analysis/passes/structtag"
- "golang.org/x/tools/go/analysis/passes/tests"
- "golang.org/x/tools/go/analysis/passes/unmarshal"
- "golang.org/x/tools/go/analysis/passes/unreachable"
- "golang.org/x/tools/go/analysis/passes/unsafeptr"
- "golang.org/x/tools/go/analysis/passes/unusedresult"
- "honnef.co/go/tools/staticcheck"
- "honnef.co/go/tools/stylecheck"
-
- "gvisor.dev/gvisor/tools/checkescape"
- "gvisor.dev/gvisor/tools/checkunsafe"
+ "fmt"
+ "regexp"
)
-var analyzerConfig = map[*analysis.Analyzer]matcher{
- // Standard analyzers.
- asmdecl.Analyzer: alwaysMatches(),
- assign.Analyzer: externalExcluded(
- ".*gazelle/walk/walk.go", // False positive.
- ),
- atomic.Analyzer: alwaysMatches(),
- bools.Analyzer: alwaysMatches(),
- buildtag.Analyzer: alwaysMatches(),
- cgocall.Analyzer: alwaysMatches(),
- composite.Analyzer: and(
- disableMatches(), // Disabled for now.
- resultExcluded{
- "Object_",
- "Range{",
- },
- ),
- copylock.Analyzer: internalMatches(), // Common external issues (e.g. protos).
- errorsas.Analyzer: alwaysMatches(),
- httpresponse.Analyzer: alwaysMatches(),
- loopclosure.Analyzer: alwaysMatches(),
- lostcancel.Analyzer: internalMatches(), // Common external issues.
- nilfunc.Analyzer: alwaysMatches(),
- nilness.Analyzer: and(
- internalMatches(), // Common "tautological checks".
- internalExcluded(
- "pkg/sentry/platform/kvm/kvm_test.go", // Intentional.
- "tools/bigquery/bigquery.go", // False positive.
- ),
- ),
- printf.Analyzer: alwaysMatches(),
- shift.Analyzer: alwaysMatches(),
- stdmethods.Analyzer: internalMatches(), // Common external issues (e.g. methods named "Write").
- stringintconv.Analyzer: and(
- internalExcluded(),
- externalExcluded(
- ".*protobuf/.*.go", // Bad conversions.
- ".*flate/huffman_bit_writer.go", // Bad conversion.
+// GroupName is a named group.
+type GroupName string
+
+// AnalyzerName is a named analyzer.
+type AnalyzerName string
+
+// Group represents a named collection of files.
+type Group struct {
+ // Name is the short name for the group.
+ Name GroupName `yaml:"name"`
+
+ // Regex matches all full paths in the group.
+ Regex string `yaml:"regex"`
+ regex *regexp.Regexp `yaml:"-"`
+
+ // Default determines the default group behavior.
+ //
+ // If Default is true, all Analyzers are enabled for this
+ // group. Otherwise, Analyzers must be individually enabled
+ // by specifying a (possible empty) ItemConfig for the group
+ // in the AnalyzerConfig.
+ Default bool `yaml:"default"`
+}
+
+func (g *Group) compile() error {
+ r, err := regexp.Compile(g.Regex)
+ if err != nil {
+ return err
+ }
+ g.regex = r
+ return nil
+}
+
+// ItemConfig is an (Analyzer,Group) configuration.
+type ItemConfig struct {
+ // Exclude are analyzer exclusions.
+ //
+ // Exclude is a list of regular expressions. If the corresponding
+ // Analyzer emits a Finding for which Finding.Position.String()
+ // matches a regular expression in Exclude, the finding will not
+ // be reported.
+ Exclude []string `yaml:"exclude,omitempty"`
+ exclude []*regexp.Regexp `yaml:"-"`
+
+ // Suppress are analyzer suppressions.
+ //
+ // Suppress is a list of regular expressions. If the corresponding
+ // Analyzer emits a Finding for which Finding.Message matches a regular
+ // expression in Suppress, the finding will not be reported.
+ Suppress []string `yaml:"suppress,omitempty"`
+ suppress []*regexp.Regexp `yaml:"-"`
+}
+
+func compileRegexps(ss []string, rs *[]*regexp.Regexp) error {
+ *rs = make([]*regexp.Regexp, 0, len(ss))
+ for _, s := range ss {
+ r, err := regexp.Compile(s)
+ if err != nil {
+ return err
+ }
+ *rs = append(*rs, r)
+ }
+ return nil
+}
+
+func (i *ItemConfig) compile() error {
+ if i == nil {
+ // This may be nil if nothing is included in the
+ // item configuration. That's fine, there's nothing
+ // to compile and nothing to exclude & suppress.
+ return nil
+ }
+ if err := compileRegexps(i.Exclude, &i.exclude); err != nil {
+ return fmt.Errorf("in exclude: %w", err)
+ }
+ if err := compileRegexps(i.Suppress, &i.suppress); err != nil {
+ return fmt.Errorf("in suppress: %w", err)
+ }
+ return nil
+}
+
+func (i *ItemConfig) merge(other *ItemConfig) {
+ i.Exclude = append(i.Exclude, other.Exclude...)
+ i.Suppress = append(i.Suppress, other.Suppress...)
+}
+
+func (i *ItemConfig) shouldReport(fullPos, msg string) bool {
+ if i == nil {
+ // See above.
+ return true
+ }
+ for _, r := range i.exclude {
+ if r.MatchString(fullPos) {
+ return false
+ }
+ }
+ for _, r := range i.suppress {
+ if r.MatchString(msg) {
+ return false
+ }
+ }
+ return true
+}
+
+// AnalyzerConfig is the configuration for a single analyzers.
+//
+// This map is keyed by individual Group names, to allow for different
+// configurations depending on what Group the file belongs to.
+type AnalyzerConfig map[GroupName]*ItemConfig
+
+func (a AnalyzerConfig) compile() error {
+ for name, gc := range a {
+ if err := gc.compile(); err != nil {
+ return fmt.Errorf("invalid group %q: %v", name, err)
+ }
+ }
+ return nil
+}
+
+func (a AnalyzerConfig) merge(other AnalyzerConfig) {
+ // Merge all the groups.
+ for name, gc := range other {
+ old, ok := a[name]
+ if !ok || old == nil {
+ a[name] = gc // Not configured in a.
+ continue
+ }
+ old.merge(gc)
+ }
+}
+
+func (a AnalyzerConfig) shouldReport(groupConfig *Group, fullPos, msg string) bool {
+ gc, ok := a[groupConfig.Name]
+ if !ok {
+ return groupConfig.Default
+ }
+
+ // Note that if a section appears for a particular group
+ // for a particular analyzer, then it will now be enabled,
+ // and the group default no longer applies.
+ return gc.shouldReport(fullPos, msg)
+}
+
+// Config is a nogo configuration.
+type Config struct {
+ // Prefixes defines a set of regular expressions that
+ // are standard "prefixes", so that files can be grouped
+ // and specific rules applied to individual groups.
+ Groups []Group `yaml:"groups"`
- // Runtime internal violations.
- ".*reflect/value.go",
- ".*encoding/xml/xml.go",
- ".*runtime/pprof/internal/profile/proto.go",
- ".*fmt/scan.go",
- ".*go/types/conversions.go",
- ".*golang.org/x/net/dns/dnsmessage/message.go",
- ),
- ),
- shadow.Analyzer: disableMatches(), // Disabled for now.
- structtag.Analyzer: internalMatches(), // External not subject to rules.
- tests.Analyzer: alwaysMatches(),
- unmarshal.Analyzer: alwaysMatches(),
- unreachable.Analyzer: internalMatches(),
- unsafeptr.Analyzer: and(
- internalMatches(),
- internalExcluded(
- ".*_test.go", // Exclude tests.
- "pkg/flipcall/.*_unsafe.go", // Special case.
- "pkg/gohacks/gohacks_unsafe.go", // Special case.
- "pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go", // Special case.
- "pkg/sentry/platform/kvm/bluepill_unsafe.go", // Special case.
- "pkg/sentry/platform/kvm/machine_unsafe.go", // Special case.
- "pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go", // Special case.
- "pkg/sentry/platform/safecopy/safecopy_unsafe.go", // Special case.
- "pkg/sentry/vfs/mount_unsafe.go", // Special case.
- "pkg/sentry/platform/systrap/stub_unsafe.go", // Special case.
- "pkg/sentry/platform/systrap/switchto_google_unsafe.go", // Special case.
- "pkg/sentry/platform/systrap/sysmsg_thread_unsafe.go", // Special case.
- ),
- ),
- unusedresult.Analyzer: alwaysMatches(),
+ // Global is the global analyzer config.
+ Global AnalyzerConfig `yaml:"global"`
- // Internal analyzers: external packages not subject.
- checkescape.Analyzer: internalMatches(),
- checkunsafe.Analyzer: internalMatches(),
+ // Analyzers are individual analyzer configurations. The
+ // key for each analyzer is the name of the analyzer. The
+ // value is either a boolean (enable/disable), or a map to
+ // the groups above.
+ Analyzers map[AnalyzerName]AnalyzerConfig `yaml:"analyzers"`
}
-func init() {
- staticMatcher := and(
- // Only match internal, non-generated files.
- internalMatches(),
- generatedExcluded(),
+// Merge merges two configurations.
+func (c *Config) Merge(other *Config) {
+ // Merge all groups.
+ for _, g := range other.Groups {
+ // Is there a matching group? If yes, we just delete
+ // it. This will preserve the order provided in the
+ // overriding file, even if it differs.
+ for i := 0; i < len(c.Groups); i++ {
+ if g.Name == c.Groups[i].Name {
+ copy(c.Groups[i:], c.Groups[i+1:])
+ c.Groups = c.Groups[:len(c.Groups)-1]
+ break
+ }
+ }
+ c.Groups = append(c.Groups, g)
+ }
- // We use ALL_CAPS for system definitions,
- // which are common enough in the code base
- // that we shouldn't annotate exceptions.
- //
- // Same story for underscores.
- resultExcluded([]string{
- "should not use ALL_CAPS in Go names",
- "should not use underscores in Go names",
- }),
+ // Merge global configurations.
+ c.Global.merge(other.Global)
- // Exclude existing matches.
- internalExcluded(
- "pkg/abi/linux/fuse.go:22",
- "pkg/abi/linux/fuse.go:25",
- "pkg/abi/linux/socket.go:113",
- "pkg/abi/linux/tty.go:73",
- "pkg/bpf/decoder.go:112",
- "pkg/cpuid/cpuid_x86.go:675",
- "pkg/eventchannel/event.go:193",
- "pkg/eventchannel/event.go:27",
- "pkg/eventchannel/event_test.go:22",
- "pkg/eventchannel/rate.go:19",
- "pkg/gohacks/gohacks_unsafe.go:33",
- "pkg/log/json.go:30",
- "pkg/log/log.go:359",
- "pkg/merkletree/merkletree.go:230",
- "pkg/merkletree/merkletree.go:243",
- "pkg/merkletree/merkletree.go:249",
- "pkg/merkletree/merkletree.go:266",
- "pkg/merkletree/merkletree.go:355",
- "pkg/merkletree/merkletree.go:369",
- "pkg/metric/metric_test.go:20",
- "pkg/p9/p9test/client_test.go:687",
- "pkg/p9/transport_test.go:196",
- "pkg/pool/pool.go:15",
- "pkg/refs/refcounter.go:510",
- "pkg/refs/refcounter_test.go:169",
- "pkg/refs_vfs2/refs.go:16",
- "pkg/safemem/block_unsafe.go:89",
- "pkg/seccomp/seccomp.go:82",
- "pkg/segment/test/set_functions.go:15",
- "pkg/sentry/arch/signal.go:166",
- "pkg/sentry/arch/signal.go:171",
- "pkg/sentry/control/pprof.go:196",
- "pkg/sentry/devices/memdev/full.go:58",
- "pkg/sentry/devices/memdev/null.go:59",
- "pkg/sentry/devices/memdev/random.go:68",
- "pkg/sentry/devices/memdev/zero.go:86",
- "pkg/sentry/fdimport/fdimport.go:15",
- "pkg/sentry/fs/attr.go:257",
- "pkg/sentry/fsbridge/fs.go:116",
- "pkg/sentry/fsbridge/vfs.go:124",
- "pkg/sentry/fsbridge/vfs.go:70",
- "pkg/sentry/fs/copy_up.go:365",
- "pkg/sentry/fs/copy_up_test.go:65",
- "pkg/sentry/fs/dev/net_tun.go:161",
- "pkg/sentry/fs/dev/net_tun.go:63",
- "pkg/sentry/fs/dev/null.go:97",
- "pkg/sentry/fs/dirent_cache.go:64",
- "pkg/sentry/fs/file_overlay.go:327",
- "pkg/sentry/fs/file_overlay.go:524",
- "pkg/sentry/fs/filetest/filetest.go:55",
- "pkg/sentry/fs/filetest/filetest.go:60",
- "pkg/sentry/fs/fs.go:77",
- "pkg/sentry/fs/fsutil/file.go:290",
- "pkg/sentry/fs/fsutil/file.go:346",
- "pkg/sentry/fs/fsutil/host_file_mapper.go:105",
- "pkg/sentry/fs/fsutil/inode_cached.go:676",
- "pkg/sentry/fs/fsutil/inode_cached.go:772",
- "pkg/sentry/fs/gofer/attr.go:120",
- "pkg/sentry/fs/gofer/fifo.go:33",
- "pkg/sentry/fs/gofer/inode.go:410",
- "pkg/sentry/fsimpl/devpts/devpts.go:110",
- "pkg/sentry/fsimpl/devpts/devpts.go:246",
- "pkg/sentry/fsimpl/devpts/devpts.go:50",
- "pkg/sentry/fsimpl/devpts/master.go:110",
- "pkg/sentry/fsimpl/devpts/master.go:55",
- "pkg/sentry/fsimpl/devpts/replica.go:113",
- "pkg/sentry/fsimpl/devpts/replica.go:57",
- "pkg/sentry/fsimpl/devtmpfs/devtmpfs.go:54",
- "pkg/sentry/fsimpl/ext/disklayout/superblock_64.go:97",
- "pkg/sentry/fsimpl/ext/disklayout/superblock_old.go:92",
- "pkg/sentry/fsimpl/ext/disklayout/block_group_32.go:44",
- "pkg/sentry/fsimpl/ext/disklayout/inode_new.go:91",
- "pkg/sentry/fsimpl/ext/disklayout/inode_old.go:93",
- "pkg/sentry/fsimpl/ext/disklayout/superblock_32.go:66",
- "pkg/sentry/fsimpl/ext/disklayout/block_group_64.go:53",
- "pkg/sentry/fsimpl/eventfd/eventfd.go:268",
- "pkg/sentry/fsimpl/ext/directory.go:163",
- "pkg/sentry/fsimpl/ext/directory.go:164",
- "pkg/sentry/fsimpl/ext/extent_file.go:142",
- "pkg/sentry/fsimpl/ext/extent_file.go:143",
- "pkg/sentry/fsimpl/ext/ext.go:105",
- "pkg/sentry/fsimpl/ext/filesystem.go:287",
- "pkg/sentry/fsimpl/ext/regular_file.go:153",
- "pkg/sentry/fsimpl/ext/symlink.go:113",
- "pkg/sentry/fsimpl/fuse/connection_control.go:194",
- "pkg/sentry/fsimpl/fuse/dev.go:387",
- "pkg/sentry/fsimpl/fuse/dev_test.go:318",
- "pkg/sentry/fsimpl/fuse/fusefs.go:102",
- "pkg/sentry/fsimpl/fuse/read_write.go:129",
- "pkg/sentry/fsimpl/fuse/request_response.go:71",
- "pkg/sentry/fsimpl/gofer/directory.go:135",
- "pkg/sentry/fsimpl/gofer/filesystem.go:679",
- "pkg/sentry/fsimpl/gofer/gofer.go:1694",
- "pkg/sentry/fsimpl/gofer/gofer.go:276",
- "pkg/sentry/fsimpl/gofer/regular_file.go:81",
- "pkg/sentry/fsimpl/gofer/special_file.go:141",
- "pkg/sentry/fsimpl/host/host.go:184",
- "pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go:50",
- "pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go:90",
- "pkg/sentry/fsimpl/kernfs/fd_impl_util.go:273",
- "pkg/sentry/fsimpl/kernfs/filesystem.go:247",
- "pkg/sentry/fsimpl/kernfs/inode_impl_util.go:320",
- "pkg/sentry/fsimpl/kernfs/inode_impl_util.go:497",
- "pkg/sentry/fsimpl/kernfs/synthetic_directory.go:52",
- "pkg/sentry/fsimpl/overlay/directory.go:119",
- "pkg/sentry/fsimpl/overlay/filesystem.go:527",
- "pkg/sentry/fsimpl/overlay/non_directory.go:152",
- "pkg/sentry/fsimpl/overlay/overlay.go:115",
- "pkg/sentry/fsimpl/overlay/overlay.go:719",
- "pkg/sentry/fsimpl/pipefs/pipefs.go:74",
- "pkg/sentry/fsimpl/proc/filesystem.go:52",
- "pkg/sentry/fsimpl/proc/filesystem.go:81",
- "pkg/sentry/fsimpl/proc/subtasks.go:126",
- "pkg/sentry/fsimpl/proc/subtasks.go:189",
- "pkg/sentry/fsimpl/proc/task_fds.go:168",
- "pkg/sentry/fsimpl/proc/task_fds.go:228",
- "pkg/sentry/fsimpl/proc/task_fds.go:301",
- "pkg/sentry/fsimpl/proc/task_fds.go:318",
- "pkg/sentry/fsimpl/proc/task_fds.go:67",
- "pkg/sentry/fsimpl/proc/task_files.go:112",
- "pkg/sentry/fsimpl/proc/task_files.go:158",
- "pkg/sentry/fsimpl/proc/task_files.go:259",
- "pkg/sentry/fsimpl/proc/task_files.go:285",
- "pkg/sentry/fsimpl/proc/task_files.go:305",
- "pkg/sentry/fsimpl/proc/task_files.go:384",
- "pkg/sentry/fsimpl/proc/task_files.go:403",
- "pkg/sentry/fsimpl/proc/task_files.go:428",
- "pkg/sentry/fsimpl/proc/task_files.go:691",
- "pkg/sentry/fsimpl/proc/task_files.go:770",
- "pkg/sentry/fsimpl/proc/task_files.go:797",
- "pkg/sentry/fsimpl/proc/task_files.go:828",
- "pkg/sentry/fsimpl/proc/task_files.go:879",
- "pkg/sentry/fsimpl/proc/task_files.go:910",
- "pkg/sentry/fsimpl/proc/task_files.go:961",
- "pkg/sentry/fsimpl/proc/task.go:127",
- "pkg/sentry/fsimpl/proc/task.go:193",
- "pkg/sentry/fsimpl/proc/task_net.go:134",
- "pkg/sentry/fsimpl/proc/task_net.go:475",
- "pkg/sentry/fsimpl/proc/task_net.go:491",
- "pkg/sentry/fsimpl/proc/task_net.go:508",
- "pkg/sentry/fsimpl/proc/task_net.go:665",
- "pkg/sentry/fsimpl/proc/task_net.go:715",
- "pkg/sentry/fsimpl/proc/task_net.go:779",
- "pkg/sentry/fsimpl/proc/tasks_files.go:113",
- "pkg/sentry/fsimpl/proc/tasks_files.go:388",
- "pkg/sentry/fsimpl/proc/tasks.go:232",
- "pkg/sentry/fsimpl/proc/tasks_sys.go:145",
- "pkg/sentry/fsimpl/proc/tasks_sys.go:181",
- "pkg/sentry/fsimpl/proc/tasks_sys.go:239",
- "pkg/sentry/fsimpl/proc/tasks_sys.go:291",
- "pkg/sentry/fsimpl/proc/tasks_sys.go:375",
- "pkg/sentry/fsimpl/signalfd/signalfd.go:124",
- "pkg/sentry/fsimpl/signalfd/signalfd.go:15",
- "pkg/sentry/fsimpl/signalfd/signalfd.go:126",
- "pkg/sentry/fsimpl/sockfs/sockfs.go:36",
- "pkg/sentry/fsimpl/sockfs/sockfs.go:79",
- "pkg/sentry/fsimpl/sys/kcov.go:49",
- "pkg/sentry/fsimpl/sys/kcov.go:99",
- "pkg/sentry/fsimpl/sys/sys.go:118",
- "pkg/sentry/fsimpl/sys/sys.go:56",
- "pkg/sentry/fsimpl/testutil/testutil.go:257",
- "pkg/sentry/fsimpl/testutil/testutil.go:260",
- "pkg/sentry/fsimpl/timerfd/timerfd.go:87",
- "pkg/sentry/fsimpl/tmpfs/directory.go:112",
- "pkg/sentry/fsimpl/tmpfs/filesystem.go:195",
- "pkg/sentry/fsimpl/tmpfs/regular_file.go:226",
- "pkg/sentry/fsimpl/tmpfs/regular_file.go:346",
- "pkg/sentry/fsimpl/tmpfs/tmpfs.go:103",
- "pkg/sentry/fsimpl/tmpfs/tmpfs.go:733",
- "pkg/sentry/fsimpl/verity/filesystem.go:490",
- "pkg/sentry/fsimpl/verity/verity.go:156",
- "pkg/sentry/fsimpl/verity/verity.go:629",
- "pkg/sentry/fsimpl/verity/verity.go:672",
- "pkg/sentry/fs/mount.go:162",
- "pkg/sentry/fs/mount.go:256",
- "pkg/sentry/fs/mount_overlay.go:144",
- "pkg/sentry/fs/mounts.go:432",
- "pkg/sentry/fs/proc/exec_args.go:104",
- "pkg/sentry/fs/proc/exec_args.go:73",
- "pkg/sentry/fs/proc/fds.go:269",
- "pkg/sentry/fs/proc/loadavg.go:33",
- "pkg/sentry/fs/proc/meminfo.go:39",
- "pkg/sentry/fs/proc/mounts.go:193",
- "pkg/sentry/fs/proc/mounts.go:84",
- "pkg/sentry/fs/proc/net.go:125",
- "pkg/sentry/fs/proc/proc.go:146",
- "pkg/sentry/fs/proc/proc.go:204",
- "pkg/sentry/fs/proc/seqfile/seqfile.go:210",
- "pkg/sentry/fs/proc/sys.go:146",
- "pkg/sentry/fs/proc/sys.go:43",
- "pkg/sentry/fs/proc/sys_net.go:113",
- "pkg/sentry/fs/proc/sys_net.go:205",
- "pkg/sentry/fs/proc/sys_net.go:233",
- "pkg/sentry/fs/proc/sys_net.go:307",
- "pkg/sentry/fs/proc/sys_net.go:335",
- "pkg/sentry/fs/proc/sys_net.go:446",
- "pkg/sentry/fs/proc/sys_net.go:456",
- "pkg/sentry/fs/proc/sys_net.go:89",
- "pkg/sentry/fs/proc/task.go:170",
- "pkg/sentry/fs/proc/task.go:322",
- "pkg/sentry/fs/proc/task.go:427",
- "pkg/sentry/fs/proc/task.go:467",
- "pkg/sentry/fs/proc/task.go:500",
- "pkg/sentry/fs/proc/task.go:784",
- "pkg/sentry/fs/proc/task.go:839",
- "pkg/sentry/fs/proc/task.go:920",
- "pkg/sentry/fs/proc/uid_gid_map.go:108",
- "pkg/sentry/fs/proc/uid_gid_map.go:79",
- "pkg/sentry/fs/proc/uptime.go:75",
- "pkg/sentry/fs/ramfs/dir.go:447",
- "pkg/sentry/fs/tmpfs/inode_file.go:436",
- "pkg/sentry/fs/tmpfs/inode_file.go:537",
- "pkg/sentry/fs/tty/dir.go:313",
- "pkg/sentry/fs/tty/master.go:131",
- "pkg/sentry/fs/tty/master.go:91",
- "pkg/sentry/fs/tty/replica.go:116",
- "pkg/sentry/fs/tty/replica.go:88",
- "pkg/sentry/kernel/auth/id_map.go:269",
- "pkg/sentry/kernel/fasync/fasync.go:67",
- "pkg/sentry/kernel/kcov.go:209",
- "pkg/sentry/kernel/kcov.go:223",
- "pkg/sentry/kernel/kernel.go:343",
- "pkg/sentry/kernel/kernel.go:368",
- "pkg/sentry/kernel/pipe/node_test.go:112",
- "pkg/sentry/kernel/pipe/node_test.go:119",
- "pkg/sentry/kernel/pipe/node_test.go:130",
- "pkg/sentry/kernel/pipe/node_test.go:137",
- "pkg/sentry/kernel/pipe/node_test.go:149",
- "pkg/sentry/kernel/pipe/node_test.go:150",
- "pkg/sentry/kernel/pipe/node_test.go:158",
- "pkg/sentry/kernel/pipe/node_test.go:174",
- "pkg/sentry/kernel/pipe/node_test.go:180",
- "pkg/sentry/kernel/pipe/node_test.go:193",
- "pkg/sentry/kernel/pipe/node_test.go:202",
- "pkg/sentry/kernel/pipe/node_test.go:205",
- "pkg/sentry/kernel/pipe/node_test.go:216",
- "pkg/sentry/kernel/pipe/node_test.go:219",
- "pkg/sentry/kernel/pipe/node_test.go:271",
- "pkg/sentry/kernel/pipe/node_test.go:290",
- "pkg/sentry/kernel/pipe/pipe_test.go:93",
- "pkg/sentry/kernel/pipe/reader_writer.go:65",
- "pkg/sentry/kernel/posixtimer.go:157",
- "pkg/sentry/kernel/ptrace.go:218",
- "pkg/sentry/kernel/semaphore/semaphore.go:323",
- "pkg/sentry/kernel/sessions.go:123",
- "pkg/sentry/kernel/sessions.go:508",
- "pkg/sentry/kernel/signal_handlers.go:57",
- "pkg/sentry/kernel/task_context.go:72",
- "pkg/sentry/kernel/task_exit.go:67",
- "pkg/sentry/kernel/task_sched.go:255",
- "pkg/sentry/kernel/task_sched.go:280",
- "pkg/sentry/kernel/task_sched.go:323",
- "pkg/sentry/kernel/task_stop.go:192",
- "pkg/sentry/kernel/thread_group.go:530",
- "pkg/sentry/kernel/timekeeper.go:316",
- "pkg/sentry/kernel/vdso.go:106",
- "pkg/sentry/kernel/vdso.go:118",
- "pkg/sentry/memmap/memmap.go:103",
- "pkg/sentry/memmap/memmap.go:163",
- "pkg/sentry/mm/address_space.go:42",
- "pkg/sentry/mm/address_space.go:42",
- "pkg/sentry/mm/aio_context.go:208",
- "pkg/sentry/mm/aio_context.go:288",
- "pkg/sentry/mm/pma.go:683",
- "pkg/sentry/mm/special_mappable.go:80",
- "pkg/sentry/platform/systrap/subprocess.go:370",
- "pkg/sentry/platform/systrap/usertrap/usertrap_amd64.go:124",
- "pkg/sentry/socket/control/control.go:260",
- "pkg/sentry/socket/control/control.go:94",
- "pkg/sentry/socket/control/control_vfs2.go:37",
- "pkg/sentry/socket/hostinet/stack.go:433",
- "pkg/sentry/socket/hostinet/stack.go:438",
- "pkg/sentry/socket/hostinet/stack.go:444",
- "pkg/sentry/socket/hostinet/stack.go:460",
- "pkg/sentry/socket/netfilter/tcp_matcher.go:74",
- "pkg/sentry/socket/netfilter/udp_matcher.go:71",
- "pkg/sentry/socket/netlink/route/protocol.go:38",
- "pkg/sentry/socket/socket.go:332",
- "pkg/sentry/socket/unix/transport/connectioned.go:394",
- "pkg/sentry/socket/unix/transport/connectionless.go:152",
- "pkg/sentry/socket/unix/transport/unix.go:436",
- "pkg/sentry/socket/unix/transport/unix.go:490",
- "pkg/sentry/socket/unix/transport/unix.go:685",
- "pkg/sentry/socket/unix/transport/unix.go:795",
- "pkg/sentry/syscalls/linux/sys_sem.go:62",
- "pkg/sentry/syscalls/linux/sys_time.go:189",
- "pkg/sentry/usage/cpu.go:42",
- "pkg/sentry/vfs/anonfs.go:302",
- "pkg/sentry/vfs/anonfs.go:99",
- "pkg/sentry/vfs/dentry.go:214",
- "pkg/sentry/vfs/epoll.go:168",
- "pkg/sentry/vfs/epoll.go:314",
- "pkg/sentry/vfs/file_description.go:549",
- "pkg/sentry/vfs/file_description_impl_util.go:304",
- "pkg/sentry/vfs/file_description_impl_util.go:412",
- "pkg/sentry/vfs/filesystem.go:76",
- "pkg/sentry/vfs/lock.go:15",
- "pkg/sentry/vfs/lock.go:47",
- "pkg/sentry/vfs/memxattr/xattr.go:37",
- "pkg/sentry/vfs/mount.go:510",
- "pkg/sentry/vfs/mount.go:667",
- "pkg/sentry/vfs/mount_test.go:106",
- "pkg/sentry/vfs/mount_test.go:160",
- "pkg/sentry/vfs/mount_test.go:215",
- "pkg/sentry/vfs/mount_unsafe.go:153",
- "pkg/sentry/vfs/resolving_path.go:228",
- "pkg/sentry/vfs/vfs.go:897",
- "pkg/shim/runsc/runsc.go:16",
- "pkg/shim/runsc/utils.go:16",
- "pkg/shim/v1/proc/deleted_state.go:16",
- "pkg/shim/v1/proc/exec.go:16",
- "pkg/shim/v1/proc/exec_state.go:16",
- "pkg/shim/v1/proc/init.go:16",
- "pkg/shim/v1/proc/init_state.go:16",
- "pkg/shim/v1/proc/io.go:16",
- "pkg/shim/v1/proc/process.go:16",
- "pkg/shim/v1/proc/types.go:16",
- "pkg/shim/v1/proc/utils.go:16",
- "pkg/shim/v1/shim/api.go:16",
- "pkg/shim/v1/shim/platform.go:16",
- "pkg/shim/v1/shim/service.go:16",
- "pkg/shim/v1/utils/annotations.go:15",
- "pkg/shim/v1/utils/utils.go:15",
- "pkg/shim/v1/utils/volumes.go:15",
- "pkg/shim/v2/api.go:16",
- "pkg/shim/v2/epoll.go:18",
- "pkg/shim/v2/options/options.go:15",
- "pkg/shim/v2/options/options.go:24",
- "pkg/shim/v2/options/options.go:26",
- "pkg/shim/v2/runtimeoptions/runtimeoptions.go:16",
- "pkg/shim/v2/runtimeoptions/runtimeoptions_test.go:22",
- "pkg/shim/v2/service.go:15",
- "pkg/shim/v2/service_linux.go:18",
- "pkg/state/tests/integer_test.go:23",
- "pkg/state/tests/integer_test.go:28",
- "pkg/sync/rwmutex_test.go:105",
- "pkg/syserr/host_linux.go:35",
- "pkg/tcpip/adapters/gonet/gonet_test.go:144",
- "pkg/tcpip/adapters/gonet/gonet_test.go:415",
- "pkg/tcpip/adapters/gonet/gonet_test.go:99",
- "pkg/tcpip/buffer/view.go:238",
- "pkg/tcpip/buffer/view.go:238",
- "pkg/tcpip/buffer/view.go:246",
- "pkg/tcpip/header/tcp.go:151",
- "pkg/tcpip/link/sharedmem/pipe/pipe_test.go:493",
- "pkg/tcpip/stack/iptables.go:293",
- "pkg/tcpip/stack/iptables_types.go:277",
- "pkg/tcpip/stack/stack.go:553",
- "pkg/tcpip/stack/transport_test.go:30",
- "pkg/tcpip/transport/packet/endpoint.go:126",
- "pkg/tcpip/transport/raw/endpoint.go:145",
- "pkg/tcpip/transport/tcp/sack_scoreboard.go:167",
- "pkg/unet/unet_test.go:634",
- "pkg/unet/unet_test.go:662",
- "pkg/unet/unet_test.go:703",
- "pkg/unet/unet_test.go:98",
- "pkg/usermem/addr.go:34",
- "pkg/usermem/usermem.go:171",
- "pkg/usermem/usermem.go:170",
- "runsc/boot/compat.go:22",
- "runsc/boot/compat.go:56",
- "runsc/boot/loader.go:1115",
- "runsc/boot/loader.go:1120",
- "runsc/cmd/checkpoint.go:151",
- "runsc/config/flags.go:32",
- "runsc/container/container.go:641",
- "runsc/container/container.go:988",
- "runsc/specutils/specutils.go:172",
- "runsc/specutils/specutils.go:428",
- "runsc/specutils/specutils.go:436",
- "runsc/specutils/specutils.go:442",
- "runsc/specutils/specutils.go:447",
- "runsc/specutils/specutils.go:454",
- "test/cmd/test_app/fds.go:171",
- "test/iptables/filter_output.go:251",
- "test/packetimpact/testbench/connections.go:77",
- "tools/bigquery/bigquery.go:106",
- "tools/checkescape/test1/test1.go:108",
- "tools/checkescape/test1/test1.go:122",
- "tools/checkescape/test1/test1.go:137",
- "tools/checkescape/test1/test1.go:151",
- "tools/checkescape/test1/test1.go:170",
- "tools/checkescape/test1/test1.go:39",
- "tools/checkescape/test1/test1.go:45",
- "tools/checkescape/test1/test1.go:50",
- "tools/checkescape/test1/test1.go:64",
- "tools/checkescape/test1/test1.go:80",
- "tools/checkescape/test1/test1.go:94",
- "tools/go_generics/imports.go:51",
- "tools/go_generics/imports.go:75",
- "tools/go_marshal/gomarshal/generator.go:177",
- "tools/go_marshal/gomarshal/generator.go:81",
- "tools/go_marshal/gomarshal/generator.go:85",
- "tools/go_marshal/test/escape/escape.go:15",
- "tools/go_marshal/test/test.go:164",
- ),
- )
+ // Merge all analyzer configurations.
+ for name, ac := range other.Analyzers {
+ old, ok := c.Analyzers[name]
+ if !ok {
+ c.Analyzers[name] = ac // No analyzer in original config.
+ continue
+ }
+ old.merge(ac)
+ }
+}
- // Add all staticcheck analyzers; internal only.
- for _, a := range staticcheck.Analyzers {
- analyzerConfig[a] = staticMatcher
+// Compile compiles a configuration to make it useable.
+func (c *Config) Compile() error {
+ for i := 0; i < len(c.Groups); i++ {
+ if err := c.Groups[i].compile(); err != nil {
+ return fmt.Errorf("invalid group %q: %w", c.Groups[i].Name, err)
+ }
}
- // Add all stylecheck analyzers; internal only.
- for _, a := range stylecheck.Analyzers {
- analyzerConfig[a] = staticMatcher
+ if err := c.Global.compile(); err != nil {
+ return fmt.Errorf("invalid global: %w", err)
}
+ for name, ac := range c.Analyzers {
+ if err := ac.compile(); err != nil {
+ return fmt.Errorf("invalid analyzer %q: %w", name, err)
+ }
+ }
+ return nil
}
-var escapesConfig = map[*analysis.Analyzer]matcher{
- // Informational only: include all packages.
- checkescape.EscapeAnalyzer: alwaysMatches(),
+// ShouldReport returns true iff the finding should match the Config.
+func (c *Config) ShouldReport(finding Finding) bool {
+ fullPos := finding.Position.String()
+
+ // Find the matching group.
+ var groupConfig *Group
+ for i := 0; i < len(c.Groups); i++ {
+ if c.Groups[i].regex.MatchString(fullPos) {
+ groupConfig = &c.Groups[i]
+ break
+ }
+ }
+
+ // If there is no group matching this path, then
+ // we default to accept the finding.
+ if groupConfig == nil {
+ return true
+ }
+
+ // Suppress via global rule?
+ if !c.Global.shouldReport(groupConfig, fullPos, finding.Message) {
+ return false
+ }
+
+ // Try the analyzer config.
+ ac, ok := c.Analyzers[finding.Category]
+ if !ok {
+ return groupConfig.Default
+ }
+ return ac.shouldReport(groupConfig, fullPos, finding.Message)
}
diff --git a/tools/nogo/defs.bzl b/tools/nogo/defs.bzl
index c6fcfd402..b3d297308 100644
--- a/tools/nogo/defs.bzl
+++ b/tools/nogo/defs.bzl
@@ -1,10 +1,59 @@
"""Nogo rules."""
-load("//tools/bazeldefs:defs.bzl", "go_context", "go_importpath", "go_rule", "go_test_library")
+load("//tools/bazeldefs:go.bzl", "go_context", "go_importpath", "go_rule", "go_test_library")
-def _nogo_objdump_tool_impl(ctx):
- go_ctx = go_context(ctx)
+NogoConfigInfo = provider(
+ "information about a nogo configuration",
+ fields = {
+ "srcs": "the collection of configuration files",
+ },
+)
+
+def _nogo_config_impl(ctx):
+ return [NogoConfigInfo(
+ srcs = ctx.files.srcs,
+ )]
+
+nogo_config = rule(
+ implementation = _nogo_config_impl,
+ attrs = {
+ "srcs": attr.label_list(
+ doc = "a list of yaml files (schema defined by tool/nogo/config.go).",
+ allow_files = True,
+ ),
+ },
+)
+
+NogoTargetInfo = provider(
+ "information about the Go target",
+ fields = {
+ "goarch": "the build architecture (GOARCH)",
+ "goos": "the build OS target (GOOS)",
+ },
+)
+
+def _nogo_target_impl(ctx):
+ return [NogoTargetInfo(
+ goarch = ctx.attr.goarch,
+ goos = ctx.attr.goos,
+ )]
+nogo_target = go_rule(
+ rule,
+ implementation = _nogo_target_impl,
+ attrs = {
+ "goarch": attr.string(
+ doc = "the Go build architecture (propagated to other rules).",
+ mandatory = True,
+ ),
+ "goos": attr.string(
+ doc = "the Go OS target (propagated to other rules).",
+ mandatory = True,
+ ),
+ },
+)
+
+def _nogo_objdump_tool_impl(ctx):
# Construct the magic dump command.
#
# Note that in some cases, the input is being fed into the tool via stdin.
@@ -12,6 +61,8 @@ def _nogo_objdump_tool_impl(ctx):
# we need the tool to handle this case by creating a temporary file.
#
# [1] https://github.com/golang/go/issues/41051
+ nogo_target_info = ctx.attr._nogo_target[NogoTargetInfo]
+ go_ctx = go_context(ctx, goos = nogo_target_info.goos, goarch = nogo_target_info.goarch)
env_prefix = " ".join(["%s=%s" % (key, value) for (key, value) in go_ctx.env.items()])
dumper = ctx.actions.declare_file(ctx.label.name)
ctx.actions.write(dumper, "\n".join([
@@ -42,6 +93,12 @@ def _nogo_objdump_tool_impl(ctx):
nogo_objdump_tool = go_rule(
rule,
implementation = _nogo_objdump_tool_impl,
+ attrs = {
+ "_nogo_target": attr.label(
+ default = "//tools/nogo:target",
+ cfg = "target",
+ ),
+ },
)
# NogoStdlibInfo is the set of standard library facts.
@@ -49,16 +106,16 @@ NogoStdlibInfo = provider(
"information for nogo analysis (standard library facts)",
fields = {
"facts": "serialized standard library facts",
- "findings": "package findings (if relevant)",
+ "raw_findings": "raw package findings (if relevant)",
},
)
def _nogo_stdlib_impl(ctx):
- go_ctx = go_context(ctx)
-
# Build the standard library facts.
+ nogo_target_info = ctx.attr._nogo_target[NogoTargetInfo]
+ go_ctx = go_context(ctx, goos = nogo_target_info.goos, goarch = nogo_target_info.goarch)
facts = ctx.actions.declare_file(ctx.label.name + ".facts")
- findings = ctx.actions.declare_file(ctx.label.name + ".findings")
+ raw_findings = ctx.actions.declare_file(ctx.label.name + ".raw_findings")
config = struct(
Srcs = [f.path for f in go_ctx.stdlib_srcs],
GOOS = go_ctx.goos,
@@ -69,15 +126,15 @@ def _nogo_stdlib_impl(ctx):
ctx.actions.write(config_file, config.to_json())
ctx.actions.run(
inputs = [config_file] + go_ctx.stdlib_srcs,
- outputs = [facts, findings],
- tools = depset(go_ctx.runfiles.to_list() + ctx.files._objdump_tool),
- executable = ctx.files._nogo[0],
- mnemonic = "GoStandardLibraryAnalysis",
+ outputs = [facts, raw_findings],
+ tools = depset(go_ctx.runfiles.to_list() + ctx.files._nogo_objdump_tool),
+ executable = ctx.files._nogo_check[0],
+ mnemonic = "NogoStandardLibraryAnalysis",
progress_message = "Analyzing Go Standard Library",
arguments = go_ctx.nogo_args + [
- "-objdump_tool=%s" % ctx.files._objdump_tool[0].path,
+ "-objdump_tool=%s" % ctx.files._nogo_objdump_tool[0].path,
"-stdlib=%s" % config_file.path,
- "-findings=%s" % findings.path,
+ "-findings=%s" % raw_findings.path,
"-facts=%s" % facts.path,
],
)
@@ -85,18 +142,24 @@ def _nogo_stdlib_impl(ctx):
# Return the stdlib facts as output.
return [NogoStdlibInfo(
facts = facts,
- findings = findings,
+ raw_findings = raw_findings,
)]
nogo_stdlib = go_rule(
rule,
implementation = _nogo_stdlib_impl,
attrs = {
- "_nogo": attr.label(
+ "_nogo_check": attr.label(
default = "//tools/nogo/check:check",
+ cfg = "host",
),
- "_objdump_tool": attr.label(
+ "_nogo_objdump_tool": attr.label(
default = "//tools/nogo:objdump_tool",
+ cfg = "host",
+ ),
+ "_nogo_target": attr.label(
+ default = "//tools/nogo:target",
+ cfg = "target",
),
},
)
@@ -110,23 +173,22 @@ NogoInfo = provider(
"information for nogo analysis",
fields = {
"facts": "serialized package facts",
- "findings": "package findings (if relevant)",
+ "raw_findings": "raw package findings (if relevant)",
+ "escapes": "escape-only findings (if relevant)",
"importpath": "package import path",
"binaries": "package binary files",
- "srcs": "original source files (for go_test support)",
- "deps": "original deps (for go_test support)",
+ "srcs": "srcs (for go_test support)",
+ "deps": "deps (for go_test support)",
},
)
def _nogo_aspect_impl(target, ctx):
- go_ctx = go_context(ctx)
-
# If this is a nogo rule itself (and not the shadow of a go_library or
# go_binary rule created by such a rule), then we simply return nothing.
# All work is done in the shadow properties for go rules. For a proto
# library, we simply skip the analysis portion but still need to return a
# valid NogoInfo to reference the generated binary.
- if ctx.rule.kind in ("go_library", "go_binary", "go_test", "go_tool_library"):
+ if ctx.rule.kind in ("go_library", "go_tool_library", "go_binary", "go_test"):
srcs = ctx.rule.files.srcs
deps = ctx.rule.attr.deps
elif ctx.rule.kind in ("go_proto_library", "go_wrap_cc"):
@@ -178,6 +240,7 @@ def _nogo_aspect_impl(target, ctx):
# Collect all info from shadow dependencies.
fact_map = dict()
import_map = dict()
+ all_raw_findings = []
for dep in deps:
# There will be no file attribute set for all transitive dependencies
# that are not go_library or go_binary rules, such as a proto rules.
@@ -195,17 +258,23 @@ def _nogo_aspect_impl(target, ctx):
import_map[info.importpath] = x_files[0]
fact_map[info.importpath] = info.facts.path
+ # Collect all findings; duplicates are resolved at the end.
+ all_raw_findings.extend(info.raw_findings)
+
# Ensure the above are available as inputs.
inputs.append(info.facts)
inputs += info.binaries
# Add the standard library facts.
- stdlib_facts = ctx.attr._nogo_stdlib[NogoStdlibInfo].facts
+ stdlib_info = ctx.attr._nogo_stdlib[NogoStdlibInfo]
+ stdlib_facts = stdlib_info.facts
inputs.append(stdlib_facts)
# The nogo tool operates on a configuration serialized in JSON format.
+ nogo_target_info = ctx.attr._nogo_target[NogoTargetInfo]
+ go_ctx = go_context(ctx, goos = nogo_target_info.goos, goarch = nogo_target_info.goarch)
facts = ctx.actions.declare_file(target.label.name + ".facts")
- findings = ctx.actions.declare_file(target.label.name + ".findings")
+ raw_findings = ctx.actions.declare_file(target.label.name + ".raw_findings")
escapes = ctx.actions.declare_file(target.label.name + ".escapes")
config = struct(
ImportPath = importpath,
@@ -223,39 +292,39 @@ def _nogo_aspect_impl(target, ctx):
inputs.append(config_file)
ctx.actions.run(
inputs = inputs,
- outputs = [facts, findings, escapes],
- tools = depset(go_ctx.runfiles.to_list() + ctx.files._objdump_tool),
- executable = ctx.files._nogo[0],
- mnemonic = "GoStaticAnalysis",
+ outputs = [facts, raw_findings, escapes],
+ tools = depset(go_ctx.runfiles.to_list() + ctx.files._nogo_objdump_tool),
+ executable = ctx.files._nogo_check[0],
+ mnemonic = "NogoAnalysis",
progress_message = "Analyzing %s" % target.label,
arguments = go_ctx.nogo_args + [
"-binary=%s" % target_objfile.path,
- "-objdump_tool=%s" % ctx.files._objdump_tool[0].path,
+ "-objdump_tool=%s" % ctx.files._nogo_objdump_tool[0].path,
"-package=%s" % config_file.path,
- "-findings=%s" % findings.path,
+ "-findings=%s" % raw_findings.path,
"-facts=%s" % facts.path,
"-escapes=%s" % escapes.path,
],
)
+ # Flatten all findings from all dependencies.
+ #
+ # This is done because all the filtering must be done at the
+ # top-level nogo_test to dynamically apply a configuration.
+ # This does not actually add any additional work here, but
+ # will simply propagate the full list of files.
+ all_raw_findings = [stdlib_info.raw_findings] + depset(all_raw_findings).to_list() + [raw_findings]
+
# Return the package facts as output.
- return [
- NogoInfo(
- facts = facts,
- findings = findings,
- importpath = importpath,
- binaries = binaries,
- srcs = srcs,
- deps = deps,
- ),
- OutputGroupInfo(
- # Expose all findings (should just be a single file). This can be
- # used for build analysis of the nogo findings.
- nogo_findings = depset([findings]),
- # Expose all escape analysis findings (see above).
- nogo_escapes = depset([escapes]),
- ),
- ]
+ return [NogoInfo(
+ facts = facts,
+ raw_findings = all_raw_findings,
+ escapes = escapes,
+ importpath = importpath,
+ binaries = binaries,
+ srcs = srcs,
+ deps = deps,
+ )]
nogo_aspect = go_rule(
aspect,
@@ -266,50 +335,94 @@ nogo_aspect = go_rule(
"embed",
],
attrs = {
- "_nogo": attr.label(default = "//tools/nogo/check:check"),
- "_nogo_stdlib": attr.label(default = "//tools/nogo:stdlib"),
- "_objdump_tool": attr.label(default = "//tools/nogo:objdump_tool"),
+ "_nogo_check": attr.label(
+ default = "//tools/nogo/check:check",
+ cfg = "host",
+ ),
+ "_nogo_stdlib": attr.label(
+ default = "//tools/nogo:stdlib",
+ cfg = "host",
+ ),
+ "_nogo_objdump_tool": attr.label(
+ default = "//tools/nogo:objdump_tool",
+ cfg = "host",
+ ),
+ "_nogo_target": attr.label(
+ default = "//tools/nogo:target",
+ cfg = "target",
+ ),
},
)
def _nogo_test_impl(ctx):
"""Check nogo findings."""
- # Build a runner that checks the facts files.
- findings = [dep[NogoInfo].findings for dep in ctx.attr.deps]
- runner = ctx.actions.declare_file(ctx.label.name)
+ # Ensure there's a single dependency.
+ if len(ctx.attr.deps) != 1:
+ fail("nogo_test requires exactly one dep.")
+ raw_findings = ctx.attr.deps[0][NogoInfo].raw_findings
+ escapes = ctx.attr.deps[0][NogoInfo].escapes
+
+ # Build a step that applies the configuration.
+ config_srcs = ctx.attr.config[NogoConfigInfo].srcs
+ findings = ctx.actions.declare_file(ctx.label.name + ".findings")
ctx.actions.run(
- inputs = findings + ctx.files.srcs,
- outputs = [runner],
- tools = depset(ctx.files._gentest),
- executable = ctx.files._gentest[0],
- mnemonic = "Gentest",
+ inputs = raw_findings + ctx.files.srcs + config_srcs,
+ outputs = [findings],
+ tools = depset(ctx.files._filter),
+ executable = ctx.files._filter[0],
+ mnemonic = "GoStaticAnalysis",
progress_message = "Generating %s" % ctx.label,
- arguments = [runner.path] + [f.path for f in findings],
+ arguments = ["-input=%s" % f.path for f in raw_findings] +
+ ["-config=%s" % f.path for f in config_srcs] +
+ ["-output=%s" % findings.path],
)
+
+ # Build a runner that checks the filtered facts.
+ #
+ # Note that this calls the filter binary without any configuration, so all
+ # findings will be included. But this is expected, since we've already
+ # filtered out everything that should not be included.
+ runner = ctx.actions.declare_file(ctx.label.name)
+ runner_content = [
+ "#!/bin/bash",
+ "exec %s -input=%s" % (ctx.files._filter[0].short_path, findings.short_path),
+ "",
+ ]
+ ctx.actions.write(runner, "\n".join(runner_content), is_executable = True)
+
return [DefaultInfo(
+ # The runner just executes the filter again, on the
+ # newly generated filtered findings. We still need
+ # the filter tool as part of our runfiles, however.
+ runfiles = ctx.runfiles(files = ctx.files._filter + [findings]),
executable = runner,
+ ), OutputGroupInfo(
+ # Propagate the filtered filters, for consumption by
+ # build tooling. Note that the build tooling typically
+ # pays attention to the mnemoic above, so this must be
+ # what is expected by the tooling.
+ nogo_findings = depset([findings]),
+ # Expose all escape analysis findings (see above).
+ nogo_escapes = depset([escapes]),
)]
-_nogo_test = rule(
+nogo_test = rule(
implementation = _nogo_test_impl,
attrs = {
- # deps should have only a single element.
- "deps": attr.label_list(aspects = [nogo_aspect]),
- # srcs exist here only to ensure that this target is
- # directly affected by changes to the source files.
- "srcs": attr.label_list(allow_files = True),
- "_gentest": attr.label(default = "//tools/nogo:gentest"),
+ "config": attr.label(
+ mandatory = True,
+ doc = "A rule of kind nogo_config.",
+ ),
+ "deps": attr.label_list(
+ aspects = [nogo_aspect],
+ doc = "Exactly one Go dependency to be analyzed.",
+ ),
+ "srcs": attr.label_list(
+ allow_files = True,
+ doc = "Relevant src files. This is ignored except to make the nogo_test directly affected by the files.",
+ ),
+ "_filter": attr.label(default = "//tools/nogo/filter:filter"),
},
test = True,
)
-
-def nogo_test(name, srcs, library, **kwargs):
- tags = kwargs.pop("tags", []) + ["nogo"]
- _nogo_test(
- name = name,
- srcs = srcs,
- deps = [library],
- tags = tags,
- **kwargs
- )
diff --git a/tools/nogo/filter/BUILD b/tools/nogo/filter/BUILD
new file mode 100644
index 000000000..e56a783e2
--- /dev/null
+++ b/tools/nogo/filter/BUILD
@@ -0,0 +1,14 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "filter",
+ srcs = ["main.go"],
+ nogo = False,
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tools/nogo",
+ "@in_gopkg_yaml_v2//:go_default_library",
+ ],
+)
diff --git a/tools/nogo/filter/main.go b/tools/nogo/filter/main.go
new file mode 100644
index 000000000..9cf41b3b0
--- /dev/null
+++ b/tools/nogo/filter/main.go
@@ -0,0 +1,131 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary check is the nogo entrypoint.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "os"
+ "strings"
+
+ yaml "gopkg.in/yaml.v2"
+ "gvisor.dev/gvisor/tools/nogo"
+)
+
+type stringList []string
+
+func (s *stringList) String() string {
+ return strings.Join(*s, ",")
+}
+
+func (s *stringList) Set(value string) error {
+ *s = append(*s, value)
+ return nil
+}
+
+var (
+ inputFiles stringList
+ configFiles stringList
+ outputFile string
+ showConfig bool
+)
+
+func init() {
+ flag.Var(&inputFiles, "input", "findings input files")
+ flag.StringVar(&outputFile, "output", "", "findings output file")
+ flag.Var(&configFiles, "config", "findings configuration files")
+ flag.BoolVar(&showConfig, "show-config", false, "dump configuration only")
+}
+
+func main() {
+ flag.Parse()
+
+ // Load all available findings.
+ var findings []nogo.Finding
+ for _, filename := range inputFiles {
+ inputFindings, err := nogo.ExtractFindingsFromFile(filename)
+ if err != nil {
+ log.Fatalf("unable to extract findings from %s: %v", filename, err)
+ }
+ findings = append(findings, inputFindings...)
+ }
+
+ // Open and merge all configuations.
+ config := &nogo.Config{
+ Global: make(nogo.AnalyzerConfig),
+ Analyzers: make(map[nogo.AnalyzerName]nogo.AnalyzerConfig),
+ }
+ for _, filename := range configFiles {
+ content, err := ioutil.ReadFile(filename)
+ if err != nil {
+ log.Fatalf("unable to read %s: %v", filename, err)
+ }
+ var newConfig nogo.Config // For current file.
+ if err := yaml.Unmarshal(content, &newConfig); err != nil {
+ log.Fatalf("unable to decode %s: %v", filename, err)
+ }
+ config.Merge(&newConfig)
+ if showConfig {
+ bytes, err := yaml.Marshal(&newConfig)
+ if err != nil {
+ log.Fatalf("error marshalling config: %v", err)
+ }
+ mergedBytes, err := yaml.Marshal(config)
+ if err != nil {
+ log.Fatalf("error marshalling config: %v", err)
+ }
+ fmt.Fprintf(os.Stdout, "Loaded configuration from %s:\n%s\n", filename, string(bytes))
+ fmt.Fprintf(os.Stdout, "Merged configuration:\n%s\n", string(mergedBytes))
+ }
+ }
+ if err := config.Compile(); err != nil {
+ log.Fatalf("error compiling config: %v", err)
+ }
+ if showConfig {
+ os.Exit(0)
+ }
+
+ // Filter the findings (and aggregate by group).
+ filteredFindings := make([]nogo.Finding, 0, len(findings))
+ for _, finding := range findings {
+ if ok := config.ShouldReport(finding); ok {
+ filteredFindings = append(filteredFindings, finding)
+ }
+ }
+
+ // Write the output (if required).
+ //
+ // If the outputFile is specified, then we exit here. Otherwise,
+ // we continue to write to stdout and treat like a test.
+ if outputFile != "" {
+ if err := nogo.WriteFindingsToFile(filteredFindings, outputFile); err != nil {
+ log.Fatalf("unable to write findings: %v", err)
+ }
+ return
+ }
+
+ // Treat the run as a test.
+ if len(filteredFindings) == 0 {
+ fmt.Fprintf(os.Stdout, "PASS\n")
+ os.Exit(0)
+ }
+ for _, finding := range filteredFindings {
+ fmt.Fprintf(os.Stdout, "%s\n", finding.String())
+ }
+ os.Exit(1)
+}
diff --git a/tools/nogo/findings.go b/tools/nogo/findings.go
new file mode 100644
index 000000000..5bd850269
--- /dev/null
+++ b/tools/nogo/findings.go
@@ -0,0 +1,63 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package nogo
+
+import (
+ "encoding/json"
+ "fmt"
+ "go/token"
+ "io/ioutil"
+)
+
+// Finding is a single finding.
+type Finding struct {
+ Category AnalyzerName
+ Position token.Position
+ Message string
+}
+
+// String implements fmt.Stringer.String.
+func (f *Finding) String() string {
+ return fmt.Sprintf("%s: %s: %s", f.Category, f.Position.String(), f.Message)
+}
+
+// WriteFindingsToFile writes findings to a file.
+func WriteFindingsToFile(findings []Finding, filename string) error {
+ content, err := WriteFindingsToBytes(findings)
+ if err != nil {
+ return err
+ }
+ return ioutil.WriteFile(filename, content, 0644)
+}
+
+// WriteFindingsToBytes serializes findings as bytes.
+func WriteFindingsToBytes(findings []Finding) ([]byte, error) {
+ return json.Marshal(findings)
+}
+
+// ExtractFindingsFromFile loads findings from a file.
+func ExtractFindingsFromFile(filename string) ([]Finding, error) {
+ content, err := ioutil.ReadFile(filename)
+ if err != nil {
+ return nil, err
+ }
+ return ExtractFindingsFromBytes(content)
+}
+
+// ExtractFindingsFromBytes loads findings from bytes.
+func ExtractFindingsFromBytes(content []byte) (findings []Finding, err error) {
+ err = json.Unmarshal(content, &findings)
+ return findings, err
+}
diff --git a/tools/nogo/gentest.sh b/tools/nogo/gentest.sh
deleted file mode 100755
index 033da11ad..000000000
--- a/tools/nogo/gentest.sh
+++ /dev/null
@@ -1,47 +0,0 @@
-#!/bin/bash
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-set -euo pipefail
-
-if [[ "$#" -lt 2 ]]; then
- echo "usage: $0 <output> <findings...>"
- exit 2
-fi
-declare violations=0
-declare output=$1
-shift
-
-# Start the script.
-echo "#!/bin/sh" > "${output}"
-
-# Read a list of findings files.
-declare filename
-declare line
-for filename in "$@"; do
- if [[ -z "${filename}" ]]; then
- continue
- fi
- while read -r line; do
- violations=$((${violations}+1));
- echo "echo -e '\\033[0;31m${line}\\033[0;31m\\033[0m'" >> "${output}"
- done < "${filename}"
-done
-
-# Show violations.
-if [[ "${violations}" -eq 0 ]]; then
- echo "echo -e '\\033[0;32mPASS\\033[0;31m\\033[0m'" >> "${output}"
-else
- echo "exit 1" >> "${output}"
-fi
diff --git a/tools/nogo/io_bazel_rules_go-visibility.patch b/tools/nogo/io_bazel_rules_go-visibility.patch
deleted file mode 100644
index 6b64b2e85..000000000
--- a/tools/nogo/io_bazel_rules_go-visibility.patch
+++ /dev/null
@@ -1,25 +0,0 @@
-diff --git a/third_party/org_golang_x_tools-extras.patch b/third_party/org_golang_x_tools-extras.patch
-index 133fbccc..5f0d9a47 100644
---- a/third_party/org_golang_x_tools-extras.patch
-+++ b/third_party/org_golang_x_tools-extras.patch
-@@ -32,7 +32,7 @@ diff -urN c/go/analysis/internal/facts/BUILD.bazel d/go/analysis/internal/facts/
-
- go_library(
- name = "go_default_library",
--@@ -14,6 +14,23 @@
-+@@ -14,6 +14,20 @@
- ],
- )
-
-@@ -43,10 +43,7 @@ diff -urN c/go/analysis/internal/facts/BUILD.bazel d/go/analysis/internal/facts/
- + "imports.go",
- + ],
- + importpath = "golang.org/x/tools/go/analysis/internal/facts",
--+ visibility = [
--+ "//go/analysis:__subpackages__",
--+ "@io_bazel_rules_go//go/tools/builders:__pkg__",
--+ ],
-++ visibility = ["//visibility:public"],
- + deps = [
- + "//go/analysis:go_tool_library",
- + "//go/types/objectpath:go_tool_library",
diff --git a/tools/nogo/matchers.go b/tools/nogo/matchers.go
deleted file mode 100644
index b7b73fa27..000000000
--- a/tools/nogo/matchers.go
+++ /dev/null
@@ -1,172 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package nogo
-
-import (
- "go/token"
- "regexp"
- "strings"
-
- "golang.org/x/tools/go/analysis"
-)
-
-type matcher interface {
- ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool
-}
-
-// pathRegexps filters explicit paths.
-type pathRegexps struct {
- expr []*regexp.Regexp
-
- // include, if true, indicates that paths matching any regexp in expr
- // match.
- //
- // If false, paths matching no regexps in expr match.
- include bool
-}
-
-// buildRegexps builds a list of regular expressions.
-//
-// This will panic on error.
-func buildRegexps(prefix string, args ...string) []*regexp.Regexp {
- result := make([]*regexp.Regexp, 0, len(args))
- for _, arg := range args {
- result = append(result, regexp.MustCompile(prefix+arg))
- }
- return result
-}
-
-// notPath works around the lack of backtracking.
-//
-// It is used to construct a regular expression for non-matching components.
-func notPath(name string) string {
- sb := strings.Builder{}
- sb.WriteString("(")
- for i := range name {
- if i > 0 {
- sb.WriteString("|")
- }
- sb.WriteString(name[:i])
- sb.WriteString("[^")
- sb.WriteByte(name[i])
- sb.WriteString("/][^/]*")
- }
- sb.WriteString(")")
- return sb.String()
-}
-
-// ShouldReport implements matcher.ShouldReport.
-func (p *pathRegexps) ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool {
- fullPos := fs.Position(d.Pos).String()
- for _, path := range p.expr {
- if path.MatchString(fullPos) {
- return p.include
- }
- }
- return !p.include
-}
-
-// internalExcluded excludes specific internal paths.
-func internalExcluded(paths ...string) *pathRegexps {
- return &pathRegexps{
- expr: buildRegexps(internalPrefix, paths...),
- include: false,
- }
-}
-
-// excludedExcluded excludes specific external paths.
-func externalExcluded(paths ...string) *pathRegexps {
- return &pathRegexps{
- expr: buildRegexps(externalPrefix, paths...),
- include: false,
- }
-}
-
-// internalMatches returns a path matcher for internal packages.
-func internalMatches() *pathRegexps {
- return &pathRegexps{
- expr: buildRegexps(internalPrefix, internalDefault),
- include: true,
- }
-}
-
-// generatedExcluded excludes all generated code.
-func generatedExcluded() *pathRegexps {
- return &pathRegexps{
- expr: buildRegexps(generatedPrefix, ".*"),
- include: false,
- }
-}
-
-// resultExcluded excludes explicit message contents.
-type resultExcluded []string
-
-// ShouldReport implements matcher.ShouldReport.
-func (r resultExcluded) ShouldReport(d analysis.Diagnostic, _ *token.FileSet) bool {
- for _, str := range r {
- if strings.Contains(d.Message, str) {
- return false
- }
- }
- return true // Not excluded.
-}
-
-// andMatcher is a composite matcher.
-type andMatcher struct {
- all []matcher
-}
-
-// ShouldReport implements matcher.ShouldReport.
-func (a *andMatcher) ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool {
- for _, m := range a.all {
- if !m.ShouldReport(d, fs) {
- return false
- }
- }
- return true
-}
-
-// and is a syntactic convension for andMatcher.
-func and(ms ...matcher) *andMatcher {
- return &andMatcher{
- all: ms,
- }
-}
-
-// anyMatcher matches everything.
-type anyMatcher struct{}
-
-// ShouldReport implements matcher.ShouldReport.
-func (anyMatcher) ShouldReport(analysis.Diagnostic, *token.FileSet) bool {
- return true
-}
-
-// alwaysMatches returns an anyMatcher instance.
-func alwaysMatches() anyMatcher {
- return anyMatcher{}
-}
-
-// neverMatcher will never match.
-type neverMatcher struct{}
-
-// ShouldReport implements matcher.ShouldReport.
-func (neverMatcher) ShouldReport(analysis.Diagnostic, *token.FileSet) bool {
- return false
-}
-
-// disableMatches returns a neverMatcher instance.
-func disableMatches() neverMatcher {
- return neverMatcher{}
-}
diff --git a/tools/nogo/nogo.go b/tools/nogo/nogo.go
index 120fdcff5..779d4d6d8 100644
--- a/tools/nogo/nogo.go
+++ b/tools/nogo/nogo.go
@@ -21,7 +21,6 @@ package nogo
import (
"encoding/json"
"errors"
- "flag"
"fmt"
"go/ast"
"go/build"
@@ -45,20 +44,20 @@ import (
"gvisor.dev/gvisor/tools/checkescape"
)
-// stdlibConfig is serialized as the configuration.
+// StdlibConfig is serialized as the configuration.
//
// This contains everything required for stdlib analysis.
-type stdlibConfig struct {
+type StdlibConfig struct {
Srcs []string
GOOS string
GOARCH string
Tags []string
}
-// packageConfig is serialized as the configuration.
+// PackageConfig is serialized as the configuration.
//
// This contains everything required for single package analysis.
-type packageConfig struct {
+type PackageConfig struct {
ImportPath string
GoFiles []string
NonGoFiles []string
@@ -84,7 +83,7 @@ type saver func([]byte) error
//
// This is done because all stdlib data is stored together, and we don't want
// to load this data many times over.
-func (c *packageConfig) factLoader() (loader, error) {
+func (c *PackageConfig) factLoader() (loader, error) {
allFacts := make(map[string][]byte)
if c.StdlibFacts != "" {
data, err := ioutil.ReadFile(c.StdlibFacts)
@@ -114,7 +113,7 @@ func (c *packageConfig) factLoader() (loader, error) {
// shouldInclude indicates whether the file should be included.
//
// NOTE: This does only basic parsing of tags.
-func (c *packageConfig) shouldInclude(path string) (bool, error) {
+func (c *PackageConfig) shouldInclude(path string) (bool, error) {
ctx := build.Default
ctx.GOOS = c.GOOS
ctx.GOARCH = c.GOARCH
@@ -128,7 +127,7 @@ func (c *packageConfig) shouldInclude(path string) (bool, error) {
// files, and the facts. Note that this importer implementation will always
// pass when a given package is not available.
type importer struct {
- *packageConfig
+ *PackageConfig
fset *token.FileSet
cache map[string]*types.Package
lastErr error
@@ -185,14 +184,14 @@ func (i *importer) Import(path string) (*types.Package, error) {
// ErrSkip indicates the package should be skipped.
var ErrSkip = errors.New("skipped")
-// checkStdlib checks the standard library.
+// CheckStdlib checks the standard library.
//
// This constructs a synthetic package configuration for each library in the
-// standard library sources, and call checkPackage repeatedly.
+// standard library sources, and call CheckPackage repeatedly.
//
// Note that not all parts of the source are expected to build. We skip obvious
// test files, and cmd files, which should not be dependencies.
-func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]string, []byte, error) {
+func CheckStdlib(config *StdlibConfig, analyzers []*analysis.Analyzer) (allFindings []Finding, facts []byte, err error) {
if len(config.Srcs) == 0 {
return nil, nil, nil
}
@@ -225,7 +224,7 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str
}
// Aggregate all files by directory.
- packages := make(map[string]*packageConfig)
+ packages := make(map[string]*PackageConfig)
for _, file := range config.Srcs {
if !strings.HasPrefix(file, rootSrcPrefix) {
// Superflouous file.
@@ -243,7 +242,7 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str
}
c, ok := packages[pkg]
if !ok {
- c = &packageConfig{
+ c = &PackageConfig{
ImportPath: pkg,
GOOS: config.GOOS,
GOARCH: config.GOARCH,
@@ -262,14 +261,18 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str
}
// Closure to check a single package.
- allFindings := make([]string, 0)
stdlibFacts := make(map[string][]byte)
+ stdlibErrs := make(map[string]error)
var checkOne func(pkg string) error // Recursive.
checkOne = func(pkg string) error {
// Is this already done?
if _, ok := stdlibFacts[pkg]; ok {
return nil
}
+ // Did this fail previously?
+ if _, ok := stdlibErrs[pkg]; ok {
+ return nil
+ }
// Lookup the configuration.
config, ok := packages[pkg]
@@ -283,6 +286,7 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str
// If there's no binary for this package, it is likely
// not built with the distribution. That's fine, we can
// just skip analysis.
+ stdlibErrs[pkg] = err
return nil
}
@@ -295,10 +299,11 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str
}()
// Run the analysis.
- findings, factData, err := checkPackage(config, ac, checkOne)
+ findings, factData, err := CheckPackage(config, analyzers, checkOne)
if err != nil {
// If we can't analyze a package from the standard library,
// then we skip it. It will simply not have any findings.
+ stdlibErrs[pkg] = err
return nil
}
stdlibFacts[pkg] = factData
@@ -312,7 +317,9 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str
// to evaluate in the order provided here. We do ensure however, that
// all packages are evaluated.
for pkg := range packages {
- checkOne(pkg)
+ if err := checkOne(pkg); err != nil {
+ return nil, nil, err
+ }
}
// Sanity check.
@@ -326,11 +333,16 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str
return nil, nil, fmt.Errorf("error saving stdlib facts: %w", err)
}
+ // Write out all errors.
+ for pkg, err := range stdlibErrs {
+ log.Printf("WARNING: error while processing %v: %v", pkg, err)
+ }
+
// Return all findings.
return allFindings, factData, nil
}
-// checkPackage runs all analyzers.
+// CheckPackage runs all given analyzers.
//
// The implementation was adapted from [1], which was in turn adpated from [2].
// This returns a list of matching analysis issues, or an error if the analysis
@@ -338,9 +350,9 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str
//
// [1] bazelbuid/rules_go/tools/builders/nogo_main.go
// [2] golang.org/x/tools/go/checker/internal/checker
-func checkPackage(config *packageConfig, ac map[*analysis.Analyzer]matcher, importCallback func(string) error) ([]string, []byte, error) {
+func CheckPackage(config *PackageConfig, analyzers []*analysis.Analyzer, importCallback func(string) error) (findings []Finding, factData []byte, err error) {
imp := &importer{
- packageConfig: config,
+ PackageConfig: config,
fset: token.NewFileSet(),
cache: make(map[string]*types.Package),
callback: importCallback,
@@ -392,7 +404,6 @@ func checkPackage(config *packageConfig, ac map[*analysis.Analyzer]matcher, impo
// Register fact types and establish dependencies between analyzers.
// The visit closure will execute recursively, and populate results
// will all required analysis results.
- diagnostics := make(map[*analysis.Analyzer][]analysis.Diagnostic)
results := make(map[*analysis.Analyzer]interface{})
var visit func(*analysis.Analyzer) error // For recursion.
visit = func(a *analysis.Analyzer) error {
@@ -407,27 +418,25 @@ func checkPackage(config *packageConfig, ac map[*analysis.Analyzer]matcher, impo
}
}
- // Prepare the matcher.
- m := ac[a]
- report := func(d analysis.Diagnostic) {
- if m.ShouldReport(d, imp.fset) {
- diagnostics[a] = append(diagnostics[a], d)
- }
- }
-
// Run the analysis.
factFilter := make(map[reflect.Type]bool)
for _, f := range a.FactTypes {
factFilter[reflect.TypeOf(f)] = true
}
p := &analysis.Pass{
- Analyzer: a,
- Fset: imp.fset,
- Files: syntax,
- Pkg: types,
- TypesInfo: typesInfo,
- ResultOf: results, // All results.
- Report: report,
+ Analyzer: a,
+ Fset: imp.fset,
+ Files: syntax,
+ Pkg: types,
+ TypesInfo: typesInfo,
+ ResultOf: results, // All results.
+ Report: func(d analysis.Diagnostic) {
+ findings = append(findings, Finding{
+ Category: AnalyzerName(a.Name),
+ Position: imp.fset.Position(d.Pos),
+ Message: d.Message,
+ })
+ },
ImportPackageFact: facts.ImportPackageFact,
ExportPackageFact: facts.ExportPackageFact,
ImportObjectFact: facts.ImportObjectFact,
@@ -450,7 +459,7 @@ func checkPackage(config *packageConfig, ac map[*analysis.Analyzer]matcher, impo
}
// Visit all analyzers recursively.
- for a, _ := range ac {
+ for _, a := range analyzers {
if imp.lastErr == ErrSkip {
continue // No local analysis.
}
@@ -459,114 +468,6 @@ func checkPackage(config *packageConfig, ac map[*analysis.Analyzer]matcher, impo
}
}
- // Convert all diagnostics to strings.
- findings := make([]string, 0, len(diagnostics))
- for a, ds := range diagnostics {
- for _, d := range ds {
- // Include the anlyzer name for debugability and configuration.
- findings = append(findings, fmt.Sprintf("%s: %s: %s", a.Name, imp.fset.Position(d.Pos), d.Message))
- }
- }
-
// Return all findings.
- factData := facts.Encode()
- return findings, factData, nil
-}
-
-var (
- packageFile = flag.String("package", "", "package configuration file (in JSON format)")
- stdlibFile = flag.String("stdlib", "", "stdlib configuration file (in JSON format)")
- findingsOutput = flag.String("findings", "", "output file (or stdout, if not specified)")
- factsOutput = flag.String("facts", "", "output file for facts (optional)")
- escapesOutput = flag.String("escapes", "", "output file for escapes (optional)")
-)
-
-func loadConfig(file string, config interface{}) interface{} {
- // Load the configuration.
- f, err := os.Open(file)
- if err != nil {
- log.Fatalf("unable to open configuration %q: %v", file, err)
- }
- defer f.Close()
- dec := json.NewDecoder(f)
- dec.DisallowUnknownFields()
- if err := dec.Decode(config); err != nil {
- log.Fatalf("unable to decode configuration: %v", err)
- }
- return config
-}
-
-// Main is the entrypoint; it should be called directly from main.
-//
-// N.B. This package registers it's own flags.
-func Main() {
- // Parse all flags.
- flag.Parse()
-
- var (
- findings []string
- factData []byte
- err error
- )
-
- // Check the configuration.
- if *packageFile != "" && *stdlibFile != "" {
- log.Fatalf("unable to perform stdlib and package analysis; provide only one!")
- } else if *stdlibFile != "" {
- // Perform basic analysis.
- c := loadConfig(*stdlibFile, new(stdlibConfig)).(*stdlibConfig)
- findings, factData, err = checkStdlib(c, analyzerConfig)
- } else if *packageFile != "" {
- // Perform basic analysis.
- c := loadConfig(*packageFile, new(packageConfig)).(*packageConfig)
- findings, factData, err = checkPackage(c, analyzerConfig, nil)
- // Do we need to do escape analysis?
- if *escapesOutput != "" {
- escapes, _, err := checkPackage(c, escapesConfig, nil)
- if err != nil {
- log.Fatalf("error performing escape analysis: %v", err)
- }
- f, err := os.OpenFile(*escapesOutput, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
- if err != nil {
- log.Fatalf("unable to open output %q: %v", *escapesOutput, err)
- }
- defer f.Close()
- for _, escape := range escapes {
- fmt.Fprintf(f, "%s\n", escape)
- }
- }
- } else {
- log.Fatalf("please provide at least one of package or stdlib!")
- }
-
- // Save facts.
- if *factsOutput != "" {
- if err := ioutil.WriteFile(*factsOutput, factData, 0644); err != nil {
- log.Fatalf("error saving findings to %q: %v", *factsOutput, err)
- }
- }
-
- // Open the output file.
- var w io.Writer = os.Stdout
- if *findingsOutput != "" {
- f, err := os.OpenFile(*findingsOutput, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
- if err != nil {
- log.Fatalf("unable to open output %q: %v", *findingsOutput, err)
- }
- defer f.Close()
- w = f
- }
-
- // Handle findings & errors.
- if err != nil {
- log.Fatalf("error checking package: %v", err)
- }
- if len(findings) == 0 {
- return
- }
-
- // Print findings.
- for _, finding := range findings {
- fmt.Fprintf(w, "%s\n", finding)
- }
+ return findings, facts.Encode(), nil
}
diff --git a/tools/nogo/register.go b/tools/nogo/register.go
deleted file mode 100644
index 34b173937..000000000
--- a/tools/nogo/register.go
+++ /dev/null
@@ -1,67 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package nogo
-
-import (
- "encoding/gob"
- "log"
-
- "golang.org/x/tools/go/analysis"
-)
-
-// analyzers returns all configured analyzers.
-func analyzers() (all []*analysis.Analyzer) {
- for a, _ := range analyzerConfig {
- all = append(all, a)
- }
- for a, _ := range escapesConfig {
- all = append(all, a)
- }
- return all
-}
-
-func init() {
- // Validate basic configuration.
- if err := analysis.Validate(analyzers()); err != nil {
- log.Fatalf("unable to validate analyzer: %v", err)
- }
-
- // Register all fact types.
- //
- // N.B. This needs to be done recursively, because there may be
- // analyzers in the Requires list that do not appear explicitly above.
- registered := make(map[*analysis.Analyzer]struct{})
- var register func(*analysis.Analyzer)
- register = func(a *analysis.Analyzer) {
- if _, ok := registered[a]; ok {
- return
- }
-
- // Regsiter dependencies.
- for _, da := range a.Requires {
- register(da)
- }
-
- // Register local facts.
- for _, f := range a.FactTypes {
- gob.Register(f)
- }
-
- registered[a] = struct{}{} // Done.
- }
- for _, a := range analyzers() {
- register(a)
- }
-}
diff --git a/tools/nogo/util/BUILD b/tools/nogo/util/BUILD
deleted file mode 100644
index 7ab340b51..000000000
--- a/tools/nogo/util/BUILD
+++ /dev/null
@@ -1,9 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "util",
- srcs = ["util.go"],
- visibility = ["//visibility:public"],
-)
diff --git a/tools/nogo/util/util.go b/tools/nogo/util/util.go
deleted file mode 100644
index 919fec799..000000000
--- a/tools/nogo/util/util.go
+++ /dev/null
@@ -1,85 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package util contains nogo-related utilities.
-package util
-
-import (
- "fmt"
- "io/ioutil"
- "regexp"
- "strconv"
- "strings"
-)
-
-// findingRegexp is used to parse findings.
-var findingRegexp = regexp.MustCompile(`([a-zA-Z0-9_\/\.-]+): (-|([a-zA-Z0-9_\/\.-]+):([0-9]+)(:([0-9]+))?): (.*)`)
-
-const (
- categoryIndex = 1
- fullPathAndLineIndex = 2
- fullPathIndex = 3
- lineIndex = 4
- messageIndex = 7
-)
-
-// Finding is a single finding.
-type Finding struct {
- Category string
- Path string
- Line int
- Message string
-}
-
-// ExtractFindingsFromFile loads findings from a file.
-func ExtractFindingsFromFile(filename string) ([]Finding, error) {
- content, err := ioutil.ReadFile(filename)
- if err != nil {
- return nil, err
- }
- return ExtractFindingsFromBytes(content)
-}
-
-// ExtractFindingsFromBytes loads findings from bytes.
-func ExtractFindingsFromBytes(content []byte) (findings []Finding, err error) {
- lines := strings.Split(string(content), "\n")
- for _, singleLine := range lines {
- // Skip blank lines.
- singleLine = strings.TrimSpace(singleLine)
- if singleLine == "" {
- continue
- }
- m := findingRegexp.FindStringSubmatch(singleLine)
- if m == nil {
- // We shouldn't see findings like this.
- return findings, fmt.Errorf("poorly formated line: %v", singleLine)
- }
- if m[fullPathAndLineIndex] == "-" {
- continue // No source file available.
- }
- // Cleanup the message.
- message := m[messageIndex]
- message = strings.Replace(message, " → ", "\n → ", -1)
- message = strings.Replace(message, " or ", "\n or ", -1)
- // Construct a new annotation.
- lineNumber, _ := strconv.ParseUint(m[lineIndex], 10, 32)
- findings = append(findings, Finding{
- Category: m[categoryIndex],
- Path: m[fullPathIndex],
- Line: int(lineNumber),
- Message: message,
- })
- }
- return findings, nil
-}
diff --git a/tools/parsers/BUILD b/tools/parsers/BUILD
index 7d9c9a3fb..dab954e25 100644
--- a/tools/parsers/BUILD
+++ b/tools/parsers/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_binary", "go_library", "go_test")
package(licenses = ["notice"])
@@ -7,6 +7,7 @@ go_test(
size = "small",
srcs = ["go_parser_test.go"],
library = ":parsers",
+ nogo = False,
deps = [
"//tools/bigquery",
"@com_github_google_go_cmp//cmp:go_default_library",
@@ -19,9 +20,22 @@ go_library(
srcs = [
"go_parser.go",
],
+ nogo = False,
visibility = ["//:sandbox"],
deps = [
"//test/benchmarks/tools",
"//tools/bigquery",
],
)
+
+go_binary(
+ name = "parser",
+ testonly = 1,
+ srcs = ["parser_main.go"],
+ nogo = False,
+ deps = [
+ ":parsers",
+ "//runsc/flag",
+ "//tools/bigquery",
+ ],
+)
diff --git a/tools/parsers/go_parser.go b/tools/parsers/go_parser.go
index 2cf74c883..df4875e6a 100644
--- a/tools/parsers/go_parser.go
+++ b/tools/parsers/go_parser.go
@@ -27,20 +27,21 @@ import (
"gvisor.dev/gvisor/tools/bigquery"
)
-// parseOutput expects golang benchmark output returns a Benchmark struct formatted for BigQuery.
-func parseOutput(output string, metadata *bigquery.Metadata, official bool) ([]*bigquery.Benchmark, error) {
- var benchmarks []*bigquery.Benchmark
+// ParseOutput expects golang benchmark output and returns a struct formatted
+// for BigQuery.
+func ParseOutput(output string, name string, official bool) (*bigquery.Suite, error) {
+ suite := bigquery.NewSuite(name)
lines := strings.Split(output, "\n")
for _, line := range lines {
- bm, err := parseLine(line, metadata, official)
+ bm, err := parseLine(line, official)
if err != nil {
return nil, fmt.Errorf("failed to parse line '%s': %v", line, err)
}
if bm != nil {
- benchmarks = append(benchmarks, bm)
+ suite.Benchmarks = append(suite.Benchmarks, bm)
}
}
- return benchmarks, nil
+ return suite, nil
}
// parseLine handles parsing a benchmark line into a bigquery.Benchmark.
@@ -58,9 +59,8 @@ func parseOutput(output string, metadata *bigquery.Metadata, official bool) ([]*
// {Name: ns/op, Unit: ns/op, Sample: 1397875880}
// {Name: requests_per_second, Unit: QPS, Sample: 140 }
// }
-// Metadata: metadata
//}
-func parseLine(line string, metadata *bigquery.Metadata, official bool) (*bigquery.Benchmark, error) {
+func parseLine(line string, official bool) (*bigquery.Benchmark, error) {
fields := strings.Fields(line)
// Check if this line is a Benchmark line. Otherwise ignore the line.
@@ -79,7 +79,6 @@ func parseLine(line string, metadata *bigquery.Metadata, official bool) (*bigque
}
bm := bigquery.NewBenchmark(name, iters, official)
- bm.Metadata = metadata
for _, p := range params {
bm.AddCondition(p.Name, p.Value)
}
diff --git a/tools/parsers/go_parser_test.go b/tools/parsers/go_parser_test.go
index 36996b7c8..0aa1152a2 100644
--- a/tools/parsers/go_parser_test.go
+++ b/tools/parsers/go_parser_test.go
@@ -94,13 +94,11 @@ func TestParseLine(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- got, err := parseLine(tc.data, nil, false)
+ got, err := parseLine(tc.data, false)
if err != nil {
t.Fatalf("parseLine failed with: %v", err)
}
- tc.want.Timestamp = got.Timestamp
-
if !cmp.Equal(tc.want, got, nil) {
for _, c := range got.Condition {
t.Logf("Cond: %+v", c)
@@ -150,14 +148,14 @@ BenchmarkRuby/server_threads.5-6 1 1416003331 ns/op 0.00950 average_latency.s 46
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- bms, err := parseOutput(tc.data, nil, false)
+ suite, err := ParseOutput(tc.data, "", false)
if err != nil {
t.Fatalf("parseOutput failed: %v", err)
- } else if len(bms) != tc.numBenchmarks {
- t.Fatalf("NumBenchmarks failed want: %d got: %d %+v", tc.numBenchmarks, len(bms), bms)
+ } else if len(suite.Benchmarks) != tc.numBenchmarks {
+ t.Fatalf("NumBenchmarks failed want: %d got: %d %+v", tc.numBenchmarks, len(suite.Benchmarks), suite.Benchmarks)
}
- for _, bm := range bms {
+ for _, bm := range suite.Benchmarks {
if len(bm.Metric) != tc.numMetrics {
t.Fatalf("NumMetrics failed want: %d got: %d %+v", tc.numMetrics, len(bm.Metric), bm.Metric)
}
diff --git a/tools/parsers/parser_main.go b/tools/parsers/parser_main.go
new file mode 100644
index 000000000..6c6182464
--- /dev/null
+++ b/tools/parsers/parser_main.go
@@ -0,0 +1,129 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary parser parses Benchmark data from golang benchmarks,
+// puts it into a Schema for BigQuery, and sends it to BigQuery.
+// parser will also initialize a table with the Benchmarks BigQuery schema.
+package main
+
+import (
+ "context"
+ "fmt"
+ "io/ioutil"
+ "os"
+
+ "gvisor.dev/gvisor/runsc/flag"
+ bq "gvisor.dev/gvisor/tools/bigquery"
+ "gvisor.dev/gvisor/tools/parsers"
+)
+
+const (
+ initString = "init"
+ initDescription = "initializes a new table with benchmarks schema"
+ parseString = "parse"
+ parseDescription = "parses given benchmarks file and sends it to BigQuery table."
+)
+
+var (
+ // The init command will create a new dataset/table in the given project and initialize
+ // the table with the schema in //tools/bigquery/bigquery.go. If the table/dataset exists
+ // or has been initialized, init has no effect and successfully returns.
+ initCmd = flag.NewFlagSet(initString, flag.ContinueOnError)
+ initProject = initCmd.String("project", "", "GCP project to send benchmarks.")
+ initDataset = initCmd.String("dataset", "", "dataset to send benchmarks data.")
+ initTable = initCmd.String("table", "", "table to send benchmarks data.")
+
+ // The parse command parses benchmark data in `file` and sends it to the
+ // requested table.
+ parseCmd = flag.NewFlagSet(parseString, flag.ContinueOnError)
+ file = parseCmd.String("file", "", "file to parse for benchmarks")
+ name = parseCmd.String("suite_name", "", "name of the benchmark suite")
+ clNumber = parseCmd.String("cl", "", "changelist number of this run")
+ gitCommit = parseCmd.String("git_commit", "", "git commit sha for this run")
+ parseProject = parseCmd.String("project", "", "GCP project to send benchmarks.")
+ parseDataset = parseCmd.String("dataset", "", "dataset to send benchmarks data.")
+ parseTable = parseCmd.String("table", "", "table to send benchmarks data.")
+ official = parseCmd.Bool("official", false, "mark input data as official.")
+)
+
+// initBenchmarks initializes a dataset/table in a BigQuery project.
+func initBenchmarks(ctx context.Context) error {
+ return bq.InitBigQuery(ctx, *initProject, *initDataset, *initTable, nil)
+}
+
+// parseBenchmarks parses the given file into the BigQuery schema,
+// adds some custom data for the commit, and sends the data to BigQuery.
+func parseBenchmarks(ctx context.Context) error {
+ data, err := ioutil.ReadFile(*file)
+ if err != nil {
+ return fmt.Errorf("failed to read file: %v", err)
+ }
+ suite, err := parsers.ParseOutput(string(data), *name, *official)
+ if err != nil {
+ return fmt.Errorf("failed parse data: %v", err)
+ }
+ extraConditions := []*bq.Condition{
+ {
+ Name: "change_list",
+ Value: *clNumber,
+ },
+ {
+ Name: "commit",
+ Value: *gitCommit,
+ },
+ }
+
+ suite.Conditions = append(suite.Conditions, extraConditions...)
+ return bq.SendBenchmarks(ctx, suite, *parseProject, *parseDataset, *parseTable, nil)
+}
+
+func main() {
+ ctx := context.Background()
+ switch {
+ // the "init" command
+ case len(os.Args) >= 2 && os.Args[1] == initString:
+ if err := initCmd.Parse(os.Args[2:]); err != nil {
+ fmt.Fprintf(os.Stderr, "failed parse flags: %v", err)
+ os.Exit(1)
+ }
+ if err := initBenchmarks(ctx); err != nil {
+ failure := "failed to initialize project: %s dataset: %s table: %s: %v"
+ fmt.Fprintf(os.Stderr, failure, *parseProject, *parseDataset, *parseTable, err)
+ os.Exit(1)
+ }
+ // the "parse" command.
+ case len(os.Args) >= 2 && os.Args[1] == parseString:
+ if err := parseCmd.Parse(os.Args[2:]); err != nil {
+ fmt.Fprintf(os.Stderr, "failed parse flags: %v", err)
+ os.Exit(1)
+ }
+ if err := parseBenchmarks(ctx); err != nil {
+ fmt.Fprintf(os.Stderr, "failed parse benchmarks: %v", err)
+ os.Exit(1)
+ }
+ default:
+ printUsage()
+ }
+}
+
+// printUsage prints the top level usage string.
+func printUsage() {
+ usage := `Usage: parser <command> <flags> ...
+
+Available commands:
+ %s %s
+ %s %s
+`
+ fmt.Fprintf(os.Stderr, usage, initCmd.Name(), initDescription, parseCmd.Name(), parseDescription)
+}
diff --git a/tools/rules_go.patch b/tools/rules_go.patch
new file mode 100644
index 000000000..5e1e87084
--- /dev/null
+++ b/tools/rules_go.patch
@@ -0,0 +1,14 @@
+diff --git a/go/private/rules/test.bzl b/go/private/rules/test.bzl
+index 17516ad7..76b6c68c 100644
+--- a/go/private/rules/test.bzl
++++ b/go/private/rules/test.bzl
+@@ -121,9 +121,6 @@ def _go_test_impl(ctx):
+ )
+
+ test_gc_linkopts = gc_linkopts(ctx)
+- if not go.mode.debug:
+- # Disable symbol table and DWARF generation for test binaries.
+- test_gc_linkopts.extend(["-s", "-w"])
+
+ # Now compile the test binary itself
+ test_library = GoLibrary(
diff --git a/tools/tag_release.sh b/tools/tag_release.sh
index b0bab74b4..50378065e 100755
--- a/tools/tag_release.sh
+++ b/tools/tag_release.sh
@@ -43,7 +43,7 @@ fi
closest_commit() {
while read line; do
- if [[ "$line" =~ "commit " ]]; then
+ if [[ "$line" =~ ^"commit " ]]; then
current_commit="${line#commit }"
continue
elif [[ "$line" =~ "PiperOrigin-RevId: " ]]; then
@@ -57,7 +57,9 @@ closest_commit() {
# Is the passed identifier a sha commit?
if ! git show "${target_commit}" &> /dev/null; then
# Extract the commit given a piper ID.
- declare -r commit="$(git log | closest_commit "${target_commit}")"
+ commit="$(set +o pipefail; \
+ git log --first-parent | closest_commit "${target_commit}")"
+ declare -r commit
else
declare -r commit="${target_commit}"
fi
diff --git a/website/BUILD b/website/BUILD
index f3642b903..676c2b701 100644
--- a/website/BUILD
+++ b/website/BUILD
@@ -6,11 +6,16 @@ package(licenses = ["notice"])
docker_image(
name = "website",
- data = [":files"],
+ data = ":files",
statements = [
"EXPOSE 8080/tcp",
'ENTRYPOINT ["/server"]',
],
+ tags = [
+ "local",
+ "manual",
+ "nosandbox",
+ ],
)
# files is the full file system of the generated container.
diff --git a/website/_config.yml b/website/_config.yml
index 20fbb3d2d..51cb8e13c 100644
--- a/website/_config.yml
+++ b/website/_config.yml
@@ -37,3 +37,9 @@ authors:
fvoznika:
name: Fabricio Voznika
email: fvoznika@google.com
+ ianlewis:
+ name: Ian Lewis
+ email: ianlewis@google.com
+ mpratt:
+ name: Michael Pratt
+ email: mpratt@google.com
diff --git a/website/blog/2020-10-22-platform-portability.md b/website/blog/2020-10-22-platform-portability.md
new file mode 100644
index 000000000..4d82940f9
--- /dev/null
+++ b/website/blog/2020-10-22-platform-portability.md
@@ -0,0 +1,120 @@
+# Platform Portability
+
+Hardware virtualization is often seen as a requirement to provide an additional
+isolation layer for untrusted applications. However, hardware virtualization
+requires expensive bare-metal machines or cloud instances to run safely with
+good performance, increasing cost and complexity for Cloud users. gVisor,
+however, takes a more flexible approach.
+
+One of the pillars of gVisor's architecture is portability, allowing it to run
+anywhere that runs Linux. Modern Cloud-Native applications run in containers in
+many different places, from bare metal to virtual machines, and can't always
+rely on nested virtualization. It is important for gVisor to be able to support
+the environments where you run containers.
+
+gVisor achieves portability through an abstraction called a _Platform_.
+Platforms can have many implementations, and each implementation can cover
+different environments, making use of available software or hardware features.
+
+## Background
+
+Before we can understand how gVisor achieves portability using platforms, we
+should take a step back and understand how applications interact with their
+host.
+
+Container sandboxes can provide an isolation layer between the host and
+application by virtualizing one of the layers below it, including the hardware
+or operating system. Many sandboxes virtualize the hardware layer by running
+applications in virtual machines. gVisor takes a different approach by
+virtualizing the OS layer.
+
+When an application is run in a normal situation the host operating system loads
+the application into user memory and schedules it for execution. The operating
+system scheduler eventually schedules the application to a CPU and begins
+executing it. It then handles the application's requests, such as for memory and
+the lifecycle of the application. gVisor virtualizes these interactions, such as
+system calls, and context switching that happen between an application and OS.
+
+[System calls](https://en.wikipedia.org/wiki/System_call) allow applications to
+ask the OS to perform some task for it. System calls look like a normal function
+call in most programming languages though works a bit differently under the
+hood. When an application system call is encountered some special processing
+takes place to do a
+[context switch](https://en.wikipedia.org/wiki/Context_switch) into kernel mode
+and begin executing code in the kernel before returning a result to the
+application. Context switching may happen in other situations as well. For
+example, to respond to an interrupt.
+
+## The Platform Interface
+
+gVisor provides a sandbox which implements the Linux OS interface, intercepting
+OS interactions such as system calls and implements them in the sandbox kernel.
+
+It does this to limit interactions with the host, and protect the host from an
+untrusted application running in the sandbox. The Platform is the bottom layer
+of gVisor which provides the environment necessary for gVisor to control and
+manage applications. In general, the Platform must:
+
+1. Provide the ability to create and manage memory address spaces.
+2. Provide execution contexts for running applications in those memory address
+ spaces.
+3. Provide the ability to change execution context and return control to gVisor
+ at specific times (e.g. system call, page fault)
+
+This interface is conceptually simple, but very powerful. Since the Platform
+interface only requires these three capabilities, it gives gVisor enough control
+for it to act as the application's OS, while still allowing the use of very
+different isolation technologies under the hood. You can learn more about the
+Platform interface in the
+[Platform Guide](https://gvisor.dev/docs/architecture_guide/platforms/).
+
+## Implementations of the Platform Interface
+
+While gVisor can make use of technologies like hardware virtualization, it
+doesn't necessarily rely on any one technology to provide a similar level of
+isolation. The flexibility of the Platform interface allows for implementations
+that use technologies other than hardware virtualization. This allows gVisor to
+run in VMs without nested virtualization, for example. By providing an
+abstraction for the underlying platform, each implementation can make various
+tradeoffs regarding performance or hardware requirements.
+
+Currently gVisor provides two gVisor Platform implementations; the Ptrace
+Platform, and the KVM Platform, each using very different methods to implement
+the Platform interface.
+
+![gVisor Platforms](../../../../../docs/architecture_guide/platforms/platforms.png "Platforms")
+
+The Ptrace Platform uses
+[PTRACE\_SYSEMU](http://man7.org/linux/man-pages/man2/ptrace.2.html) to trap
+syscalls, and uses the host for memory mapping and context switching. This
+platform can run anywhere that ptrace is available, which includes most Linux
+systems, VMs or otherwise.
+
+The KVM Platform uses virtualization, but in an unconventional way. gVisor runs
+in a virtual machine but as both guest OS and VMM, and presents no virtualized
+hardware layer. This provides a simpler interface that can avoid hardware
+initialization for fast start up, while taking advantage of hardware
+virtualization support to improve memory isolation and performance of context
+switching.
+
+The flexibility of the Platform interface allows for a lot of room to improve
+the existing KVM and ptrace platforms, as well as the ability to utilize new
+methods for improving gVisor's performance or portability in future Platform
+implementations.
+
+## Portability
+
+Through the Platform interface, gVisor is able to support bare metal, virtual
+machines, and Cloud environments while still providing a highly secure sandbox
+for running untrusted applications. This is especially important for Cloud and
+Kubernetes users because it allows gVisor to run anywhere that Kubernetes can
+run and provide similar experiences in multi-region, hybrid, multi-platform
+environments.
+
+Give gVisor's open source platforms a try. Using a Platform is as easy as
+providing the `--platform` flag to `runsc`. See the documentation on
+[changing platforms](https://gvisor.dev/docs/user_guide/platforms/) for how to
+use different platforms with Docker. We would love to hear about your experience
+so come chat with us in our
+[Gitter channel](https://gitter.im/gvisor/community), or send us an
+[issue on Github](https://gvisor.dev/issue) if you run into any problems.
diff --git a/website/blog/BUILD b/website/blog/BUILD
index 865e403da..17beb721f 100644
--- a/website/blog/BUILD
+++ b/website/blog/BUILD
@@ -38,6 +38,17 @@ doc(
permalink = "/blog/2020/09/18/containing-a-real-vulnerability/",
)
+doc(
+ name = "platform_portability",
+ src = "2020-10-22-platform-portability.md",
+ authors = [
+ "ianlewis",
+ "mpratt",
+ ],
+ layout = "post",
+ permalink = "/blog/2020/10/22/platform-portability/",
+)
+
docs(
name = "posts",
deps = [
diff --git a/website/cmd/server/main.go b/website/cmd/server/main.go
index c401b6abd..ac09550a9 100644
--- a/website/cmd/server/main.go
+++ b/website/cmd/server/main.go
@@ -29,6 +29,7 @@ var redirects = map[string]string{
// GitHub redirects.
"/change": "https://github.com/google/gvisor",
"/issue": "https://github.com/google/gvisor/issues",
+ "/issues": "https://github.com/google/gvisor/issues",
"/issue/new": "https://github.com/google/gvisor/issues/new",
"/pr": "https://github.com/google/gvisor/pulls",
@@ -44,14 +45,16 @@ var redirects = map[string]string{
"/c/linux/amd64": "/docs/user_guide/compatibility/linux/amd64/",
// Redirect for old URLs.
- "/docs/user_guide/compatibility/amd64/": "/docs/user_guide/compatibility/linux/amd64/",
- "/docs/user_guide/compatibility/amd64": "/docs/user_guide/compatibility/linux/amd64/",
- "/docs/user_guide/kubernetes/": "/docs/user_guide/quick_start/kubernetes/",
- "/docs/user_guide/kubernetes": "/docs/user_guide/quick_start/kubernetes/",
- "/docs/user_guide/oci/": "/docs/user_guide/quick_start/oci/",
- "/docs/user_guide/oci": "/docs/user_guide/quick_start/oci/",
- "/docs/user_guide/docker/": "/docs/user_guide/quick_start/docker/",
- "/docs/user_guide/docker": "/docs/user_guide/quick_start/docker/",
+ "/docs/user_guide/compatibility/amd64/": "/docs/user_guide/compatibility/linux/amd64/",
+ "/docs/user_guide/compatibility/amd64": "/docs/user_guide/compatibility/linux/amd64/",
+ "/docs/user_guide/kubernetes/": "/docs/user_guide/quick_start/kubernetes/",
+ "/docs/user_guide/kubernetes": "/docs/user_guide/quick_start/kubernetes/",
+ "/docs/user_guide/oci/": "/docs/user_guide/quick_start/oci/",
+ "/docs/user_guide/oci": "/docs/user_guide/quick_start/oci/",
+ "/docs/user_guide/docker/": "/docs/user_guide/quick_start/docker/",
+ "/docs/user_guide/docker": "/docs/user_guide/quick_start/docker/",
+ "/blog/2020/09/22/platform-portability": "/blog/2020/10/22/platform-portability/",
+ "/blog/2020/09/22/platform-portability/": "/blog/2020/10/22/platform-portability/",
// Deprecated, but links continue to work.
"/cl": "https://gvisor-review.googlesource.com",
@@ -60,6 +63,7 @@ var redirects = map[string]string{
var prefixHelpers = map[string]string{
"change": "https://github.com/google/gvisor/commit/%s",
"issue": "https://github.com/google/gvisor/issues/%s",
+ "issues": "https://github.com/google/gvisor/issues/%s",
"pr": "https://github.com/google/gvisor/pull/%s",
// Redirects to compatibility docs.