diff options
271 files changed, 11133 insertions, 4328 deletions
@@ -1,10 +1,17 @@ load("//tools:defs.bzl", "build_test", "gazelle", "go_path") +load("//tools/nogo:defs.bzl", "nogo_config") load("//website:defs.bzl", "doc") package(licenses = ["notice"]) exports_files(["LICENSE"]) +nogo_config( + name = "nogo_config", + srcs = ["nogo.yaml"], + visibility = ["//:sandbox"], +) + doc( name = "contributing", src = "CONTRIBUTING.md", @@ -260,7 +260,7 @@ website-build: load-jekyll ## Build the site image locally. .PHONY: website-build website-server: website-build ## Run a local server for development. - @docker run -i -p 8080:8080 gvisor.dev/images/website + @docker run -i -p 8080:8080 $(WEBSITE_IMAGE) .PHONY: website-server website-push: website-build ## Push a new image and update the service. diff --git a/g3doc/user_guide/containerd/quick_start.md b/g3doc/user_guide/containerd/quick_start.md index b6a3186d8..a98fe5c4a 100644 --- a/g3doc/user_guide/containerd/quick_start.md +++ b/g3doc/user_guide/containerd/quick_start.md @@ -1,7 +1,7 @@ # Containerd Quick Start -This document describes how to install and configure `containerd-shim-runsc-v1` -using the containerd runtime handler support on `containerd` 1.2 or later. +This document describes how to use `containerd-shim-runsc-v1` with the +containerd runtime handler support on `containerd` 1.2 or later. > ⚠️ NOTE: If you are using Kubernetes and set up your cluster using kubeadm you > may run into issues. See the [FAQ](../FAQ.md#runtime-handler) for details. diff --git a/g3doc/user_guide/install.md b/g3doc/user_guide/install.md index abb9e8582..c3ced9d61 100644 --- a/g3doc/user_guide/install.md +++ b/g3doc/user_guide/install.md @@ -13,15 +13,19 @@ To download and install the latest release manually follow these steps: ( set -e URL=https://storage.googleapis.com/gvisor/releases/release/latest - wget ${URL}/runsc ${URL}/runsc.sha512 - sha512sum -c runsc.sha512 - rm -f runsc.sha512 - sudo mv runsc /usr/local/bin - sudo chmod a+rx /usr/local/bin/runsc + wget ${URL}/runsc ${URL}/runsc.sha512 \ + ${URL}/gvisor-containerd-shim ${URL}/gvisor-containerd-shim.sha512 \ + ${URL}/containerd-shim-runsc-v1 ${URL}/containerd-shim-runsc-v1.sha512 + sha512sum -c runsc.sha512 \ + -c gvisor-containerd-shim.sha512 \ + -c containerd-shim-runsc-v1.sha512 + rm -f *.sha512 + chmod a+rx runsc gvisor-containerd-shim containerd-shim-runsc-v1 + sudo mv runsc gvisor-containerd-shim containerd-shim-runsc-v1 /usr/local/bin ) ``` -To install gVisor with Docker, run the following commands: +To install gVisor as a Docker runtime, run the following commands: ```bash /usr/local/bin/runsc install @@ -165,5 +169,6 @@ You can use this link with the steps described in Note that `apt` installation of a specific point release is not supported. After installation, try out `runsc` by following the -[Docker Quick Start](./quick_start/docker.md) or +[Docker Quick Start](./quick_start/docker.md), +[Containerd QuickStart](./containerd/quick_start.md), or [OCI Quick Start](./quick_start/oci.md). diff --git a/images/defs.bzl b/images/defs.bzl index 61d7bbf73..c1f96e312 100644 --- a/images/defs.bzl +++ b/images/defs.bzl @@ -2,30 +2,33 @@ def _docker_image_impl(ctx): importer = ctx.actions.declare_file(ctx.label.name) + importer_content = [ "#!/bin/bash", "set -euo pipefail", + "source_file='%s'" % ctx.file.data.path, + "if [[ ! -f \"$source_file\" ]]; then", + " source_file='%s'" % ctx.file.data.short_path, + "fi", "exec docker import " + " ".join([ "-c '%s'" % attr for attr in ctx.attr.statements - ]) + " " + " ".join([ - "'%s'" % f.path - for f in ctx.files.data - ]) + " $1", + ]) + " \"$source_file\" $1", "", ] + ctx.actions.write(importer, "\n".join(importer_content), is_executable = True) return [DefaultInfo( - runfiles = ctx.runfiles(ctx.files.data), + runfiles = ctx.runfiles([ctx.file.data]), executable = importer, )] docker_image = rule( implementation = _docker_image_impl, - doc = "Tool to load a Docker image; takes a single parameter (image name).", + doc = "Tool to import a Docker image; takes a single parameter (image name).", attrs = { "statements": attr.string_list(doc = "Extra Dockerfile directives."), - "data": attr.label_list(doc = "All image data."), + "data": attr.label(doc = "Image filesystem tarball", allow_single_file = [".tgz", ".tar.gz"]), }, executable = True, ) diff --git a/nogo.yaml b/nogo.yaml new file mode 100644 index 000000000..07fd9aa4d --- /dev/null +++ b/nogo.yaml @@ -0,0 +1,495 @@ +groups: + # We define three basic groups: generated (all generated files), + # exteranl (all files outside the repository), and internal (all + # files within the local repository). We can't enforce many style + # checks on generated and external code, so enable those cases + # selectively for analyzers below. + - name: generated + regex: "^(bazel-genfiles|bazel-out|bazel-bin)/" + default: true + - name: external + regex: "^external/" + default: false + - name: internal + regex: ".*" + default: true +global: + generated: + suppress: + # Suppress the basic style checks for + # generated code, but keep the analysis + # that are required for quality & security. + - "should not use ALL_CAPS in Go names" + - "should not use underscores" + - "comment on exported" + - "methods on the same type should have the same receiver name" + - "at least one file in a package" + - "package comment should be of the form" + # Generated code may have dead code paths. + - "identical build constraints" + - "no value of type" + - "is never used" + # go_embed_data rules generate unicode literals. + - "string literal contains the Unicode format character" + - "string literal contains the Unicode control character" + # Some external code will generate protov1 + # implementations. These should be ignored. + - "proto.* is deprecated" + - "xxx_messageInfo_.*" + - "receiver name should be a reflection of its identity" + # Generated gRPC code is not compliant either. + - "error strings should not be capitalized" + - "grpc.Errorf is deprecated" + internal: + suppress: + # We use ALL_CAPS for system definitions, + # which are common enough in the code base + # that we shouldn't annotate exceptions. + # + # Same story for underscores. + - "should not use ALL_CAPS in Go names" + - "should not use underscores in Go names" + exclude: + # A variety of staticcheck and stylecheck + # rules apply here. These should be fixed + # and removed from here, and the global + # rules should be used sparingly. + - pkg/abi/linux/fuse.go:22 + - pkg/abi/linux/fuse.go:25 + - pkg/abi/linux/socket.go:113 + - pkg/abi/linux/tty.go:73 + - pkg/bpf/decoder.go:112 + - pkg/cpuid/cpuid_x86.go:675 + - pkg/eventchannel/event.go:193 + - pkg/eventchannel/event.go:27 + - pkg/eventchannel/event_test.go:22 + - pkg/eventchannel/rate.go:19 + - pkg/gohacks/gohacks_unsafe.go:33 + - pkg/log/json.go:30 + - pkg/log/log.go:359 + - pkg/merkletree/merkletree.go:230 + - pkg/merkletree/merkletree.go:243 + - pkg/merkletree/merkletree.go:249 + - pkg/merkletree/merkletree.go:266 + - pkg/merkletree/merkletree.go:355 + - pkg/merkletree/merkletree.go:369 + - pkg/metric/metric_test.go:20 + - pkg/p9/p9test/client_test.go:687 + - pkg/p9/transport_test.go:196 + - pkg/pool/pool.go:15 + - pkg/refs/refcounter.go:510 + - pkg/refs/refcounter_test.go:169 + - pkg/refs_vfs2/refs.go:16 + - pkg/safemem/block_unsafe.go:89 + - pkg/seccomp/seccomp.go:82 + - pkg/segment/test/set_functions.go:15 + - pkg/sentry/arch/signal.go:166 + - pkg/sentry/arch/signal.go:171 + - pkg/sentry/control/pprof.go:196 + - pkg/sentry/devices/memdev/full.go:58 + - pkg/sentry/devices/memdev/null.go:59 + - pkg/sentry/devices/memdev/random.go:68 + - pkg/sentry/devices/memdev/zero.go:86 + - pkg/sentry/fdimport/fdimport.go:15 + - pkg/sentry/fs/attr.go:257 + - pkg/sentry/fsbridge/fs.go:116 + - pkg/sentry/fsbridge/vfs.go:124 + - pkg/sentry/fsbridge/vfs.go:70 + - pkg/sentry/fs/copy_up.go:365 + - pkg/sentry/fs/copy_up_test.go:65 + - pkg/sentry/fs/dev/net_tun.go:161 + - pkg/sentry/fs/dev/net_tun.go:63 + - pkg/sentry/fs/dev/null.go:97 + - pkg/sentry/fs/dirent_cache.go:64 + - pkg/sentry/fs/file_overlay.go:327 + - pkg/sentry/fs/file_overlay.go:524 + - pkg/sentry/fs/filetest/filetest.go:55 + - pkg/sentry/fs/filetest/filetest.go:60 + - pkg/sentry/fs/fs.go:77 + - pkg/sentry/fs/fsutil/file.go:290 + - pkg/sentry/fs/fsutil/file.go:346 + - pkg/sentry/fs/fsutil/host_file_mapper.go:105 + - pkg/sentry/fs/fsutil/inode_cached.go:676 + - pkg/sentry/fs/fsutil/inode_cached.go:772 + - pkg/sentry/fs/gofer/attr.go:120 + - pkg/sentry/fs/gofer/fifo.go:33 + - pkg/sentry/fs/gofer/inode.go:410 + - pkg/sentry/fsimpl/devpts/devpts.go:110 + - pkg/sentry/fsimpl/devpts/devpts.go:246 + - pkg/sentry/fsimpl/devpts/devpts.go:50 + - pkg/sentry/fsimpl/devpts/master.go:110 + - pkg/sentry/fsimpl/devpts/master.go:55 + - pkg/sentry/fsimpl/devpts/replica.go:113 + - pkg/sentry/fsimpl/devpts/replica.go:57 + - pkg/sentry/fsimpl/devtmpfs/devtmpfs.go:54 + - pkg/sentry/fsimpl/ext/disklayout/superblock_64.go:97 + - pkg/sentry/fsimpl/ext/disklayout/superblock_old.go:92 + - pkg/sentry/fsimpl/ext/disklayout/block_group_32.go:44 + - pkg/sentry/fsimpl/ext/disklayout/inode_new.go:91 + - pkg/sentry/fsimpl/ext/disklayout/inode_old.go:93 + - pkg/sentry/fsimpl/ext/disklayout/superblock_32.go:66 + - pkg/sentry/fsimpl/ext/disklayout/block_group_64.go:53 + - pkg/sentry/fsimpl/eventfd/eventfd.go:268 + - pkg/sentry/fsimpl/ext/directory.go:163 + - pkg/sentry/fsimpl/ext/directory.go:164 + - pkg/sentry/fsimpl/ext/extent_file.go:142 + - pkg/sentry/fsimpl/ext/extent_file.go:143 + - pkg/sentry/fsimpl/ext/ext.go:105 + - pkg/sentry/fsimpl/ext/filesystem.go:287 + - pkg/sentry/fsimpl/ext/regular_file.go:153 + - pkg/sentry/fsimpl/ext/symlink.go:113 + - pkg/sentry/fsimpl/fuse/connection_control.go:194 + - pkg/sentry/fsimpl/fuse/dev.go:387 + - pkg/sentry/fsimpl/fuse/dev_test.go:318 + - pkg/sentry/fsimpl/fuse/fusefs.go:102 + - pkg/sentry/fsimpl/fuse/read_write.go:129 + - pkg/sentry/fsimpl/fuse/request_response.go:71 + - pkg/sentry/fsimpl/gofer/directory.go:135 + - pkg/sentry/fsimpl/gofer/filesystem.go:679 + - pkg/sentry/fsimpl/gofer/gofer.go:1694 + - pkg/sentry/fsimpl/gofer/gofer.go:276 + - pkg/sentry/fsimpl/gofer/regular_file.go:81 + - pkg/sentry/fsimpl/gofer/special_file.go:141 + - pkg/sentry/fsimpl/host/host.go:184 + - pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go:50 + - pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go:90 + - pkg/sentry/fsimpl/kernfs/fd_impl_util.go:273 + - pkg/sentry/fsimpl/kernfs/filesystem.go:247 + - pkg/sentry/fsimpl/kernfs/inode_impl_util.go:320 + - pkg/sentry/fsimpl/kernfs/inode_impl_util.go:497 + - pkg/sentry/fsimpl/kernfs/synthetic_directory.go:52 + - pkg/sentry/fsimpl/overlay/directory.go:119 + - pkg/sentry/fsimpl/overlay/filesystem.go:527 + - pkg/sentry/fsimpl/overlay/non_directory.go:152 + - pkg/sentry/fsimpl/overlay/overlay.go:115 + - pkg/sentry/fsimpl/overlay/overlay.go:719 + - pkg/sentry/fsimpl/pipefs/pipefs.go:74 + - pkg/sentry/fsimpl/proc/filesystem.go:52 + - pkg/sentry/fsimpl/proc/filesystem.go:81 + - pkg/sentry/fsimpl/proc/subtasks.go:126 + - pkg/sentry/fsimpl/proc/subtasks.go:189 + - pkg/sentry/fsimpl/proc/task_fds.go:168 + - pkg/sentry/fsimpl/proc/task_fds.go:228 + - pkg/sentry/fsimpl/proc/task_fds.go:301 + - pkg/sentry/fsimpl/proc/task_fds.go:318 + - pkg/sentry/fsimpl/proc/task_fds.go:67 + - pkg/sentry/fsimpl/proc/task_files.go:112 + - pkg/sentry/fsimpl/proc/task_files.go:158 + - pkg/sentry/fsimpl/proc/task_files.go:259 + - pkg/sentry/fsimpl/proc/task_files.go:285 + - pkg/sentry/fsimpl/proc/task_files.go:305 + - pkg/sentry/fsimpl/proc/task_files.go:384 + - pkg/sentry/fsimpl/proc/task_files.go:403 + - pkg/sentry/fsimpl/proc/task_files.go:428 + - pkg/sentry/fsimpl/proc/task_files.go:691 + - pkg/sentry/fsimpl/proc/task_files.go:770 + - pkg/sentry/fsimpl/proc/task_files.go:797 + - pkg/sentry/fsimpl/proc/task_files.go:828 + - pkg/sentry/fsimpl/proc/task_files.go:879 + - pkg/sentry/fsimpl/proc/task_files.go:910 + - pkg/sentry/fsimpl/proc/task_files.go:961 + - pkg/sentry/fsimpl/proc/task.go:127 + - pkg/sentry/fsimpl/proc/task.go:193 + - pkg/sentry/fsimpl/proc/task_net.go:134 + - pkg/sentry/fsimpl/proc/task_net.go:475 + - pkg/sentry/fsimpl/proc/task_net.go:491 + - pkg/sentry/fsimpl/proc/task_net.go:508 + - pkg/sentry/fsimpl/proc/task_net.go:665 + - pkg/sentry/fsimpl/proc/task_net.go:715 + - pkg/sentry/fsimpl/proc/task_net.go:779 + - pkg/sentry/fsimpl/proc/tasks_files.go:113 + - pkg/sentry/fsimpl/proc/tasks_files.go:388 + - pkg/sentry/fsimpl/proc/tasks.go:232 + - pkg/sentry/fsimpl/proc/tasks_sys.go:145 + - pkg/sentry/fsimpl/proc/tasks_sys.go:181 + - pkg/sentry/fsimpl/proc/tasks_sys.go:239 + - pkg/sentry/fsimpl/proc/tasks_sys.go:291 + - pkg/sentry/fsimpl/proc/tasks_sys.go:375 + - pkg/sentry/fsimpl/signalfd/signalfd.go:124 + - pkg/sentry/fsimpl/signalfd/signalfd.go:15 + - pkg/sentry/fsimpl/signalfd/signalfd.go:126 + - pkg/sentry/fsimpl/sockfs/sockfs.go:36 + - pkg/sentry/fsimpl/sockfs/sockfs.go:79 + - pkg/sentry/fsimpl/sys/kcov.go:49 + - pkg/sentry/fsimpl/sys/kcov.go:99 + - pkg/sentry/fsimpl/sys/sys.go:118 + - pkg/sentry/fsimpl/sys/sys.go:56 + - pkg/sentry/fsimpl/testutil/testutil.go:257 + - pkg/sentry/fsimpl/testutil/testutil.go:260 + - pkg/sentry/fsimpl/timerfd/timerfd.go:87 + - pkg/sentry/fsimpl/tmpfs/directory.go:112 + - pkg/sentry/fsimpl/tmpfs/filesystem.go:195 + - pkg/sentry/fsimpl/tmpfs/regular_file.go:226 + - pkg/sentry/fsimpl/tmpfs/regular_file.go:346 + - pkg/sentry/fsimpl/tmpfs/tmpfs.go:103 + - pkg/sentry/fsimpl/tmpfs/tmpfs.go:733 + - pkg/sentry/fsimpl/verity/filesystem.go:490 + - pkg/sentry/fsimpl/verity/verity.go:156 + - pkg/sentry/fsimpl/verity/verity.go:629 + - pkg/sentry/fsimpl/verity/verity.go:672 + - pkg/sentry/fs/mount.go:162 + - pkg/sentry/fs/mount.go:256 + - pkg/sentry/fs/mount_overlay.go:144 + - pkg/sentry/fs/mounts.go:432 + - pkg/sentry/fs/proc/exec_args.go:104 + - pkg/sentry/fs/proc/exec_args.go:73 + - pkg/sentry/fs/proc/fds.go:269 + - pkg/sentry/fs/proc/loadavg.go:33 + - pkg/sentry/fs/proc/meminfo.go:39 + - pkg/sentry/fs/proc/mounts.go:193 + - pkg/sentry/fs/proc/mounts.go:84 + - pkg/sentry/fs/proc/net.go:125 + - pkg/sentry/fs/proc/proc.go:146 + - pkg/sentry/fs/proc/proc.go:204 + - pkg/sentry/fs/proc/seqfile/seqfile.go:210 + - pkg/sentry/fs/proc/sys.go:146 + - pkg/sentry/fs/proc/sys.go:43 + - pkg/sentry/fs/proc/sys_net.go:113 + - pkg/sentry/fs/proc/sys_net.go:205 + - pkg/sentry/fs/proc/sys_net.go:233 + - pkg/sentry/fs/proc/sys_net.go:307 + - pkg/sentry/fs/proc/sys_net.go:335 + - pkg/sentry/fs/proc/sys_net.go:446 + - pkg/sentry/fs/proc/sys_net.go:456 + - pkg/sentry/fs/proc/sys_net.go:89 + - pkg/sentry/fs/proc/task.go:170 + - pkg/sentry/fs/proc/task.go:322 + - pkg/sentry/fs/proc/task.go:427 + - pkg/sentry/fs/proc/task.go:467 + - pkg/sentry/fs/proc/task.go:500 + - pkg/sentry/fs/proc/task.go:784 + - pkg/sentry/fs/proc/task.go:839 + - pkg/sentry/fs/proc/task.go:920 + - pkg/sentry/fs/proc/uid_gid_map.go:108 + - pkg/sentry/fs/proc/uid_gid_map.go:79 + - pkg/sentry/fs/proc/uptime.go:75 + - pkg/sentry/fs/ramfs/dir.go:447 + - pkg/sentry/fs/tmpfs/inode_file.go:436 + - pkg/sentry/fs/tmpfs/inode_file.go:537 + - pkg/sentry/fs/tty/dir.go:313 + - pkg/sentry/fs/tty/master.go:131 + - pkg/sentry/fs/tty/master.go:91 + - pkg/sentry/fs/tty/replica.go:116 + - pkg/sentry/fs/tty/replica.go:88 + - pkg/sentry/kernel/auth/id_map.go:269 + - pkg/sentry/kernel/fasync/fasync.go:67 + - pkg/sentry/kernel/kcov.go:209 + - pkg/sentry/kernel/kcov.go:223 + - pkg/sentry/kernel/kernel.go:343 + - pkg/sentry/kernel/kernel.go:368 + - pkg/sentry/kernel/pipe/node_test.go:112 + - pkg/sentry/kernel/pipe/node_test.go:119 + - pkg/sentry/kernel/pipe/node_test.go:130 + - pkg/sentry/kernel/pipe/node_test.go:137 + - pkg/sentry/kernel/pipe/node_test.go:149 + - pkg/sentry/kernel/pipe/node_test.go:150 + - pkg/sentry/kernel/pipe/node_test.go:158 + - pkg/sentry/kernel/pipe/node_test.go:174 + - pkg/sentry/kernel/pipe/node_test.go:180 + - pkg/sentry/kernel/pipe/node_test.go:193 + - pkg/sentry/kernel/pipe/node_test.go:202 + - pkg/sentry/kernel/pipe/node_test.go:205 + - pkg/sentry/kernel/pipe/node_test.go:216 + - pkg/sentry/kernel/pipe/node_test.go:219 + - pkg/sentry/kernel/pipe/node_test.go:271 + - pkg/sentry/kernel/pipe/node_test.go:290 + - pkg/sentry/kernel/pipe/pipe_test.go:93 + - pkg/sentry/kernel/pipe/reader_writer.go:65 + - pkg/sentry/kernel/posixtimer.go:157 + - pkg/sentry/kernel/ptrace.go:218 + - pkg/sentry/kernel/semaphore/semaphore.go:323 + - pkg/sentry/kernel/sessions.go:123 + - pkg/sentry/kernel/sessions.go:508 + - pkg/sentry/kernel/signal_handlers.go:57 + - pkg/sentry/kernel/task_context.go:72 + - pkg/sentry/kernel/task_exit.go:67 + - pkg/sentry/kernel/task_sched.go:255 + - pkg/sentry/kernel/task_sched.go:280 + - pkg/sentry/kernel/task_sched.go:323 + - pkg/sentry/kernel/task_stop.go:192 + - pkg/sentry/kernel/thread_group.go:530 + - pkg/sentry/kernel/timekeeper.go:316 + - pkg/sentry/kernel/vdso.go:106 + - pkg/sentry/kernel/vdso.go:118 + - pkg/sentry/memmap/memmap.go:103 + - pkg/sentry/memmap/memmap.go:163 + - pkg/sentry/mm/address_space.go:42 + - pkg/sentry/mm/address_space.go:42 + - pkg/sentry/mm/aio_context.go:208 + - pkg/sentry/mm/aio_context.go:288 + - pkg/sentry/mm/pma.go:683 + - pkg/sentry/mm/special_mappable.go:80 + - pkg/sentry/syscalls/linux/sys_sem.go:62 + - pkg/sentry/syscalls/linux/sys_time.go:189 + - pkg/sentry/usage/cpu.go:42 + - pkg/sentry/vfs/anonfs.go:302 + - pkg/sentry/vfs/anonfs.go:99 + - pkg/sentry/vfs/dentry.go:214 + - pkg/sentry/vfs/epoll.go:168 + - pkg/sentry/vfs/epoll.go:314 + - pkg/sentry/vfs/file_description.go:549 + - pkg/sentry/vfs/file_description_impl_util.go:304 + - pkg/sentry/vfs/file_description_impl_util.go:412 + - pkg/sentry/vfs/filesystem.go:76 + - pkg/sentry/vfs/lock.go:15 + - pkg/sentry/vfs/lock.go:47 + - pkg/sentry/vfs/memxattr/xattr.go:37 + - pkg/sentry/vfs/mount.go:510 + - pkg/sentry/vfs/mount.go:667 + - pkg/sentry/vfs/mount_test.go:106 + - pkg/sentry/vfs/mount_test.go:160 + - pkg/sentry/vfs/mount_test.go:215 + - pkg/sentry/vfs/mount_unsafe.go:153 + - pkg/sentry/vfs/resolving_path.go:228 + - pkg/sentry/vfs/vfs.go:897 + - pkg/shim/runsc/runsc.go:16 + - pkg/shim/runsc/utils.go:16 + - pkg/shim/v1/proc/deleted_state.go:16 + - pkg/shim/v1/proc/exec.go:16 + - pkg/shim/v1/proc/exec_state.go:16 + - pkg/shim/v1/proc/init.go:16 + - pkg/shim/v1/proc/init_state.go:16 + - pkg/shim/v1/proc/io.go:16 + - pkg/shim/v1/proc/process.go:16 + - pkg/shim/v1/proc/types.go:16 + - pkg/shim/v1/proc/utils.go:16 + - pkg/shim/v1/shim/api.go:16 + - pkg/shim/v1/shim/platform.go:16 + - pkg/shim/v1/shim/service.go:16 + - pkg/shim/v1/utils/annotations.go:15 + - pkg/shim/v1/utils/utils.go:15 + - pkg/shim/v1/utils/volumes.go:15 + - pkg/shim/v2/api.go:16 + - pkg/shim/v2/epoll.go:18 + - pkg/shim/v2/options/options.go:15 + - pkg/shim/v2/options/options.go:24 + - pkg/shim/v2/options/options.go:26 + - pkg/shim/v2/runtimeoptions/runtimeoptions.go:16 + - pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go # Generated: exempt all. + - pkg/shim/v2/runtimeoptions/runtimeoptions_test.go:22 + - pkg/shim/v2/service.go:15 + - pkg/shim/v2/service_linux.go:18 + - pkg/state/tests/integer_test.go:23 + - pkg/state/tests/integer_test.go:28 + - pkg/sync/rwmutex_test.go:105 + - pkg/syserr/host_linux.go:35 + - pkg/unet/unet_test.go:634 + - pkg/unet/unet_test.go:662 + - pkg/unet/unet_test.go:703 + - pkg/unet/unet_test.go:98 + - pkg/usermem/addr.go:34 + - pkg/usermem/usermem.go:171 + - pkg/usermem/usermem.go:170 + - runsc/boot/compat.go:22 + - runsc/boot/compat.go:56 + - runsc/boot/loader.go:1115 + - runsc/boot/loader.go:1120 + - runsc/cmd/checkpoint.go:151 + - runsc/config/flags.go:32 + - runsc/container/container.go:641 + - runsc/container/container.go:988 + - runsc/specutils/specutils.go:172 + - runsc/specutils/specutils.go:428 + - runsc/specutils/specutils.go:436 + - runsc/specutils/specutils.go:442 + - runsc/specutils/specutils.go:447 + - runsc/specutils/specutils.go:454 + - test/cmd/test_app/fds.go:171 + - test/iptables/filter_output.go:251 + - test/packetimpact/testbench/connections.go:77 + - tools/bigquery/bigquery.go:106 + - tools/checkescape/test1/test1.go:108 + - tools/checkescape/test1/test1.go:122 + - tools/checkescape/test1/test1.go:137 + - tools/checkescape/test1/test1.go:151 + - tools/checkescape/test1/test1.go:170 + - tools/checkescape/test1/test1.go:39 + - tools/checkescape/test1/test1.go:45 + - tools/checkescape/test1/test1.go:50 + - tools/checkescape/test1/test1.go:64 + - tools/checkescape/test1/test1.go:80 + - tools/checkescape/test1/test1.go:94 + - tools/go_generics/imports.go:51 + - tools/go_generics/imports.go:75 + - tools/go_marshal/gomarshal/generator.go:177 + - tools/go_marshal/gomarshal/generator.go:81 + - tools/go_marshal/gomarshal/generator.go:85 + - tools/go_marshal/test/escape/escape.go:15 + - tools/go_marshal/test/test.go:164 +analyzers: + asmdecl: + external: # Enabled. + assign: + external: + exclude: + - gazelle/walk/walk.go + atomic: + external: # Enabled. + bools: + external: # Enabled. + buildtag: + external: # Enabled. + cgocall: + external: # Enabled. + shadow: # Disable for now. + generated: + exclude: [".*"] + internal: + exclude: [".*"] + composites: # Disable for now. + generated: + exclude: [".*"] + internal: + exclude: [".*"] + errorsas: + external: # Enabled. + httpresponse: + external: # Enabled. + loopclosure: + external: # Enabled. + nilfunc: + external: # Enabled. + nilness: + internal: + exclude: + - pkg/sentry/platform/kvm/kvm_test.go # Intentional. + - tools/bigquery/bigquery.go # False positive. + printf: + external: # Enabled. + shift: + external: # Enabled. + stringintconv: + external: + exclude: + - ".*protobuf/.*.go" # Bad conversions. + - ".*flate/huffman_bit_writer.go" # Bad conversion. + # Runtime internal violations. + - ".*reflect/value.go" + - ".*encoding/xml/xml.go" + - ".*runtime/pprof/internal/profile/proto.go" + - ".*fmt/scan.go" + - ".*go/types/conversions.go" + - ".*golang.org/x/net/dns/dnsmessage/message.go" + tests: + external: # Enabled. + unmarshal: + external: # Enabled. + unreachable: + external: # Enabled. + unsafeptr: + internal: + exclude: + - ".*_test.go" # Exclude tests. + - "pkg/flipcall/.*_unsafe.go" # Special case. + - pkg/gohacks/gohacks_unsafe.go # Special case. + - pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go # Special case. + - pkg/sentry/platform/kvm/bluepill_unsafe.go # Special case. + - pkg/sentry/platform/kvm/machine_unsafe.go # Special case. + - pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go # Special case. + - pkg/sentry/platform/safecopy/safecopy_unsafe.go # Special case. + - pkg/sentry/vfs/mount_unsafe.go # Special case. + - pkg/state/decode_unsafe.go # Special case. + unusedresult: + external: # Enabled. + checkescape: + external: # Enabled. diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 4a26e28de..a0654df2f 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -55,6 +55,8 @@ go_library( "sched.go", "seccomp.go", "sem.go", + "sem_amd64.go", + "sem_arm64.go", "shm.go", "signal.go", "signalfd.go", diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go index 7df02dd6d..006b5a525 100644 --- a/pkg/abi/linux/ioctl.go +++ b/pkg/abi/linux/ioctl.go @@ -121,6 +121,9 @@ const ( // Constants from uapi/linux/fsverity.h. const ( + FS_VERITY_HASH_ALG_SHA256 = 1 + FS_VERITY_HASH_ALG_SHA512 = 2 + FS_IOC_ENABLE_VERITY = 1082156677 FS_IOC_MEASURE_VERITY = 3221513862 ) diff --git a/pkg/abi/linux/sem.go b/pkg/abi/linux/sem.go index 487a626cc..1b2f76c0b 100644 --- a/pkg/abi/linux/sem.go +++ b/pkg/abi/linux/sem.go @@ -34,18 +34,6 @@ const ( const SEM_UNDO = 0x1000 -// SemidDS is equivalent to struct semid64_ds. -// -// +marshal -type SemidDS struct { - SemPerm IPCPerm - SemOTime TimeT - SemCTime TimeT - SemNSems uint64 - unused3 uint64 - unused4 uint64 -} - // Sembuf is equivalent to struct sembuf. // // +marshal slice:SembufSlice diff --git a/pkg/abi/linux/sem_amd64.go b/pkg/abi/linux/sem_amd64.go new file mode 100644 index 000000000..ab980cb4f --- /dev/null +++ b/pkg/abi/linux/sem_amd64.go @@ -0,0 +1,33 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package linux + +// SemidDS is equivalent to struct semid64_ds. +// +// Source: arch/x86/include/uapi/asm/sembuf.h +// +// +marshal +type SemidDS struct { + SemPerm IPCPerm + SemOTime TimeT + unused1 uint64 + SemCTime TimeT + unused2 uint64 + SemNSems uint64 + unused3 uint64 + unused4 uint64 +} diff --git a/pkg/abi/linux/sem_arm64.go b/pkg/abi/linux/sem_arm64.go new file mode 100644 index 000000000..521468fb1 --- /dev/null +++ b/pkg/abi/linux/sem_arm64.go @@ -0,0 +1,31 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package linux + +// SemidDS is equivalent to struct semid64_ds. +// +// Source: include/uapi/asm-generic/sembuf.h +// +// +marshal +type SemidDS struct { + SemPerm IPCPerm + SemOTime TimeT + SemCTime TimeT + SemNSems uint64 + unused3 uint64 + unused4 uint64 +} diff --git a/pkg/context/context.go b/pkg/context/context.go index 2613bc752..f3031fc60 100644 --- a/pkg/context/context.go +++ b/pkg/context/context.go @@ -166,3 +166,27 @@ var bgContext = &logContext{Logger: log.Log()} func Background() Context { return bgContext } + +// WithValue returns a copy of parent in which the value associated with key is +// val. +func WithValue(parent Context, key, val interface{}) Context { + return &withValue{ + Context: parent, + key: key, + val: val, + } +} + +type withValue struct { + Context + key interface{} + val interface{} +} + +// Value implements Context.Value. +func (ctx *withValue) Value(key interface{}) interface{} { + if key == ctx.key { + return ctx.val + } + return ctx.Context.Value(key) +} diff --git a/pkg/merkletree/BUILD b/pkg/merkletree/BUILD index a8fcb2e19..501a9ef21 100644 --- a/pkg/merkletree/BUILD +++ b/pkg/merkletree/BUILD @@ -6,12 +6,18 @@ go_library( name = "merkletree", srcs = ["merkletree.go"], visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/usermem"], + deps = [ + "//pkg/abi/linux", + "//pkg/usermem", + ], ) go_test( name = "merkletree_test", srcs = ["merkletree_test.go"], library = ":merkletree", - deps = ["//pkg/usermem"], + deps = [ + "//pkg/abi/linux", + "//pkg/usermem", + ], ) diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go index d8227b8bd..18457d287 100644 --- a/pkg/merkletree/merkletree.go +++ b/pkg/merkletree/merkletree.go @@ -18,21 +18,32 @@ package merkletree import ( "bytes" "crypto/sha256" + "crypto/sha512" "fmt" "io" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/usermem" ) const ( // sha256DigestSize specifies the digest size of a SHA256 hash. sha256DigestSize = 32 + // sha512DigestSize specifies the digest size of a SHA512 hash. + sha512DigestSize = 64 ) // DigestSize returns the size (in bytes) of a digest. -// TODO(b/156980949): Allow config other hash methods (SHA384/SHA512). -func DigestSize() int { - return sha256DigestSize +// TODO(b/156980949): Allow config SHA384. +func DigestSize(hashAlgorithm int) int { + switch hashAlgorithm { + case linux.FS_VERITY_HASH_ALG_SHA256: + return sha256DigestSize + case linux.FS_VERITY_HASH_ALG_SHA512: + return sha512DigestSize + default: + return -1 + } } // Layout defines the scale of a Merkle tree. @@ -51,11 +62,19 @@ type Layout struct { // InitLayout initializes and returns a new Layout object describing the structure // of a tree. dataSize specifies the size of input data in bytes. -func InitLayout(dataSize int64, dataAndTreeInSameFile bool) Layout { +func InitLayout(dataSize int64, hashAlgorithms int, dataAndTreeInSameFile bool) (Layout, error) { layout := Layout{ blockSize: usermem.PageSize, - // TODO(b/156980949): Allow config other hash methods (SHA384/SHA512). - digestSize: sha256DigestSize, + } + + // TODO(b/156980949): Allow config SHA384. + switch hashAlgorithms { + case linux.FS_VERITY_HASH_ALG_SHA256: + layout.digestSize = sha256DigestSize + case linux.FS_VERITY_HASH_ALG_SHA512: + layout.digestSize = sha512DigestSize + default: + return Layout{}, fmt.Errorf("unexpected hash algorithms") } // treeStart is the offset (in bytes) of the first level of the tree in @@ -88,7 +107,7 @@ func InitLayout(dataSize int64, dataAndTreeInSameFile bool) Layout { } layout.levelOffset = append(layout.levelOffset, treeStart+offset*layout.blockSize) - return layout + return layout, nil } // hashesPerBlock() returns the number of digests in each block. For example, @@ -139,12 +158,33 @@ func (d *VerityDescriptor) String() string { } // verify generates a hash from d, and compares it with expected. -func (d *VerityDescriptor) verify(expected []byte) error { - h := sha256.Sum256([]byte(d.String())) +func (d *VerityDescriptor) verify(expected []byte, hashAlgorithms int) error { + h, err := hashData([]byte(d.String()), hashAlgorithms) + if err != nil { + return err + } if !bytes.Equal(h[:], expected) { return fmt.Errorf("unexpected root hash") } return nil + +} + +// hashData hashes data and returns the result hash based on the hash +// algorithms. +func hashData(data []byte, hashAlgorithms int) ([]byte, error) { + var digest []byte + switch hashAlgorithms { + case linux.FS_VERITY_HASH_ALG_SHA256: + digestArray := sha256.Sum256(data) + digest = digestArray[:] + case linux.FS_VERITY_HASH_ALG_SHA512: + digestArray := sha512.Sum512(data) + digest = digestArray[:] + default: + return nil, fmt.Errorf("unexpected hash algorithms") + } + return digest, nil } // GenerateParams contains the parameters used to generate a Merkle tree. @@ -161,6 +201,8 @@ type GenerateParams struct { UID uint32 // GID is the group ID of the target file. GID uint32 + // HashAlgorithms is the algorithms used to hash data. + HashAlgorithms int // TreeReader is a reader for the Merkle tree. TreeReader io.ReaderAt // TreeWriter is a writer for the Merkle tree. @@ -176,7 +218,10 @@ type GenerateParams struct { // Generate returns a hash of a VerityDescriptor, which contains the file // metadata and the hash from file content. func Generate(params *GenerateParams) ([]byte, error) { - layout := InitLayout(params.Size, params.DataAndTreeInSameFile) + layout, err := InitLayout(params.Size, params.HashAlgorithms, params.DataAndTreeInSameFile) + if err != nil { + return nil, err + } numBlocks := (params.Size + layout.blockSize - 1) / layout.blockSize @@ -218,10 +263,13 @@ func Generate(params *GenerateParams) ([]byte, error) { return nil, err } // Hash the bytes in buf. - digest := sha256.Sum256(buf) + digest, err := hashData(buf, params.HashAlgorithms) + if err != nil { + return nil, err + } if level == layout.rootLevel() { - root = digest[:] + root = digest } // Write the generated hash to the end of the tree file. @@ -246,8 +294,7 @@ func Generate(params *GenerateParams) ([]byte, error) { GID: params.GID, RootHash: root, } - ret := sha256.Sum256([]byte(descriptor.String())) - return ret[:], nil + return hashData([]byte(descriptor.String()), params.HashAlgorithms) } // VerifyParams contains the params used to verify a portion of a file against @@ -269,6 +316,8 @@ type VerifyParams struct { UID uint32 // GID is the group ID of the target file. GID uint32 + // HashAlgorithms is the algorithms used to hash data. + HashAlgorithms int // ReadOffset is the offset of the data range to be verified. ReadOffset int64 // ReadSize is the size of the data range to be verified. @@ -298,7 +347,7 @@ func verifyMetadata(params *VerifyParams, layout *Layout) error { GID: params.GID, RootHash: root, } - return descriptor.verify(params.Expected) + return descriptor.verify(params.Expected, params.HashAlgorithms) } // Verify verifies the content read from data with offset. The content is @@ -313,7 +362,10 @@ func Verify(params *VerifyParams) (int64, error) { if params.ReadSize < 0 { return 0, fmt.Errorf("unexpected read size: %d", params.ReadSize) } - layout := InitLayout(int64(params.Size), params.DataAndTreeInSameFile) + layout, err := InitLayout(int64(params.Size), params.HashAlgorithms, params.DataAndTreeInSameFile) + if err != nil { + return 0, err + } if params.ReadSize == 0 { return 0, verifyMetadata(params, &layout) } @@ -354,7 +406,7 @@ func Verify(params *VerifyParams) (int64, error) { UID: params.UID, GID: params.GID, } - if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.Expected); err != nil { + if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.HashAlgorithms, params.Expected); err != nil { return 0, err } @@ -395,7 +447,7 @@ func Verify(params *VerifyParams) (int64, error) { // fails if the calculated hash from block is different from any level of // hashes stored in tree. And the final root hash is compared with // expected. -func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, expected []byte) error { +func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, hashAlgorithms int, expected []byte) error { if len(dataBlock) != int(layout.blockSize) { return fmt.Errorf("incorrect block size") } @@ -406,8 +458,11 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, for level := 0; level < layout.numLevels(); level++ { // Calculate hash. if level == 0 { - digestArray := sha256.Sum256(dataBlock) - digest = digestArray[:] + h, err := hashData(dataBlock, hashAlgorithms) + if err != nil { + return err + } + digest = h } else { // Read a block in previous level that contains the // hash we just generated, and generate a next level @@ -415,8 +470,11 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, if _, err := tree.ReadAt(treeBlock, layout.blockOffset(level-1, blockIndex)); err != nil { return err } - digestArray := sha256.Sum256(treeBlock) - digest = digestArray[:] + h, err := hashData(treeBlock, hashAlgorithms) + if err != nil { + return err + } + digest = h } // Read the digest for the current block and store in @@ -434,5 +492,5 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, // Verification for the tree succeeded. Now hash the descriptor with // the root hash and compare it with expected. descriptor.RootHash = digest - return descriptor.verify(expected) + return descriptor.verify(expected, hashAlgorithms) } diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go index e1350ebda..0782ca3e7 100644 --- a/pkg/merkletree/merkletree_test.go +++ b/pkg/merkletree/merkletree_test.go @@ -22,54 +22,114 @@ import ( "testing" "time" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/usermem" ) func TestLayout(t *testing.T) { testCases := []struct { dataSize int64 + hashAlgorithms int dataAndTreeInSameFile bool + expectedDigestSize int64 expectedLevelOffset []int64 }{ { dataSize: 100, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, + expectedDigestSize: 32, expectedLevelOffset: []int64{0}, }, { dataSize: 100, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + expectedDigestSize: 64, + expectedLevelOffset: []int64{0}, + }, + { + dataSize: 100, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, + expectedDigestSize: 32, + expectedLevelOffset: []int64{usermem.PageSize}, + }, + { + dataSize: 100, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + expectedDigestSize: 64, expectedLevelOffset: []int64{usermem.PageSize}, }, { dataSize: 1000000, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, + expectedDigestSize: 32, expectedLevelOffset: []int64{0, 2 * usermem.PageSize, 3 * usermem.PageSize}, }, { dataSize: 1000000, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + expectedDigestSize: 64, + expectedLevelOffset: []int64{0, 4 * usermem.PageSize, 5 * usermem.PageSize}, + }, + { + dataSize: 1000000, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, + expectedDigestSize: 32, expectedLevelOffset: []int64{245 * usermem.PageSize, 247 * usermem.PageSize, 248 * usermem.PageSize}, }, { + dataSize: 1000000, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + expectedDigestSize: 64, + expectedLevelOffset: []int64{245 * usermem.PageSize, 249 * usermem.PageSize, 250 * usermem.PageSize}, + }, + { dataSize: 4096 * int64(usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, + expectedDigestSize: 32, expectedLevelOffset: []int64{0, 32 * usermem.PageSize, 33 * usermem.PageSize}, }, { dataSize: 4096 * int64(usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + expectedDigestSize: 64, + expectedLevelOffset: []int64{0, 64 * usermem.PageSize, 65 * usermem.PageSize}, + }, + { + dataSize: 4096 * int64(usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, + expectedDigestSize: 32, expectedLevelOffset: []int64{4096 * usermem.PageSize, 4128 * usermem.PageSize, 4129 * usermem.PageSize}, }, + { + dataSize: 4096 * int64(usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + expectedDigestSize: 64, + expectedLevelOffset: []int64{4096 * usermem.PageSize, 4160 * usermem.PageSize, 4161 * usermem.PageSize}, + }, } for _, tc := range testCases { t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) { - l := InitLayout(tc.dataSize, tc.dataAndTreeInSameFile) + l, err := InitLayout(tc.dataSize, tc.hashAlgorithms, tc.dataAndTreeInSameFile) + if err != nil { + t.Fatalf("Failed to InitLayout: %v", err) + } if l.blockSize != int64(usermem.PageSize) { t.Errorf("Got blockSize %d, want %d", l.blockSize, usermem.PageSize) } - if l.digestSize != sha256DigestSize { + if l.digestSize != tc.expectedDigestSize { t.Errorf("Got digestSize %d, want %d", l.digestSize, sha256DigestSize) } if l.numLevels() != len(tc.expectedLevelOffset) { @@ -118,24 +178,49 @@ func TestGenerate(t *testing.T) { // The input data has size dataSize. It starts with the data in startWith, // and all other bytes are zeroes. testCases := []struct { - data []byte - expectedHash []byte + data []byte + hashAlgorithms int + expectedHash []byte }{ { - data: bytes.Repeat([]byte{0}, usermem.PageSize), - expectedHash: []byte{64, 253, 58, 72, 192, 131, 82, 184, 193, 33, 108, 142, 43, 46, 179, 134, 244, 21, 29, 190, 14, 39, 66, 129, 6, 46, 200, 211, 30, 247, 191, 252}, + data: bytes.Repeat([]byte{0}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + expectedHash: []byte{64, 253, 58, 72, 192, 131, 82, 184, 193, 33, 108, 142, 43, 46, 179, 134, 244, 21, 29, 190, 14, 39, 66, 129, 6, 46, 200, 211, 30, 247, 191, 252}, + }, + { + data: bytes.Repeat([]byte{0}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + expectedHash: []byte{14, 27, 126, 158, 9, 94, 163, 51, 243, 162, 82, 167, 183, 127, 93, 121, 221, 23, 184, 59, 104, 166, 111, 49, 161, 195, 229, 111, 121, 201, 233, 68, 10, 154, 78, 142, 154, 236, 170, 156, 110, 167, 15, 144, 155, 97, 241, 235, 202, 233, 246, 217, 138, 88, 152, 179, 238, 46, 247, 185, 125, 20, 101, 201}, + }, + { + data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + expectedHash: []byte{182, 223, 218, 62, 65, 185, 160, 219, 93, 119, 186, 88, 205, 32, 122, 231, 173, 72, 78, 76, 65, 57, 177, 146, 159, 39, 44, 123, 230, 156, 97, 26}, + }, + { + data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + expectedHash: []byte{55, 204, 240, 1, 224, 252, 58, 131, 251, 174, 45, 140, 107, 57, 118, 11, 18, 236, 203, 204, 19, 59, 27, 196, 3, 78, 21, 7, 22, 98, 197, 128, 17, 128, 90, 122, 54, 83, 253, 108, 156, 67, 59, 229, 236, 241, 69, 88, 99, 44, 127, 109, 204, 183, 150, 232, 187, 57, 228, 137, 209, 235, 241, 172}, }, { - data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), - expectedHash: []byte{182, 223, 218, 62, 65, 185, 160, 219, 93, 119, 186, 88, 205, 32, 122, 231, 173, 72, 78, 76, 65, 57, 177, 146, 159, 39, 44, 123, 230, 156, 97, 26}, + data: []byte{'a'}, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + expectedHash: []byte{28, 201, 8, 36, 150, 178, 111, 5, 193, 212, 129, 205, 206, 124, 211, 90, 224, 142, 81, 183, 72, 165, 243, 240, 242, 241, 76, 127, 101, 61, 63, 11}, }, { - data: []byte{'a'}, - expectedHash: []byte{28, 201, 8, 36, 150, 178, 111, 5, 193, 212, 129, 205, 206, 124, 211, 90, 224, 142, 81, 183, 72, 165, 243, 240, 242, 241, 76, 127, 101, 61, 63, 11}, + data: []byte{'a'}, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + expectedHash: []byte{207, 233, 114, 94, 113, 212, 243, 160, 59, 232, 226, 77, 28, 81, 176, 61, 211, 213, 222, 190, 148, 196, 90, 166, 237, 56, 113, 148, 230, 154, 23, 105, 14, 97, 144, 211, 12, 122, 226, 207, 167, 203, 136, 193, 38, 249, 227, 187, 92, 238, 101, 97, 170, 255, 246, 209, 246, 98, 241, 150, 175, 253, 173, 206}, }, { - data: bytes.Repeat([]byte{'a'}, usermem.PageSize), - expectedHash: []byte{106, 58, 160, 152, 41, 68, 38, 108, 245, 74, 177, 84, 64, 193, 19, 176, 249, 86, 27, 193, 85, 164, 99, 240, 79, 104, 148, 222, 76, 46, 191, 79}, + data: bytes.Repeat([]byte{'a'}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + expectedHash: []byte{106, 58, 160, 152, 41, 68, 38, 108, 245, 74, 177, 84, 64, 193, 19, 176, 249, 86, 27, 193, 85, 164, 99, 240, 79, 104, 148, 222, 76, 46, 191, 79}, + }, + { + data: bytes.Repeat([]byte{'a'}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + expectedHash: []byte{110, 103, 29, 250, 27, 211, 235, 119, 112, 65, 49, 156, 6, 92, 66, 105, 133, 1, 187, 172, 169, 13, 186, 34, 105, 72, 252, 131, 12, 159, 91, 188, 79, 184, 240, 227, 40, 164, 72, 193, 65, 31, 227, 153, 191, 6, 117, 42, 82, 122, 33, 255, 92, 215, 215, 249, 2, 131, 170, 134, 39, 192, 222, 33}, }, } @@ -149,6 +234,7 @@ func TestGenerate(t *testing.T) { Mode: defaultMode, UID: defaultUID, GID: defaultGID, + HashAlgorithms: tc.hashAlgorithms, TreeReader: &tree, TreeWriter: &tree, DataAndTreeInSameFile: dataAndTreeInSameFile, @@ -348,77 +434,81 @@ func TestVerify(t *testing.T) { // Generate random bytes in data. rand.Read(data) - for _, dataAndTreeInSameFile := range []bool{false, true} { - var tree bytesReadWriter - genParams := GenerateParams{ - Size: int64(len(data)), - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - TreeReader: &tree, - TreeWriter: &tree, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } - if dataAndTreeInSameFile { - tree.Write(data) - genParams.File = &tree - } else { - genParams.File = &bytesReadWriter{ - bytes: data, + for _, hashAlgorithms := range []int{linux.FS_VERITY_HASH_ALG_SHA256, linux.FS_VERITY_HASH_ALG_SHA512} { + for _, dataAndTreeInSameFile := range []bool{false, true} { + var tree bytesReadWriter + genParams := GenerateParams{ + Size: int64(len(data)), + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + HashAlgorithms: hashAlgorithms, + TreeReader: &tree, + TreeWriter: &tree, + DataAndTreeInSameFile: dataAndTreeInSameFile, + } + if dataAndTreeInSameFile { + tree.Write(data) + genParams.File = &tree + } else { + genParams.File = &bytesReadWriter{ + bytes: data, + } + } + hash, err := Generate(&genParams) + if err != nil { + t.Fatalf("Generate failed: %v", err) } - } - hash, err := Generate(&genParams) - if err != nil { - t.Fatalf("Generate failed: %v", err) - } - // Flip a bit in data and checks Verify results. - var buf bytes.Buffer - data[tc.modifyByte] ^= 1 - verifyParams := VerifyParams{ - Out: &buf, - File: bytes.NewReader(data), - Tree: &tree, - Size: tc.dataSize, - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - ReadOffset: tc.verifyStart, - ReadSize: tc.verifySize, - Expected: hash, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } - if tc.modifyName { - verifyParams.Name = defaultName + "abc" - } - if tc.modifyMode { - verifyParams.Mode = defaultMode + 1 - } - if tc.modifyUID { - verifyParams.UID = defaultUID + 1 - } - if tc.modifyGID { - verifyParams.GID = defaultGID + 1 - } - if tc.shouldSucceed { - n, err := Verify(&verifyParams) - if err != nil && err != io.EOF { - t.Errorf("Verification failed when expected to succeed: %v", err) + // Flip a bit in data and checks Verify results. + var buf bytes.Buffer + data[tc.modifyByte] ^= 1 + verifyParams := VerifyParams{ + Out: &buf, + File: bytes.NewReader(data), + Tree: &tree, + Size: tc.dataSize, + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + HashAlgorithms: hashAlgorithms, + ReadOffset: tc.verifyStart, + ReadSize: tc.verifySize, + Expected: hash, + DataAndTreeInSameFile: dataAndTreeInSameFile, } - if n != tc.verifySize { - t.Errorf("Got Verify output size %d, want %d", n, tc.verifySize) + if tc.modifyName { + verifyParams.Name = defaultName + "abc" } - if int64(buf.Len()) != tc.verifySize { - t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), tc.verifySize) + if tc.modifyMode { + verifyParams.Mode = defaultMode + 1 } - if !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) { - t.Errorf("Incorrect output buf from Verify") + if tc.modifyUID { + verifyParams.UID = defaultUID + 1 } - } else { - if _, err := Verify(&verifyParams); err == nil { - t.Errorf("Verification succeeded when expected to fail") + if tc.modifyGID { + verifyParams.GID = defaultGID + 1 + } + if tc.shouldSucceed { + n, err := Verify(&verifyParams) + if err != nil && err != io.EOF { + t.Errorf("Verification failed when expected to succeed: %v", err) + } + if n != tc.verifySize { + t.Errorf("Got Verify output size %d, want %d", n, tc.verifySize) + } + if int64(buf.Len()) != tc.verifySize { + t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), tc.verifySize) + } + if !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) { + t.Errorf("Incorrect output buf from Verify") + } + } else { + if _, err := Verify(&verifyParams); err == nil { + t.Errorf("Verification succeeded when expected to fail") + } } } } @@ -435,87 +525,91 @@ func TestVerifyRandom(t *testing.T) { // Generate random bytes in data. rand.Read(data) - for _, dataAndTreeInSameFile := range []bool{false, true} { - var tree bytesReadWriter - genParams := GenerateParams{ - Size: int64(len(data)), - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - TreeReader: &tree, - TreeWriter: &tree, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } + for _, hashAlgorithms := range []int{linux.FS_VERITY_HASH_ALG_SHA256, linux.FS_VERITY_HASH_ALG_SHA512} { + for _, dataAndTreeInSameFile := range []bool{false, true} { + var tree bytesReadWriter + genParams := GenerateParams{ + Size: int64(len(data)), + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + HashAlgorithms: hashAlgorithms, + TreeReader: &tree, + TreeWriter: &tree, + DataAndTreeInSameFile: dataAndTreeInSameFile, + } - if dataAndTreeInSameFile { - tree.Write(data) - genParams.File = &tree - } else { - genParams.File = &bytesReadWriter{ - bytes: data, + if dataAndTreeInSameFile { + tree.Write(data) + genParams.File = &tree + } else { + genParams.File = &bytesReadWriter{ + bytes: data, + } + } + hash, err := Generate(&genParams) + if err != nil { + t.Fatalf("Generate failed: %v", err) } - } - hash, err := Generate(&genParams) - if err != nil { - t.Fatalf("Generate failed: %v", err) - } - // Pick a random portion of data. - start := rand.Int63n(dataSize - 1) - size := rand.Int63n(dataSize) + 1 + // Pick a random portion of data. + start := rand.Int63n(dataSize - 1) + size := rand.Int63n(dataSize) + 1 - var buf bytes.Buffer - verifyParams := VerifyParams{ - Out: &buf, - File: bytes.NewReader(data), - Tree: &tree, - Size: dataSize, - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - ReadOffset: start, - ReadSize: size, - Expected: hash, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } + var buf bytes.Buffer + verifyParams := VerifyParams{ + Out: &buf, + File: bytes.NewReader(data), + Tree: &tree, + Size: dataSize, + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + HashAlgorithms: hashAlgorithms, + ReadOffset: start, + ReadSize: size, + Expected: hash, + DataAndTreeInSameFile: dataAndTreeInSameFile, + } - // Checks that the random portion of data from the original data is - // verified successfully. - n, err := Verify(&verifyParams) - if err != nil && err != io.EOF { - t.Errorf("Verification failed for correct data: %v", err) - } - if size > dataSize-start { - size = dataSize - start - } - if n != size { - t.Errorf("Got Verify output size %d, want %d", n, size) - } - if int64(buf.Len()) != size { - t.Errorf("Got Verify output buf size %d, want %d", buf.Len(), size) - } - if !bytes.Equal(data[start:start+size], buf.Bytes()) { - t.Errorf("Incorrect output buf from Verify") - } + // Checks that the random portion of data from the original data is + // verified successfully. + n, err := Verify(&verifyParams) + if err != nil && err != io.EOF { + t.Errorf("Verification failed for correct data: %v", err) + } + if size > dataSize-start { + size = dataSize - start + } + if n != size { + t.Errorf("Got Verify output size %d, want %d", n, size) + } + if int64(buf.Len()) != size { + t.Errorf("Got Verify output buf size %d, want %d", buf.Len(), size) + } + if !bytes.Equal(data[start:start+size], buf.Bytes()) { + t.Errorf("Incorrect output buf from Verify") + } - // Verify that modified metadata should fail verification. - buf.Reset() - verifyParams.Name = defaultName + "abc" - if _, err := Verify(&verifyParams); err == nil { - t.Error("Verify succeeded for modified metadata, expect failure") - } + // Verify that modified metadata should fail verification. + buf.Reset() + verifyParams.Name = defaultName + "abc" + if _, err := Verify(&verifyParams); err == nil { + t.Error("Verify succeeded for modified metadata, expect failure") + } - // Flip a random bit in randPortion, and check that verification fails. - buf.Reset() - randBytePos := rand.Int63n(size) - data[start+randBytePos] ^= 1 - verifyParams.File = bytes.NewReader(data) - verifyParams.Name = defaultName + // Flip a random bit in randPortion, and check that verification fails. + buf.Reset() + randBytePos := rand.Int63n(size) + data[start+randBytePos] ^= 1 + verifyParams.File = bytes.NewReader(data) + verifyParams.Name = defaultName - if _, err := Verify(&verifyParams); err == nil { - t.Error("Verification succeeded for modified data, expect failure") + if _, err := Verify(&verifyParams); err == nil { + t.Error("Verification succeeded for modified data, expect failure") + } } } } diff --git a/pkg/refs_vfs2/BUILD b/pkg/refsvfs2/BUILD index 577b827a5..245e33d2d 100644 --- a/pkg/refs_vfs2/BUILD +++ b/pkg/refsvfs2/BUILD @@ -19,8 +19,16 @@ go_template( ) go_library( - name = "refs_vfs2", - srcs = ["refs.go"], - visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/context"], + name = "refsvfs2", + srcs = [ + "refs.go", + "refs_map.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/context", + "//pkg/log", + "//pkg/refs", + "//pkg/sync", + ], ) diff --git a/pkg/refs_vfs2/refs.go b/pkg/refsvfs2/refs.go index 99a074e96..ef8beb659 100644 --- a/pkg/refs_vfs2/refs.go +++ b/pkg/refsvfs2/refs.go @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package refs_vfs2 defines an interface for a reference-counted object. -package refs_vfs2 +// Package refsvfs2 defines an interface for a reference-counted object. +package refsvfs2 import ( "gvisor.dev/gvisor/pkg/context" diff --git a/pkg/refsvfs2/refs_map.go b/pkg/refsvfs2/refs_map.go new file mode 100644 index 000000000..be75b0cc2 --- /dev/null +++ b/pkg/refsvfs2/refs_map.go @@ -0,0 +1,97 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package refsvfs2 + +import ( + "fmt" + "strings" + + "gvisor.dev/gvisor/pkg/log" + refs_vfs1 "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sync" +) + +// TODO(gvisor.dev/issue/1193): re-enable once kernfs refs are fixed. +var ignored []string = []string{"kernfs.", "proc.", "sys.", "devpts.", "fuse."} + +var ( + // liveObjects is a global map of reference-counted objects. Objects are + // inserted when leak check is enabled, and they are removed when they are + // destroyed. It is protected by liveObjectsMu. + liveObjects map[CheckedObject]struct{} + liveObjectsMu sync.Mutex +) + +// CheckedObject represents a reference-counted object with an informative +// leak detection message. +type CheckedObject interface { + // LeakMessage supplies a warning to be printed upon leak detection. + LeakMessage() string +} + +func init() { + liveObjects = make(map[CheckedObject]struct{}) +} + +// LeakCheckEnabled returns whether leak checking is enabled. The following +// functions should only be called if it returns true. +func LeakCheckEnabled() bool { + return refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking +} + +// Register adds obj to the live object map. +func Register(obj CheckedObject, typ string) { + for _, str := range ignored { + if strings.Contains(typ, str) { + return + } + } + liveObjectsMu.Lock() + if _, ok := liveObjects[obj]; ok { + panic(fmt.Sprintf("Unexpected entry in leak checking map: reference %p already added", obj)) + } + liveObjects[obj] = struct{}{} + liveObjectsMu.Unlock() +} + +// Unregister removes obj from the live object map. +func Unregister(obj CheckedObject, typ string) { + liveObjectsMu.Lock() + defer liveObjectsMu.Unlock() + if _, ok := liveObjects[obj]; !ok { + for _, str := range ignored { + if strings.Contains(typ, str) { + return + } + } + panic(fmt.Sprintf("Expected to find entry in leak checking map for reference %p", obj)) + } + delete(liveObjects, obj) +} + +// DoLeakCheck iterates through the live object map and logs a message for each +// object. It is called once no reference-counted objects should be reachable +// anymore, at which point anything left in the map is considered a leak. +func DoLeakCheck() { + liveObjectsMu.Lock() + defer liveObjectsMu.Unlock() + leaked := len(liveObjects) + if leaked > 0 { + log.Warningf("Leak checking detected %d leaked objects:", leaked) + for obj := range liveObjects { + log.Warningf(obj.LeakMessage()) + } + } +} diff --git a/pkg/refs_vfs2/refs_template.go b/pkg/refsvfs2/refs_template.go index d9b552896..ec295ef5b 100644 --- a/pkg/refs_vfs2/refs_template.go +++ b/pkg/refsvfs2/refs_template.go @@ -21,11 +21,9 @@ package refs_template import ( "fmt" - "runtime" "sync/atomic" - "gvisor.dev/gvisor/pkg/log" - refs_vfs1 "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/refsvfs2" ) // T is the type of the reference counted object. It is only used to customize @@ -42,11 +40,6 @@ var ownerType *T // Note that the number of references is actually refCount + 1 so that a default // zero-value Refs object contains one reference. // -// TODO(gvisor.dev/issue/1486): Store stack traces when leak check is enabled in -// a map with 16-bit hashes, and store the hash in the top 16 bits of refCount. -// This will allow us to add stack trace information to the leak messages -// without growing the size of Refs. -// // +stateify savable type Refs struct { // refCount is composed of two fields: @@ -59,24 +52,16 @@ type Refs struct { refCount int64 } -func (r *Refs) finalize() { - var note string - switch refs_vfs1.GetLeakMode() { - case refs_vfs1.NoLeakChecking: - return - case refs_vfs1.UninitializedLeakChecking: - note = "(Leak checker uninitialized): " - } - if n := r.ReadRefs(); n != 0 { - log.Warningf("%sRefs %p owned by %T garbage collected with ref count of %d (want 0)", note, r, ownerType, n) +// EnableLeakCheck enables reference leak checking on r. +func (r *Refs) EnableLeakCheck() { + if refsvfs2.LeakCheckEnabled() { + refsvfs2.Register(r, fmt.Sprintf("%T", ownerType)) } } -// EnableLeakCheck checks for reference leaks when Refs gets garbage collected. -func (r *Refs) EnableLeakCheck() { - if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking { - runtime.SetFinalizer(r, (*Refs).finalize) - } +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *Refs) LeakMessage() string { + return fmt.Sprintf("%T %p: reference count of %d instead of 0", ownerType, r, r.ReadRefs()) } // ReadRefs returns the current number of references. The returned count is @@ -91,7 +76,7 @@ func (r *Refs) ReadRefs() int64 { //go:nosplit func (r *Refs) IncRef() { if v := atomic.AddInt64(&r.refCount, 1); v <= 0 { - panic(fmt.Sprintf("Incrementing non-positive ref count %p owned by %T", r, ownerType)) + panic(fmt.Sprintf("Incrementing non-positive count %p on %T", r, ownerType)) } } @@ -134,9 +119,18 @@ func (r *Refs) DecRef(destroy func()) { panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %T", r, ownerType)) case v == -1: + if refsvfs2.LeakCheckEnabled() { + refsvfs2.Unregister(r, fmt.Sprintf("%T", ownerType)) + } // Call the destructor. if destroy != nil { destroy() } } } + +func (r *Refs) afterLoad() { + if refsvfs2.LeakCheckEnabled() && r.ReadRefs() > 0 { + r.EnableLeakCheck() + } +} diff --git a/pkg/sentry/control/state.go b/pkg/sentry/control/state.go index 41feeffe3..d800f2c85 100644 --- a/pkg/sentry/control/state.go +++ b/pkg/sentry/control/state.go @@ -69,5 +69,5 @@ func (s *State) Save(o *SaveOpts, _ *struct{}) error { s.Kernel.Kill(kernel.ExitStatus{}) }, } - return saveOpts.Save(s.Kernel, s.Watchdog) + return saveOpts.Save(s.Kernel.SupervisorContext(), s.Kernel, s.Watchdog) } diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go index 655ea549b..ff5d49fbd 100644 --- a/pkg/sentry/devices/tundev/tundev.go +++ b/pkg/sentry/devices/tundev/tundev.go @@ -39,6 +39,8 @@ const ( ) // tunDevice implements vfs.Device for /dev/net/tun. +// +// +stateify savable type tunDevice struct{} // Open implements vfs.Device.Open. @@ -53,6 +55,8 @@ func (tunDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opt } // tunFD implements vfs.FileDescriptionImpl for /dev/net/tun. +// +// +stateify savable type tunFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go index 1390a9a7f..4468f5dd2 100644 --- a/pkg/sentry/fs/fsutil/host_file_mapper.go +++ b/pkg/sentry/fs/fsutil/host_file_mapper.go @@ -70,6 +70,13 @@ func (f *HostFileMapper) Init() { f.mappings = make(map[uint64]mapping) } +// IsInited returns true if f.Init() has been called. This is used when +// restoring a checkpoint that contains a HostFileMapper that may or may not +// have been initialized. +func (f *HostFileMapper) IsInited() bool { + return f.refs != nil +} + // NewHostFileMapper returns an initialized HostFileMapper allocated on the // heap with no references or cached mappings. func NewHostFileMapper() *HostFileMapper { diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 22d658acf..450044c9c 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -92,6 +92,7 @@ func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bo "gid_map": newGIDMap(t, msrc), "io": newIO(t, msrc, isThreadGroup), "maps": newMaps(t, msrc), + "mem": newMem(t, msrc), "mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc), "mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc), "net": newNetDir(t, msrc), @@ -399,6 +400,88 @@ func newNamespaceDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { return newProcInode(t, d, msrc, fs.SpecialDirectory, t) } +// memData implements fs.Inode for /proc/[pid]/mem. +// +// +stateify savable +type memData struct { + fsutil.SimpleFileInode + + t *kernel.Task +} + +// memDataFile implements fs.FileOperations for /proc/[pid]/mem. +// +// +stateify savable +type memDataFile struct { + fsutil.FileGenericSeek `state:"nosave"` + fsutil.FileNoIoctl `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoWrite `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileNoopFsync `state:"nosave"` + fsutil.FileNoopRelease `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + waiter.AlwaysReady `state:"nosave"` + + t *kernel.Task +} + +func newMem(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { + inode := &memData{ + SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0400), linux.PROC_SUPER_MAGIC), + t: t, + } + return newProcInode(t, inode, msrc, fs.SpecialFile, t) +} + +// Truncate implements fs.InodeOperations.Truncate. +func (m *memData) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + +// GetFile implements fs.InodeOperations.GetFile. +func (m *memData) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + // TODO(gvisor.dev/issue/260): Add check for PTRACE_MODE_ATTACH_FSCREDS + // Permission to read this file is governed by PTRACE_MODE_ATTACH_FSCREDS + // Since we dont implement setfsuid/setfsgid we can just use PTRACE_MODE_ATTACH + if !kernel.ContextCanTrace(ctx, m.t, true) { + return nil, syserror.EACCES + } + if err := checkTaskState(m.t); err != nil { + return nil, err + } + // Enable random access reads + flags.Pread = true + return fs.NewFile(ctx, dirent, flags, &memDataFile{t: m.t}), nil +} + +// Read implements fs.FileOperations.Read. +func (m *memDataFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + if dst.NumBytes() == 0 { + return 0, nil + } + mm, err := getTaskMM(m.t) + if err != nil { + return 0, nil + } + defer mm.DecUsers(ctx) + // Buffer the read data because of MM locks + buf := make([]byte, dst.NumBytes()) + n, readErr := mm.CopyIn(ctx, usermem.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true}) + if n > 0 { + if _, err := dst.CopyOut(ctx, buf[:n]); err != nil { + return 0, syserror.EFAULT + } + return int64(n), nil + } + if readErr != nil { + return 0, syserror.EIO + } + return 0, nil +} + // mapsData implements seqfile.SeqSource for /proc/[pid]/maps. // // +stateify savable diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD index 84baaac66..6af3c3781 100644 --- a/pkg/sentry/fsimpl/devpts/BUILD +++ b/pkg/sentry/fsimpl/devpts/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "root_inode_refs.go", package = "devpts", prefix = "rootInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "rootInode", }, @@ -33,6 +33,7 @@ go_library( "//pkg/marshal", "//pkg/marshal/primitive", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go index d5c5aaa8c..9185877f6 100644 --- a/pkg/sentry/fsimpl/devpts/devpts.go +++ b/pkg/sentry/fsimpl/devpts/devpts.go @@ -60,7 +60,7 @@ func (fstype *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Vir } fstype.initOnce.Do(func() { - fs, root, err := fstype.newFilesystem(vfsObj, creds) + fs, root, err := fstype.newFilesystem(ctx, vfsObj, creds) if err != nil { fstype.initErr = err return @@ -93,7 +93,7 @@ type filesystem struct { // newFilesystem creates a new devpts filesystem with root directory and ptmx // master inode. It returns the filesystem and root Dentry. -func (fstype *FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*filesystem, *kernfs.Dentry, error) { +func (fstype *FilesystemType) newFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*filesystem, *kernfs.Dentry, error) { devMinor, err := vfsObj.GetAnonBlockDevMinor() if err != nil { return nil, nil, err @@ -108,7 +108,7 @@ func (fstype *FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds root := &rootInode{ replicas: make(map[uint32]*replicaInode), } - root.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555) + root.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555) root.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) root.EnableLeakCheck() @@ -120,7 +120,7 @@ func (fstype *FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds master := &masterInode{ root: root, } - master.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 2, linux.ModeCharacterDevice|0666) + master.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, devMinor, 2, linux.ModeCharacterDevice|0666) // Add the master as a child of the root. links := root.OrderedChildren.Populate(map[string]kernfs.Inode{ @@ -170,7 +170,7 @@ type rootInode struct { var _ kernfs.Inode = (*rootInode)(nil) // allocateTerminal creates a new Terminal and installs a pts node for it. -func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error) { +func (i *rootInode) allocateTerminal(ctx context.Context, creds *auth.Credentials) (*Terminal, error) { i.mu.Lock() defer i.mu.Unlock() if i.nextIdx == math.MaxUint32 { @@ -192,7 +192,7 @@ func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error) } // Linux always uses pty index + 3 as the inode id. See // fs/devpts/inode.c:devpts_pty_new(). - replica.InodeAttrs.Init(creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600) + replica.InodeAttrs.Init(ctx, creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600) i.replicas[idx] = replica return t, nil @@ -248,9 +248,10 @@ func (i *rootInode) Lookup(ctx context.Context, name string) (kernfs.Inode, erro } // IterDirents implements kernfs.Inode.IterDirents. -func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { +func (i *rootInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { i.mu.Lock() defer i.mu.Unlock() + i.InodeAttrs.TouchAtime(ctx, mnt) ids := make([]int, 0, len(i.replicas)) for id := range i.replicas { ids = append(ids, int(id)) diff --git a/pkg/sentry/fsimpl/devpts/line_discipline.go b/pkg/sentry/fsimpl/devpts/line_discipline.go index e6b0e81cf..ae95fdd08 100644 --- a/pkg/sentry/fsimpl/devpts/line_discipline.go +++ b/pkg/sentry/fsimpl/devpts/line_discipline.go @@ -100,10 +100,10 @@ type lineDiscipline struct { column int // masterWaiter is used to wait on the master end of the TTY. - masterWaiter waiter.Queue `state:"zerovalue"` + masterWaiter waiter.Queue // replicaWaiter is used to wait on the replica end of the TTY. - replicaWaiter waiter.Queue `state:"zerovalue"` + replicaWaiter waiter.Queue } func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline { diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go index fda30fb93..e91fa26a4 100644 --- a/pkg/sentry/fsimpl/devpts/master.go +++ b/pkg/sentry/fsimpl/devpts/master.go @@ -50,7 +50,7 @@ var _ kernfs.Inode = (*masterInode)(nil) // Open implements kernfs.Inode.Open. func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - t, err := mi.root.allocateTerminal(rp.Credentials()) + t, err := mi.root.allocateTerminal(ctx, rp.Credentials()) if err != nil { return nil, err } diff --git a/pkg/sentry/fsimpl/devtmpfs/BUILD b/pkg/sentry/fsimpl/devtmpfs/BUILD index 01bbee5ad..e49a04c1b 100644 --- a/pkg/sentry/fsimpl/devtmpfs/BUILD +++ b/pkg/sentry/fsimpl/devtmpfs/BUILD @@ -4,7 +4,10 @@ licenses(["notice"]) go_library( name = "devtmpfs", - srcs = ["devtmpfs.go"], + srcs = [ + "devtmpfs.go", + "save_restore.go", + ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/fsimpl/devtmpfs/save_restore.go b/pkg/sentry/fsimpl/devtmpfs/save_restore.go new file mode 100644 index 000000000..28832d850 --- /dev/null +++ b/pkg/sentry/fsimpl/devtmpfs/save_restore.go @@ -0,0 +1,23 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package devtmpfs + +// afterLoad is invoked by stateify. +func (fst *FilesystemType) afterLoad() { + if fst.fs != nil { + // Ensure that we don't create another filesystem. + fst.initOnce.Do(func() {}) + } +} diff --git a/pkg/sentry/fsimpl/eventfd/eventfd.go b/pkg/sentry/fsimpl/eventfd/eventfd.go index 1c27ad700..5b29f2358 100644 --- a/pkg/sentry/fsimpl/eventfd/eventfd.go +++ b/pkg/sentry/fsimpl/eventfd/eventfd.go @@ -43,7 +43,7 @@ type EventFileDescription struct { // queue is used to notify interested parties when the event object // becomes readable or writable. - queue waiter.Queue `state:"zerovalue"` + queue waiter.Queue // mu protects the fields below. mu sync.Mutex `state:"nosave"` diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD index 045d7ab08..2158b1bbc 100644 --- a/pkg/sentry/fsimpl/fuse/BUILD +++ b/pkg/sentry/fsimpl/fuse/BUILD @@ -20,7 +20,7 @@ go_template_instance( out = "inode_refs.go", package = "fuse", prefix = "inode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "inode", }, @@ -49,6 +49,7 @@ go_library( "//pkg/log", "//pkg/marshal", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/fsimpl/devtmpfs", "//pkg/sentry/fsimpl/kernfs", diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index e39df21c6..e7ef5998e 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -205,7 +205,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } // root is the fusefs root directory. - root := fs.newRootInode(creds, fsopts.rootMode) + root := fs.newRootInode(ctx, creds, fsopts.rootMode) return fs.VFSFilesystem(), root.VFSDentry(), nil } @@ -284,9 +284,9 @@ type inode struct { link string } -func (fs *filesystem) newRootInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry { +func (fs *filesystem) newRootInode(ctx context.Context, creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry { i := &inode{fs: fs, nodeID: 1} - i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755) + i.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755) i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) i.EnableLeakCheck() @@ -295,10 +295,10 @@ func (fs *filesystem) newRootInode(creds *auth.Credentials, mode linux.FileMode) return &d } -func (fs *filesystem) newInode(nodeID uint64, attr linux.FUSEAttr) kernfs.Inode { +func (fs *filesystem) newInode(ctx context.Context, nodeID uint64, attr linux.FUSEAttr) kernfs.Inode { i := &inode{fs: fs, nodeID: nodeID} creds := auth.Credentials{EffectiveKGID: auth.KGID(attr.UID), EffectiveKUID: auth.KUID(attr.UID)} - i.InodeAttrs.Init(&creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode)) + i.InodeAttrs.Init(ctx, &creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode)) atomic.StoreUint64(&i.size, attr.Size) i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) i.EnableLeakCheck() @@ -424,7 +424,7 @@ func (i *inode) Keep() bool { } // IterDirents implements kernfs.Inode.IterDirents. -func (*inode) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { +func (*inode) IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { return offset, nil } @@ -544,7 +544,7 @@ func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMo if opcode != linux.FUSE_LOOKUP && ((out.Attr.Mode&linux.S_IFMT)^uint32(fileType) != 0 || out.NodeID == 0 || out.NodeID == linux.FUSE_ROOT_ID) { return nil, syserror.EIO } - child := i.fs.newInode(out.NodeID, out.Attr) + child := i.fs.newInode(ctx, out.NodeID, out.Attr) return child, nil } @@ -696,7 +696,7 @@ func (i *inode) getAttr(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOp } // Set the metadata of kernfs.InodeAttrs. - if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{ + if err := i.InodeAttrs.SetStat(ctx, fs, creds, vfs.SetStatOptions{ Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor), }); err != nil { return linux.FUSEAttr{}, err @@ -812,7 +812,7 @@ func (i *inode) setAttr(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre } // Set the metadata of kernfs.InodeAttrs. - if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{ + if err := i.InodeAttrs.SetStat(ctx, fs, creds, vfs.SetStatOptions{ Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor), }); err != nil { return err diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go index 625d1547f..2d396e84c 100644 --- a/pkg/sentry/fsimpl/fuse/read_write.go +++ b/pkg/sentry/fsimpl/fuse/read_write.go @@ -132,7 +132,7 @@ func (fs *filesystem) ReadCallback(ctx context.Context, fd *regularFileFD, off u // May need to update the signature. i := fd.inode() - // TODO(gvisor.dev/issue/1193): Invalidate or update atime. + i.InodeAttrs.TouchAtime(ctx, fd.vfsfd.Mount()) // Reached EOF. if sizeRead < size { @@ -179,6 +179,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, Flags: fd.statusFlags(), } + inode := fd.inode() var written uint32 // This loop is intended for fragmented write where the bytes to write is @@ -203,7 +204,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, in.Offset = off + uint64(written) in.Size = toWrite - req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_WRITE, &in) + req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), inode.nodeID, linux.FUSE_WRITE, &in) if err != nil { return 0, err } @@ -237,6 +238,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, break } } + inode.InodeAttrs.TouchCMtime(ctx) return written, nil } diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index ad0afc41b..4c3e9acf8 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -38,6 +38,7 @@ go_library( "host_named_pipe.go", "p9file.go", "regular_file.go", + "save_restore.go", "socket.go", "special_file.go", "symlink.go", @@ -53,6 +54,7 @@ go_library( "//pkg/log", "//pkg/p9", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/lock", @@ -70,6 +72,7 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/unet", diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go index 18c884b59..e993c8e36 100644 --- a/pkg/sentry/fsimpl/gofer/directory.go +++ b/pkg/sentry/fsimpl/gofer/directory.go @@ -16,16 +16,17 @@ package gofer import ( "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -92,7 +93,7 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { child := &dentry{ refs: 1, // held by d fs: d.fs, - ino: d.fs.nextSyntheticIno(), + ino: d.fs.nextIno(), mode: uint32(opts.mode), uid: uint32(opts.kuid), gid: uint32(opts.kgid), @@ -100,6 +101,9 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { hostFD: -1, nlink: uint32(2), } + if refsvfs2.LeakCheckEnabled() { + refsvfs2.Register(child, "gofer.dentry") + } switch opts.mode.FileType() { case linux.S_IFDIR: // Nothing else needs to be done. @@ -235,7 +239,7 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { } dirent := vfs.Dirent{ Name: p9d.Name, - Ino: uint64(inoFromPath(p9d.QID.Path)), + Ino: d.fs.inoFromQIDPath(p9d.QID.Path), NextOff: int64(len(dirents) + 1), } // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 94d96261b..baecb88c4 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -35,7 +35,7 @@ import ( // Sync implements vfs.FilesystemImpl.Sync. func (fs *filesystem) Sync(ctx context.Context) error { - // Snapshot current syncable dentries and special files. + // Snapshot current syncable dentries and special file FDs. fs.syncMu.Lock() ds := make([]*dentry, 0, len(fs.syncableDentries)) for d := range fs.syncableDentries { @@ -53,22 +53,28 @@ func (fs *filesystem) Sync(ctx context.Context) error { // regardless. var retErr error - // Sync regular files. + // Sync syncable dentries. for _, d := range ds { - err := d.syncCachedFile(ctx) + err := d.syncCachedFile(ctx, true /* forFilesystemSync */) d.DecRef(ctx) - if err != nil && retErr == nil { - retErr = err + if err != nil { + ctx.Infof("gofer.filesystem.Sync: dentry.syncCachedFile failed: %v", err) + if retErr == nil { + retErr = err + } } } // Sync special files, which may be writable but do not use dentry shared // handles (so they won't be synced by the above). for _, sffd := range sffds { - err := sffd.Sync(ctx) + err := sffd.sync(ctx, true /* forFilesystemSync */) sffd.vfsfd.DecRef(ctx) - if err != nil && retErr == nil { - retErr = err + if err != nil { + ctx.Infof("gofer.filesystem.Sync: specialFileFD.sync failed: %v", err) + if retErr == nil { + retErr = err + } } } @@ -229,7 +235,7 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir return nil, err } if child != nil { - if !file.isNil() && inoFromPath(qid.Path) == child.ino { + if !file.isNil() && qid.Path == child.qidPath { // The file at this path hasn't changed. Just update cached metadata. file.close(ctx) child.updateFromP9AttrsLocked(attrMask, &attr) @@ -1512,7 +1518,6 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath d.IncRef() return &endpoint{ dentry: d, - file: d.file.file, path: opts.Addr, }, nil } @@ -1591,7 +1596,3 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe defer fs.renameMu.RUnlock() return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b) } - -func (fs *filesystem) nextSyntheticIno() inodeNumber { - return inodeNumber(atomic.AddUint64(&fs.syntheticSeq, 1) | syntheticInoMask) -} diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index f1dad1b08..80668ebc1 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -26,6 +26,9 @@ // *** "memmap.Mappable locks taken by Translate" below this point // dentry.handleMu // dentry.dataMu +// filesystem.inoMu +// specialFileFD.mu +// specialFileFD.bufMu // // Locking dentry.dirMu in multiple dentries requires that either ancestor // dentries are locked before descendant dentries, or that filesystem.renameMu @@ -36,7 +39,6 @@ import ( "fmt" "strconv" "strings" - "sync" "sync/atomic" "syscall" @@ -44,6 +46,8 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" + refs_vfs1 "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -53,6 +57,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/pkg/usermem" @@ -81,7 +86,7 @@ type filesystem struct { iopts InternalFilesystemOptions // client is the client used by this filesystem. client is immutable. - client *p9.Client `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + client *p9.Client `state:"nosave"` // clock is a realtime clock used to set timestamps in file operations. clock ktime.Clock @@ -89,6 +94,9 @@ type filesystem struct { // devMinor is the filesystem's minor device number. devMinor is immutable. devMinor uint32 + // root is the root dentry. root is immutable. + root *dentry + // renameMu serves two purposes: // // - It synchronizes path resolution with renaming initiated by this @@ -103,39 +111,35 @@ type filesystem struct { // cachedDentries contains all dentries with 0 references. (Due to race // conditions, it may also contain dentries with non-zero references.) - // cachedDentriesLen is the number of dentries in cachedDentries. These - // fields are protected by renameMu. + // cachedDentriesLen is the number of dentries in cachedDentries. These fields + // are protected by renameMu. cachedDentries dentryList cachedDentriesLen uint64 - // syncableDentries contains all dentries in this filesystem for which - // !dentry.file.isNil(). specialFileFDs contains all open specialFileFDs. - // These fields are protected by syncMu. + // syncableDentries contains all non-synthetic dentries. specialFileFDs + // contains all open specialFileFDs. These fields are protected by syncMu. syncMu sync.Mutex `state:"nosave"` syncableDentries map[*dentry]struct{} specialFileFDs map[*specialFileFD]struct{} - // syntheticSeq stores a counter to used to generate unique inodeNumber for - // synthetic dentries. - syntheticSeq uint64 -} + // inoByQIDPath maps previously-observed QID.Paths to inode numbers + // assigned to those paths. inoByQIDPath is not preserved across + // checkpoint/restore because QIDs may be reused between different gofer + // processes, so QIDs may be repeated for different files across + // checkpoint/restore. inoByQIDPath is protected by inoMu. + inoMu sync.Mutex `state:"nosave"` + inoByQIDPath map[uint64]uint64 `state:"nosave"` -// inodeNumber represents inode number reported in Dirent.Ino. For regular -// dentries, it comes from QID.Path from the 9P server. Synthetic dentries -// have have their inodeNumber generated sequentially, with the MSB reserved to -// prevent conflicts with regular dentries. -// -// +stateify savable -type inodeNumber uint64 + // lastIno is the last inode number assigned to a file. lastIno is accessed + // using atomic memory operations. + lastIno uint64 -// Reserve MSB for synthetic mounts. -const syntheticInoMask = uint64(1) << 63 + // savedDentryRW records open read/write handles during save/restore. + savedDentryRW map[*dentry]savedDentryRW -func inoFromPath(path uint64) inodeNumber { - if path&syntheticInoMask != 0 { - log.Warningf("Dropping MSB from ino, collision is possible. Original: %d, new: %d", path, path&^syntheticInoMask) - } - return inodeNumber(path &^ syntheticInoMask) + // released is nonzero once filesystem.Release has been called. It is accessed + // with atomic memory operations. + released int32 } // +stateify savable @@ -149,8 +153,7 @@ type filesystemOptions struct { msize uint32 version string - // maxCachedDentries is the maximum number of dentries with 0 references - // retained by the client. + // maxCachedDentries is the maximum size of filesystem.cachedDentries. maxCachedDentries uint64 // If forcePageCache is true, host FDs may not be used for application @@ -247,6 +250,10 @@ const ( // // +stateify savable type InternalFilesystemOptions struct { + // If UniqueID is non-empty, it is an opaque string used to reassociate the + // filesystem with a new server FD during restoration from checkpoint. + UniqueID string + // If LeakConnection is true, do not close the connection to the server // when the Filesystem is released. This is necessary for deployments in // which servers can handle only a single client and report failure if that @@ -286,46 +293,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt mopts := vfs.GenericParseMountOptions(opts.Data) var fsopts filesystemOptions - // Check that the transport is "fd". - trans, ok := mopts["trans"] - if !ok { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: transport must be specified as 'trans=fd'") - return nil, nil, syserror.EINVAL - } - delete(mopts, "trans") - if trans != "fd" { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: unsupported transport: trans=%s", trans) - return nil, nil, syserror.EINVAL - } - - // Check that read and write FDs are provided and identical. - rfdstr, ok := mopts["rfdno"] - if !ok { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD must be specified as 'rfdno=<file descriptor>") - return nil, nil, syserror.EINVAL - } - delete(mopts, "rfdno") - rfd, err := strconv.Atoi(rfdstr) - if err != nil { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid read FD: rfdno=%s", rfdstr) - return nil, nil, syserror.EINVAL - } - wfdstr, ok := mopts["wfdno"] - if !ok { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: write FD must be specified as 'wfdno=<file descriptor>") - return nil, nil, syserror.EINVAL - } - delete(mopts, "wfdno") - wfd, err := strconv.Atoi(wfdstr) + fd, err := getFDFromMountOptionsMap(ctx, mopts) if err != nil { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid write FD: wfdno=%s", wfdstr) - return nil, nil, syserror.EINVAL - } - if rfd != wfd { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD (%d) and write FD (%d) must be equal", rfd, wfd) - return nil, nil, syserror.EINVAL + return nil, nil, err } - fsopts.fd = rfd + fsopts.fd = fd // Get the attach name. fsopts.aname = "/" @@ -441,57 +413,44 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } // If !ok, iopts being the zero value is correct. - // Establish a connection with the server. - conn, err := unet.NewSocket(fsopts.fd) + // Construct the filesystem object. + devMinor, err := vfsObj.GetAnonBlockDevMinor() if err != nil { return nil, nil, err } + fs := &filesystem{ + mfp: mfp, + opts: fsopts, + iopts: iopts, + clock: ktime.RealtimeClockFromContext(ctx), + devMinor: devMinor, + syncableDentries: make(map[*dentry]struct{}), + specialFileFDs: make(map[*specialFileFD]struct{}), + inoByQIDPath: make(map[uint64]uint64), + } + fs.vfsfs.Init(vfsObj, &fstype, fs) - // Perform version negotiation with the server. - ctx.UninterruptibleSleepStart(false) - client, err := p9.NewClient(conn, fsopts.msize, fsopts.version) - ctx.UninterruptibleSleepFinish(false) - if err != nil { - conn.Close() + // Connect to the server. + if err := fs.dial(ctx); err != nil { return nil, nil, err } - // Ownership of conn has been transferred to client. // Perform attach to obtain the filesystem root. ctx.UninterruptibleSleepStart(false) - attached, err := client.Attach(fsopts.aname) + attached, err := fs.client.Attach(fsopts.aname) ctx.UninterruptibleSleepFinish(false) if err != nil { - client.Close() + fs.vfsfs.DecRef(ctx) return nil, nil, err } attachFile := p9file{attached} qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask()) if err != nil { attachFile.close(ctx) - client.Close() + fs.vfsfs.DecRef(ctx) return nil, nil, err } - // Construct the filesystem object. - devMinor, err := vfsObj.GetAnonBlockDevMinor() - if err != nil { - attachFile.close(ctx) - client.Close() - return nil, nil, err - } - fs := &filesystem{ - mfp: mfp, - opts: fsopts, - iopts: iopts, - client: client, - clock: ktime.RealtimeClockFromContext(ctx), - devMinor: devMinor, - syncableDentries: make(map[*dentry]struct{}), - specialFileFDs: make(map[*specialFileFD]struct{}), - } - fs.vfsfs.Init(vfsObj, &fstype, fs) - // Construct the root dentry. root, err := fs.newDentry(ctx, attachFile, qid, attrMask, &attr) if err != nil { @@ -500,25 +459,87 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return nil, nil, err } // Set the root's reference count to 2. One reference is returned to the - // caller, and the other is deliberately leaked to prevent the root from - // being "cached" and subsequently evicted. Its resources will still be - // cleaned up by fs.Release(). + // caller, and the other is held by fs to prevent the root from being "cached" + // and subsequently evicted. root.refs = 2 + fs.root = root return &fs.vfsfs, &root.vfsd, nil } +func getFDFromMountOptionsMap(ctx context.Context, mopts map[string]string) (int, error) { + // Check that the transport is "fd". + trans, ok := mopts["trans"] + if !ok || trans != "fd" { + ctx.Warningf("gofer.getFDFromMountOptionsMap: transport must be specified as 'trans=fd'") + return -1, syserror.EINVAL + } + delete(mopts, "trans") + + // Check that read and write FDs are provided and identical. + rfdstr, ok := mopts["rfdno"] + if !ok { + ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD must be specified as 'rfdno=<file descriptor>'") + return -1, syserror.EINVAL + } + delete(mopts, "rfdno") + rfd, err := strconv.Atoi(rfdstr) + if err != nil { + ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid read FD: rfdno=%s", rfdstr) + return -1, syserror.EINVAL + } + wfdstr, ok := mopts["wfdno"] + if !ok { + ctx.Warningf("gofer.getFDFromMountOptionsMap: write FD must be specified as 'wfdno=<file descriptor>'") + return -1, syserror.EINVAL + } + delete(mopts, "wfdno") + wfd, err := strconv.Atoi(wfdstr) + if err != nil { + ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid write FD: wfdno=%s", wfdstr) + return -1, syserror.EINVAL + } + if rfd != wfd { + ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD (%d) and write FD (%d) must be equal", rfd, wfd) + return -1, syserror.EINVAL + } + return rfd, nil +} + +// Preconditions: fs.client == nil. +func (fs *filesystem) dial(ctx context.Context) error { + // Establish a connection with the server. + conn, err := unet.NewSocket(fs.opts.fd) + if err != nil { + return err + } + + // Perform version negotiation with the server. + ctx.UninterruptibleSleepStart(false) + client, err := p9.NewClient(conn, fs.opts.msize, fs.opts.version) + ctx.UninterruptibleSleepFinish(false) + if err != nil { + conn.Close() + return err + } + // Ownership of conn has been transferred to client. + + fs.client = client + return nil +} + // Release implements vfs.FilesystemImpl.Release. func (fs *filesystem) Release(ctx context.Context) { - mf := fs.mfp.MemoryFile() + atomic.StoreInt32(&fs.released, 1) + mf := fs.mfp.MemoryFile() fs.syncMu.Lock() for d := range fs.syncableDentries { d.handleMu.Lock() d.dataMu.Lock() if h := d.writeHandleLocked(); h.isOpen() { // Write dirty cached data to the remote file. - if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, fs.mfp.MemoryFile(), h.writeFromBlocksAt); err != nil { + if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, h.writeFromBlocksAt); err != nil { log.Warningf("gofer.filesystem.Release: failed to flush dentry: %v", err) } // TODO(jamieliu): Do we need to flushf/fsync d? @@ -539,6 +560,21 @@ func (fs *filesystem) Release(ctx context.Context) { // fs. fs.syncMu.Unlock() + // If leak checking is enabled, release all outstanding references in the + // filesystem. We deliberately avoid doing this outside of leak checking; we + // have released all external resources above rather than relying on dentry + // destructors. + if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking { + fs.renameMu.Lock() + fs.root.releaseSyntheticRecursiveLocked(ctx) + fs.evictAllCachedDentriesLocked(ctx) + fs.renameMu.Unlock() + + // An extra reference was held by the filesystem on the root to prevent it from + // being cached/evicted. + fs.root.DecRef(ctx) + } + if !fs.iopts.LeakConnection { // Close the connection to the server. This implicitly clunks all fids. fs.client.Close() @@ -547,6 +583,31 @@ func (fs *filesystem) Release(ctx context.Context) { fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) } +// releaseSyntheticRecursiveLocked traverses the tree with root d and decrements +// the reference count on every synthetic dentry. Synthetic dentries have one +// reference for existence that should be dropped during filesystem.Release. +// +// Precondition: d.fs.renameMu is locked. +func (d *dentry) releaseSyntheticRecursiveLocked(ctx context.Context) { + if d.isSynthetic() { + d.decRefLocked() + d.checkCachingLocked(ctx) + } + if d.isDir() { + var children []*dentry + d.dirMu.Lock() + for _, child := range d.children { + children = append(children, child) + } + d.dirMu.Unlock() + for _, child := range children { + if child != nil { + child.releaseSyntheticRecursiveLocked(ctx) + } + } + } +} + // dentry implements vfs.DentryImpl. // // +stateify savable @@ -574,12 +635,15 @@ type dentry struct { // filesystem.renameMu. name string + // qidPath is the p9.QID.Path for this file. qidPath is immutable. + qidPath uint64 + // file is the unopened p9.File that backs this dentry. file is immutable. // // If file.isNil(), this dentry represents a synthetic file, i.e. a file // that does not exist on the remote filesystem. As of this writing, the // only files that can be synthetic are sockets, pipes, and directories. - file p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + file p9file `state:"nosave"` // If deleted is non-zero, the file represented by this dentry has been // deleted. deleted is accessed using atomic memory operations. @@ -623,12 +687,12 @@ type dentry struct { // To mutate: // - Lock metadataMu and use atomic operations to update because we might // have atomic readers that don't hold the lock. - metadataMu sync.Mutex `state:"nosave"` - ino inodeNumber // immutable - mode uint32 // type is immutable, perms are mutable - uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic - gid uint32 // auth.KGID, but ... - blockSize uint32 // 0 if unknown + metadataMu sync.Mutex `state:"nosave"` + ino uint64 // immutable + mode uint32 // type is immutable, perms are mutable + uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic + gid uint32 // auth.KGID, but ... + blockSize uint32 // 0 if unknown // Timestamps, all nsecs from the Unix epoch. atime int64 mtime int64 @@ -679,9 +743,9 @@ type dentry struct { // (isNil() == false), it may be mutated with handleMu locked, but cannot // be closed until the dentry is destroyed. handleMu sync.RWMutex `state:"nosave"` - readFile p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. - writeFile p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. - hostFD int32 + readFile p9file `state:"nosave"` + writeFile p9file `state:"nosave"` + hostFD int32 `state:"nosave"` dataMu sync.RWMutex `state:"nosave"` @@ -758,8 +822,9 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma d := &dentry{ fs: fs, + qidPath: qid.Path, file: file, - ino: inoFromPath(qid.Path), + ino: fs.inoFromQIDPath(qid.Path), mode: uint32(attr.Mode), uid: uint32(fs.opts.dfltuid), gid: uint32(fs.opts.dfltgid), @@ -795,6 +860,9 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma d.nlink = uint32(attr.NLink) } d.vfsd.Init(d) + if refsvfs2.LeakCheckEnabled() { + refsvfs2.Register(d, "gofer.dentry") + } fs.syncMu.Lock() fs.syncableDentries[d] = struct{}{} @@ -802,6 +870,21 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma return d, nil } +func (fs *filesystem) inoFromQIDPath(qidPath uint64) uint64 { + fs.inoMu.Lock() + defer fs.inoMu.Unlock() + if ino, ok := fs.inoByQIDPath[qidPath]; ok { + return ino + } + ino := fs.nextIno() + fs.inoByQIDPath[qidPath] = ino + return ino +} + +func (fs *filesystem) nextIno() uint64 { + return atomic.AddUint64(&fs.lastIno, 1) +} + func (d *dentry) isSynthetic() bool { return d.file.isNil() } @@ -853,7 +936,7 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) { } } -// Preconditions: !d.isSynthetic() +// Preconditions: !d.isSynthetic(). func (d *dentry) updateFromGetattr(ctx context.Context) error { // Use d.readFile or d.writeFile, which represent 9P fids that have been // opened, in preference to d.file, which represents a 9P fid that has not. @@ -916,10 +999,10 @@ func (d *dentry) statTo(stat *linux.Statx) { // This is consistent with regularFileFD.Seek(), which treats regular files // as having no holes. stat.Blocks = (stat.Size + 511) / 512 - stat.Atime = statxTimestampFromDentry(atomic.LoadInt64(&d.atime)) - stat.Btime = statxTimestampFromDentry(atomic.LoadInt64(&d.btime)) - stat.Ctime = statxTimestampFromDentry(atomic.LoadInt64(&d.ctime)) - stat.Mtime = statxTimestampFromDentry(atomic.LoadInt64(&d.mtime)) + stat.Atime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.atime)) + stat.Btime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.btime)) + stat.Ctime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.ctime)) + stat.Mtime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.mtime)) stat.DevMajor = linux.UNNAMED_MAJOR stat.DevMinor = d.fs.devMinor } @@ -967,10 +1050,10 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs // Use client clocks for timestamps. now = d.fs.clock.Now().Nanoseconds() if stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec == linux.UTIME_NOW { - stat.Atime = statxTimestampFromDentry(now) + stat.Atime = linux.NsecToStatxTimestamp(now) } if stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec == linux.UTIME_NOW { - stat.Mtime = statxTimestampFromDentry(now) + stat.Mtime = linux.NsecToStatxTimestamp(now) } } @@ -1029,11 +1112,11 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs // !d.cachedMetadataAuthoritative() then we returned after calling // d.file.setAttr(). For the same reason, now must have been initialized. if stat.Mask&linux.STATX_ATIME != 0 { - atomic.StoreInt64(&d.atime, dentryTimestampFromStatx(stat.Atime)) + atomic.StoreInt64(&d.atime, stat.Atime.ToNsec()) atomic.StoreUint32(&d.atimeDirty, 0) } if stat.Mask&linux.STATX_MTIME != 0 { - atomic.StoreInt64(&d.mtime, dentryTimestampFromStatx(stat.Mtime)) + atomic.StoreInt64(&d.mtime, stat.Mtime.ToNsec()) atomic.StoreUint32(&d.mtimeDirty, 0) } atomic.StoreInt64(&d.ctime, now) @@ -1175,6 +1258,11 @@ func (d *dentry) decRefLocked() { } } +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (d *dentry) LeakMessage() string { + return fmt.Sprintf("[gofer.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs)) +} + // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. func (d *dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) { if d.isDir() { @@ -1223,6 +1311,10 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { // resolution, which requires renameMu, so if d.refs is zero then it will // remain zero while we hold renameMu for writing.) refs := atomic.LoadInt64(&d.refs) + if refs == -1 { + // Dentry has already been destroyed. + return + } if refs > 0 { if d.cached { d.fs.cachedDentries.Remove(d) @@ -1231,10 +1323,6 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { } return } - if refs == -1 { - // Dentry has already been destroyed. - return - } // Deleted and invalidated dentries with zero references are no longer // reachable by path resolution and should be dropped immediately. if d.vfsd.IsDead() { @@ -1257,6 +1345,16 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { if d.watches.Size() > 0 { return } + + if atomic.LoadInt32(&d.fs.released) != 0 { + if d.parent != nil { + d.parent.dirMu.Lock() + delete(d.parent.children, d.name) + d.parent.dirMu.Unlock() + } + d.destroyLocked(ctx) + } + // If d is already cached, just move it to the front of the LRU. if d.cached { d.fs.cachedDentries.Remove(d) @@ -1269,33 +1367,48 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { d.fs.cachedDentriesLen++ d.cached = true if d.fs.cachedDentriesLen > d.fs.opts.maxCachedDentries { - victim := d.fs.cachedDentries.Back() - d.fs.cachedDentries.Remove(victim) - d.fs.cachedDentriesLen-- - victim.cached = false - // victim.refs may have become non-zero from an earlier path resolution - // since it was inserted into fs.cachedDentries. - if atomic.LoadInt64(&victim.refs) == 0 { - if victim.parent != nil { - victim.parent.dirMu.Lock() - if !victim.vfsd.IsDead() { - // Note that victim can't be a mount point (in any mount - // namespace), since VFS holds references on mount points. - d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd) - delete(victim.parent.children, victim.name) - // We're only deleting the dentry, not the file it - // represents, so we don't need to update - // victimParent.dirents etc. - } - victim.parent.dirMu.Unlock() - } - victim.destroyLocked(ctx) - } + d.fs.evictCachedDentryLocked(ctx) // Whether or not victim was destroyed, we brought fs.cachedDentriesLen // back down to fs.opts.maxCachedDentries, so we don't loop. } } +// Precondition: fs.renameMu must be locked for writing; it may be temporarily +// unlocked. +func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) { + for fs.cachedDentriesLen != 0 { + fs.evictCachedDentryLocked(ctx) + } +} + +// Preconditions: +// * fs.renameMu must be locked for writing; it may be temporarily unlocked. +// * fs.cachedDentriesLen != 0. +func (fs *filesystem) evictCachedDentryLocked(ctx context.Context) { + victim := fs.cachedDentries.Back() + fs.cachedDentries.Remove(victim) + fs.cachedDentriesLen-- + victim.cached = false + // victim.refs may have become non-zero from an earlier path resolution + // since it was inserted into fs.cachedDentries. + if atomic.LoadInt64(&victim.refs) == 0 { + if victim.parent != nil { + victim.parent.dirMu.Lock() + if !victim.vfsd.IsDead() { + // Note that victim can't be a mount point (in any mount + // namespace), since VFS holds references on mount points. + fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd) + delete(victim.parent.children, victim.name) + // We're only deleting the dentry, not the file it + // represents, so we don't need to update + // victimParent.dirents etc. + } + victim.parent.dirMu.Unlock() + } + victim.destroyLocked(ctx) + } +} + // destroyLocked destroys the dentry. // // Preconditions: @@ -1380,6 +1493,10 @@ func (d *dentry) destroyLocked(ctx context.Context) { panic("gofer.dentry.DecRef() called without holding a reference") } } + + if refsvfs2.LeakCheckEnabled() { + refsvfs2.Unregister(d, "gofer.dentry") + } } func (d *dentry) isDeleted() bool { @@ -1623,6 +1740,33 @@ func (d *dentry) syncRemoteFileLocked(ctx context.Context) error { return nil } +func (d *dentry) syncCachedFile(ctx context.Context, forFilesystemSync bool) error { + d.handleMu.RLock() + defer d.handleMu.RUnlock() + h := d.writeHandleLocked() + if h.isOpen() { + // Write back dirty pages to the remote file. + d.dataMu.Lock() + err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), h.writeFromBlocksAt) + d.dataMu.Unlock() + if err != nil { + return err + } + } + if err := d.syncRemoteFileLocked(ctx); err != nil { + if !forFilesystemSync { + return err + } + // Only return err if we can reasonably have expected sync to succeed + // (d is a regular file and was opened for writing). + if d.isRegularFile() && h.isOpen() { + return err + } + ctx.Debugf("gofer.dentry.syncCachedFile: syncing non-writable or non-regular-file dentry failed: %v", err) + } + return nil +} + // incLinks increments link count. func (d *dentry) incLinks() { if atomic.LoadUint32(&d.nlink) == 0 { @@ -1650,7 +1794,7 @@ type fileDescription struct { vfs.FileDescriptionDefaultImpl vfs.LockFD - lockLogging sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + lockLogging sync.Once `state:"nosave"` } func (fd *fileDescription) filesystem() *filesystem { diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go index bfe75dfe4..76f08e252 100644 --- a/pkg/sentry/fsimpl/gofer/gofer_test.go +++ b/pkg/sentry/fsimpl/gofer/gofer_test.go @@ -26,12 +26,13 @@ import ( func TestDestroyIdempotent(t *testing.T) { ctx := contexttest.Context(t) fs := filesystem{ - mfp: pgalloc.MemoryFileProviderFromContext(ctx), - syncableDentries: make(map[*dentry]struct{}), + mfp: pgalloc.MemoryFileProviderFromContext(ctx), opts: filesystemOptions{ // Test relies on no dentry being held in the cache. maxCachedDentries: 0, }, + syncableDentries: make(map[*dentry]struct{}), + inoByQIDPath: make(map[uint64]uint64), } attr := &p9.Attr{ diff --git a/pkg/sentry/fsimpl/gofer/host_named_pipe.go b/pkg/sentry/fsimpl/gofer/host_named_pipe.go index 7294de7d6..c7bf10007 100644 --- a/pkg/sentry/fsimpl/gofer/host_named_pipe.go +++ b/pkg/sentry/fsimpl/gofer/host_named_pipe.go @@ -51,8 +51,24 @@ func blockUntilNonblockingPipeHasWriter(ctx context.Context, fd int32) error { if ok { return nil } - if err := sleepBetweenNamedPipeOpenChecks(ctx); err != nil { - return err + if sleepErr := sleepBetweenNamedPipeOpenChecks(ctx); sleepErr != nil { + // Another application thread may have opened this pipe for + // writing, succeeded because we previously opened the pipe for + // reading, and subsequently interrupted us for checkpointing (e.g. + // this occurs in mknod tests under cooperative save/restore). In + // this case, our open has to succeed for the checkpoint to include + // a readable FD for the pipe, which is in turn necessary to + // restore the other thread's writable FD for the same pipe + // (otherwise it will get ENXIO). So we have to check + // nonblockingPipeHasWriter() once last time. + ok, err := nonblockingPipeHasWriter(fd) + if err != nil { + return err + } + if ok { + return nil + } + return sleepErr } } } diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index f8b19bae7..dc8a890cb 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -18,7 +18,6 @@ import ( "fmt" "io" "math" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -31,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -624,23 +624,7 @@ func regularFileSeekLocked(ctx context.Context, d *dentry, fdOffset, offset int6 // Sync implements vfs.FileDescriptionImpl.Sync. func (fd *regularFileFD) Sync(ctx context.Context) error { - return fd.dentry().syncCachedFile(ctx) -} - -func (d *dentry) syncCachedFile(ctx context.Context) error { - d.handleMu.RLock() - defer d.handleMu.RUnlock() - - if h := d.writeHandleLocked(); h.isOpen() { - d.dataMu.Lock() - // Write dirty cached data to the remote file. - err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), h.writeFromBlocksAt) - d.dataMu.Unlock() - if err != nil { - return err - } - } - return d.syncRemoteFileLocked(ctx) + return fd.dentry().syncCachedFile(ctx, false /* lowSyncExpectations */) } // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. @@ -913,7 +897,7 @@ type dentryPlatformFile struct { hostFileMapper fsutil.HostFileMapper // hostFileMapperInitOnce is used to lazily initialize hostFileMapper. - hostFileMapperInitOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + hostFileMapperInitOnce sync.Once `state:"nosave"` } // IncRef implements memmap.File.IncRef. diff --git a/pkg/sentry/fsimpl/gofer/save_restore.go b/pkg/sentry/fsimpl/gofer/save_restore.go new file mode 100644 index 000000000..2ea224c43 --- /dev/null +++ b/pkg/sentry/fsimpl/gofer/save_restore.go @@ -0,0 +1,329 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gofer + +import ( + "fmt" + "io" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/refsvfs2" + "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +type saveRestoreContextID int + +const ( + // CtxRestoreServerFDMap is a Context.Value key for a map[string]int + // mapping filesystem unique IDs (cf. InternalFilesystemOptions.UniqueID) + // to host FDs. + CtxRestoreServerFDMap saveRestoreContextID = iota +) + +// +stateify savable +type savedDentryRW struct { + read bool + write bool +} + +// PreprareSave implements vfs.FilesystemImplSaveRestoreExtension.PrepareSave. +func (fs *filesystem) PrepareSave(ctx context.Context) error { + if len(fs.iopts.UniqueID) == 0 { + return fmt.Errorf("gofer.filesystem with no UniqueID cannot be saved") + } + + // Purge cached dentries, which may not be reopenable after restore due to + // permission changes. + fs.renameMu.Lock() + fs.evictAllCachedDentriesLocked(ctx) + fs.renameMu.Unlock() + + // Buffer pipe data so that it's available for reading after restore. (This + // is a legacy VFS1 feature.) + fs.syncMu.Lock() + for sffd := range fs.specialFileFDs { + if sffd.dentry().fileType() == linux.S_IFIFO && sffd.vfsfd.IsReadable() { + if err := sffd.savePipeData(ctx); err != nil { + fs.syncMu.Unlock() + return err + } + } + } + fs.syncMu.Unlock() + + // Flush local state to the remote filesystem. + if err := fs.Sync(ctx); err != nil { + return err + } + + fs.savedDentryRW = make(map[*dentry]savedDentryRW) + return fs.root.prepareSaveRecursive(ctx) +} + +// Preconditions: +// * fd represents a pipe. +// * fd is readable. +func (fd *specialFileFD) savePipeData(ctx context.Context) error { + fd.bufMu.Lock() + defer fd.bufMu.Unlock() + var buf [usermem.PageSize]byte + for { + n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:])), ^uint64(0)) + if n != 0 { + fd.buf = append(fd.buf, buf[:n]...) + } + if err != nil { + if err == io.EOF || err == syserror.EAGAIN { + break + } + return err + } + } + if len(fd.buf) != 0 { + atomic.StoreUint32(&fd.haveBuf, 1) + } + return nil +} + +func (d *dentry) prepareSaveRecursive(ctx context.Context) error { + if d.isRegularFile() && !d.cachedMetadataAuthoritative() { + // Get updated metadata for d in case we need to perform metadata + // validation during restore. + if err := d.updateFromGetattr(ctx); err != nil { + return err + } + } + if !d.readFile.isNil() || !d.writeFile.isNil() { + d.fs.savedDentryRW[d] = savedDentryRW{ + read: !d.readFile.isNil(), + write: !d.writeFile.isNil(), + } + } + d.dirMu.Lock() + defer d.dirMu.Unlock() + for _, child := range d.children { + if child != nil { + if err := child.prepareSaveRecursive(ctx); err != nil { + return err + } + } + } + return nil +} + +// beforeSave is invoked by stateify. +func (d *dentry) beforeSave() { + if d.vfsd.IsDead() { + panic(fmt.Sprintf("gofer.dentry(%q).beforeSave: deleted and invalidated dentries can't be restored", genericDebugPathname(d))) + } +} + +// afterLoad is invoked by stateify. +func (d *dentry) afterLoad() { + d.hostFD = -1 + if refsvfs2.LeakCheckEnabled() && atomic.LoadInt64(&d.refs) != -1 { + refsvfs2.Register(d, "gofer.dentry") + } +} + +// afterLoad is invoked by stateify. +func (d *dentryPlatformFile) afterLoad() { + if d.hostFileMapper.IsInited() { + // Ensure that we don't call d.hostFileMapper.Init() again. + d.hostFileMapperInitOnce.Do(func() {}) + } +} + +// afterLoad is invoked by stateify. +func (fd *specialFileFD) afterLoad() { + fd.handle.fd = -1 +} + +// CompleteRestore implements +// vfs.FilesystemImplSaveRestoreExtension.CompleteRestore. +func (fs *filesystem) CompleteRestore(ctx context.Context, opts vfs.CompleteRestoreOptions) error { + fdmapv := ctx.Value(CtxRestoreServerFDMap) + if fdmapv == nil { + return fmt.Errorf("no server FD map available") + } + fdmap := fdmapv.(map[string]int) + fd, ok := fdmap[fs.iopts.UniqueID] + if !ok { + return fmt.Errorf("no server FD available for filesystem with unique ID %q", fs.iopts.UniqueID) + } + fs.opts.fd = fd + if err := fs.dial(ctx); err != nil { + return err + } + fs.inoByQIDPath = make(map[uint64]uint64) + + // Restore the filesystem root. + ctx.UninterruptibleSleepStart(false) + attached, err := fs.client.Attach(fs.opts.aname) + ctx.UninterruptibleSleepFinish(false) + if err != nil { + return err + } + attachFile := p9file{attached} + qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask()) + if err != nil { + return err + } + if err := fs.root.restoreFile(ctx, attachFile, qid, attrMask, &attr, &opts); err != nil { + return err + } + + // Restore remaining dentries. + if err := fs.root.restoreDescendantsRecursive(ctx, &opts); err != nil { + return err + } + + // Re-open handles for specialFileFDs. Unlike the initial open + // (dentry.openSpecialFile()), pipes are always opened without blocking; + // non-readable pipe FDs are opened last to ensure that they don't get + // ENXIO if another specialFileFD represents the read end of the same pipe. + // This is consistent with VFS1. + haveWriteOnlyPipes := false + for fd := range fs.specialFileFDs { + if fd.dentry().fileType() == linux.S_IFIFO && !fd.vfsfd.IsReadable() { + haveWriteOnlyPipes = true + continue + } + if err := fd.completeRestore(ctx); err != nil { + return err + } + } + if haveWriteOnlyPipes { + for fd := range fs.specialFileFDs { + if fd.dentry().fileType() == linux.S_IFIFO && !fd.vfsfd.IsReadable() { + if err := fd.completeRestore(ctx); err != nil { + return err + } + } + } + } + + // Discard state only required during restore. + fs.savedDentryRW = nil + + return nil +} + +func (d *dentry) restoreFile(ctx context.Context, file p9file, qid p9.QID, attrMask p9.AttrMask, attr *p9.Attr, opts *vfs.CompleteRestoreOptions) error { + d.file = file + + // Gofers do not preserve QID across checkpoint/restore, so: + // + // - We must assume that the remote filesystem did not change in a way that + // would invalidate dentries, since we can't revalidate dentries by + // checking QIDs. + // + // - We need to associate the new QID.Path with the existing d.ino. + d.qidPath = qid.Path + d.fs.inoMu.Lock() + d.fs.inoByQIDPath[qid.Path] = d.ino + d.fs.inoMu.Unlock() + + // Check metadata stability before updating metadata. + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + if d.isRegularFile() { + if opts.ValidateFileSizes { + if !attrMask.Size { + return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: file size not available", genericDebugPathname(d)) + } + if d.size != attr.Size { + return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: size changed from %d to %d", genericDebugPathname(d), d.size, attr.Size) + } + } + if opts.ValidateFileModificationTimestamps { + if !attrMask.MTime { + return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime not available", genericDebugPathname(d)) + } + if want := dentryTimestampFromP9(attr.MTimeSeconds, attr.MTimeNanoSeconds); d.mtime != want { + return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime changed from %+v to %+v", genericDebugPathname(d), linux.NsecToStatxTimestamp(d.mtime), linux.NsecToStatxTimestamp(want)) + } + } + } + if !d.cachedMetadataAuthoritative() { + d.updateFromP9AttrsLocked(attrMask, attr) + } + + if rw, ok := d.fs.savedDentryRW[d]; ok { + if err := d.ensureSharedHandle(ctx, rw.read, rw.write, false /* trunc */); err != nil { + return err + } + } + + return nil +} + +// Preconditions: d is not synthetic. +func (d *dentry) restoreDescendantsRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error { + for _, child := range d.children { + if child == nil { + continue + } + if _, ok := d.fs.syncableDentries[child]; !ok { + // child is synthetic. + continue + } + if err := child.restoreRecursive(ctx, opts); err != nil { + return err + } + } + return nil +} + +// Preconditions: d is not synthetic (but note that since this function +// restores d.file, d.file.isNil() is always true at this point, so this can +// only be detected by checking filesystem.syncableDentries). d.parent has been +// restored. +func (d *dentry) restoreRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error { + qid, file, attrMask, attr, err := d.parent.file.walkGetAttrOne(ctx, d.name) + if err != nil { + return err + } + if err := d.restoreFile(ctx, file, qid, attrMask, &attr, opts); err != nil { + return err + } + return d.restoreDescendantsRecursive(ctx, opts) +} + +func (fd *specialFileFD) completeRestore(ctx context.Context) error { + d := fd.dentry() + h, err := openHandle(ctx, d.file, fd.vfsfd.IsReadable(), fd.vfsfd.IsWritable(), false /* trunc */) + if err != nil { + return err + } + fd.handle = h + + ftype := d.fileType() + fd.haveQueue = (ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK) && fd.handle.fd >= 0 + if fd.haveQueue { + if err := fdnotifier.AddFD(fd.handle.fd, &fd.queue); err != nil { + return err + } + } + + return nil +} diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go index 326b940a7..a21199eac 100644 --- a/pkg/sentry/fsimpl/gofer/socket.go +++ b/pkg/sentry/fsimpl/gofer/socket.go @@ -42,9 +42,6 @@ type endpoint struct { // dentry is the filesystem dentry which produced this endpoint. dentry *dentry - // file is the p9 file that contains a single unopened fid. - file p9.File `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. - // path is the sentry path where this endpoint is bound. path string } @@ -116,7 +113,7 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect } func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFlags, queue *waiter.Queue) (*host.SCMConnectedEndpoint, *syserr.Error) { - hostFile, err := e.file.Connect(flags) + hostFile, err := e.dentry.file.connect(ctx, flags) if err != nil { return nil, syserr.ErrConnectionRefused } @@ -131,7 +128,7 @@ func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFla c, serr := host.NewSCMEndpoint(ctx, hostFD, queue, e.path) if serr != nil { - log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.file, flags, serr) + log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.dentry.file, flags, serr) return nil, serr } return c, nil diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index 71581736c..625400c0b 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -15,7 +15,6 @@ package gofer import ( - "sync" "sync/atomic" "syscall" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -40,7 +40,7 @@ type specialFileFD struct { fileDescription // handle is used for file I/O. handle is immutable. - handle handle `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + handle handle `state:"nosave"` // isRegularFile is true if this FD represents a regular file which is only // possible when filesystemOptions.regularFilesUseSpecialFileFD is in @@ -54,12 +54,20 @@ type specialFileFD struct { // haveQueue is true if this file description represents a file for which // queue may send I/O readiness events. haveQueue is immutable. - haveQueue bool + haveQueue bool `state:"nosave"` queue waiter.Queue // If seekable is true, off is the file offset. off is protected by mu. mu sync.Mutex `state:"nosave"` off int64 + + // If haveBuf is non-zero, this FD represents a pipe, and buf contains data + // read from the pipe from previous calls to specialFileFD.savePipeData(). + // haveBuf and buf are protected by bufMu. haveBuf is accessed using atomic + // memory operations. + bufMu sync.Mutex `state:"nosave"` + haveBuf uint32 + buf []byte } func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, flags uint32) (*specialFileFD, error) { @@ -87,6 +95,9 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, } return nil, err } + d.fs.syncMu.Lock() + d.fs.specialFileFDs[fd] = struct{}{} + d.fs.syncMu.Unlock() return fd, nil } @@ -161,26 +172,51 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs return 0, syserror.EOPNOTSUPP } - // Going through dst.CopyOutFrom() holds MM locks around file operations of - // unknown duration. For regularFileFD, doing so is necessary to support - // mmap due to lock ordering; MM locks precede dentry.dataMu. That doesn't - // hold here since specialFileFD doesn't client-cache data. Just buffer the - // read instead. if d := fd.dentry(); d.cachedMetadataAuthoritative() { d.touchAtime(fd.vfsfd.Mount()) } + + bufN := int64(0) + if atomic.LoadUint32(&fd.haveBuf) != 0 { + var err error + fd.bufMu.Lock() + if len(fd.buf) != 0 { + var n int + n, err = dst.CopyOut(ctx, fd.buf) + dst = dst.DropFirst(n) + fd.buf = fd.buf[n:] + if len(fd.buf) == 0 { + atomic.StoreUint32(&fd.haveBuf, 0) + fd.buf = nil + } + bufN = int64(n) + if offset >= 0 { + offset += bufN + } + } + fd.bufMu.Unlock() + if err != nil { + return bufN, err + } + } + + // Going through dst.CopyOutFrom() would hold MM locks around file + // operations of unknown duration. For regularFileFD, doing so is necessary + // to support mmap due to lock ordering; MM locks precede dentry.dataMu. + // That doesn't hold here since specialFileFD doesn't client-cache data. + // Just buffer the read instead. buf := make([]byte, dst.NumBytes()) n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset)) if err == syserror.EAGAIN { err = syserror.ErrWouldBlock } if n == 0 { - return 0, err + return bufN, err } if cp, cperr := dst.CopyOut(ctx, buf[:n]); cperr != nil { - return int64(cp), cperr + return bufN + int64(cp), cperr } - return int64(n), err + return bufN + int64(n), err } // Read implements vfs.FileDescriptionImpl.Read. @@ -217,16 +253,16 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off } d := fd.dentry() - // If the regular file fd was opened with O_APPEND, make sure the file size - // is updated. There is a possible race here if size is modified externally - // after metadata cache is updated. - if fd.isRegularFile && fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() { - if err := d.updateFromGetattr(ctx); err != nil { - return 0, offset, err + if fd.isRegularFile { + // If the regular file fd was opened with O_APPEND, make sure the file + // size is updated. There is a possible race here if size is modified + // externally after metadata cache is updated. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() { + if err := d.updateFromGetattr(ctx); err != nil { + return 0, offset, err + } } - } - if fd.isRegularFile { // We need to hold the metadataMu *while* writing to a regular file. d.metadataMu.Lock() defer d.metadataMu.Unlock() @@ -306,13 +342,31 @@ func (fd *specialFileFD) Seek(ctx context.Context, offset int64, whence int32) ( // Sync implements vfs.FileDescriptionImpl.Sync. func (fd *specialFileFD) Sync(ctx context.Context) error { - // If we have a host FD, fsyncing it is likely to be faster than an fsync - // RPC. - if fd.handle.fd >= 0 { - ctx.UninterruptibleSleepStart(false) - err := syscall.Fsync(int(fd.handle.fd)) - ctx.UninterruptibleSleepFinish(false) - return err + return fd.sync(ctx, false /* forFilesystemSync */) +} + +func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool) error { + err := func() error { + // If we have a host FD, fsyncing it is likely to be faster than an fsync + // RPC. + if fd.handle.fd >= 0 { + ctx.UninterruptibleSleepStart(false) + err := syscall.Fsync(int(fd.handle.fd)) + ctx.UninterruptibleSleepFinish(false) + return err + } + return fd.handle.file.fsync(ctx) + }() + if err != nil { + if !forFilesystemSync { + return err + } + // Only return err if we can reasonably have expected sync to succeed + // (fd represents a regular file that was opened for writing). + if fd.isRegularFile && fd.vfsfd.IsWritable() { + return err + } + ctx.Debugf("gofer.specialFileFD.sync: syncing non-writable or non-regular-file FD failed: %v", err) } - return fd.handle.file.fsync(ctx) + return nil } diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go index 7e825caae..9cbe805b9 100644 --- a/pkg/sentry/fsimpl/gofer/time.go +++ b/pkg/sentry/fsimpl/gofer/time.go @@ -17,7 +17,6 @@ package gofer import ( "sync/atomic" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/vfs" ) @@ -25,17 +24,6 @@ func dentryTimestampFromP9(s, ns uint64) int64 { return int64(s*1e9 + ns) } -func dentryTimestampFromStatx(ts linux.StatxTimestamp) int64 { - return ts.Sec*1e9 + int64(ts.Nsec) -} - -func statxTimestampFromDentry(ns int64) linux.StatxTimestamp { - return linux.StatxTimestamp{ - Sec: ns / 1e9, - Nsec: uint32(ns % 1e9), - } -} - // Preconditions: d.cachedMetadataAuthoritative() == true. func (d *dentry) touchAtime(mnt *vfs.Mount) { if mnt.Flags.NoATime || mnt.ReadOnly() { diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD index 56bcf9bdb..dc0f86061 100644 --- a/pkg/sentry/fsimpl/host/BUILD +++ b/pkg/sentry/fsimpl/host/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "inode_refs.go", package = "host", prefix = "inode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "inode", }, @@ -19,7 +19,7 @@ go_template_instance( out = "connected_endpoint_refs.go", package = "host", prefix = "ConnectedEndpoint", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "ConnectedEndpoint", }, @@ -34,6 +34,7 @@ go_library( "inode_refs.go", "ioctl_unsafe.go", "mmap.go", + "save_restore.go", "socket.go", "socket_iovec.go", "socket_unsafe.go", @@ -51,6 +52,7 @@ go_library( "//pkg/log", "//pkg/marshal/primitive", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs/fsutil", diff --git a/pkg/sentry/fsimpl/host/control.go b/pkg/sentry/fsimpl/host/control.go index 0135e4428..13ef48cb5 100644 --- a/pkg/sentry/fsimpl/host/control.go +++ b/pkg/sentry/fsimpl/host/control.go @@ -79,7 +79,7 @@ func fdsToFiles(ctx context.Context, fds []int) []*vfs.FileDescription { } // Create the file backed by hostFD. - file, err := ImportFD(ctx, kernel.KernelFromContext(ctx).HostMount(), fd, false /* isTTY */) + file, err := NewFD(ctx, kernel.KernelFromContext(ctx).HostMount(), fd, &NewFDOptions{}) if err != nil { ctx.Warningf("Error creating file from host FD: %v", err) break diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 698e913fe..eeed0f97d 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -19,6 +19,7 @@ package host import ( "fmt" "math" + "sync/atomic" "syscall" "golang.org/x/sys/unix" @@ -40,34 +41,106 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -func newInode(fs *filesystem, hostFD int, fileType linux.FileMode, isTTY bool) (*inode, error) { - // Determine if hostFD is seekable. If not, this syscall will return ESPIPE - // (see fs/read_write.c:llseek), e.g. for pipes, sockets, and some character - // devices. +// inode implements kernfs.Inode. +// +// +stateify savable +type inode struct { + kernfs.InodeNoStatFS + kernfs.InodeNotDirectory + kernfs.InodeNotSymlink + kernfs.InodeTemporary // This holds no meaning as this inode can't be Looked up and is always valid. + + locks vfs.FileLocks + + // When the reference count reaches zero, the host fd is closed. + inodeRefs + + // hostFD contains the host fd that this file was originally created from, + // which must be available at time of restore. + // + // This field is initialized at creation time and is immutable. + hostFD int + + // ino is an inode number unique within this filesystem. + // + // This field is initialized at creation time and is immutable. + ino uint64 + + // ftype is the file's type (a linux.S_IFMT mask). + // + // This field is initialized at creation time and is immutable. + ftype uint16 + + // mayBlock is true if hostFD is non-blocking, and operations on it may + // return EAGAIN or EWOULDBLOCK instead of blocking. + // + // This field is initialized at creation time and is immutable. + mayBlock bool + + // seekable is false if lseek(hostFD) returns ESPIPE. We assume that file + // offsets are meaningful iff seekable is true. + // + // This field is initialized at creation time and is immutable. + seekable bool + + // isTTY is true if this file represents a TTY. + // + // This field is initialized at creation time and is immutable. + isTTY bool + + // savable is true if hostFD may be saved/restored by its numeric value. + // + // This field is initialized at creation time and is immutable. + savable bool + + // Event queue for blocking operations. + queue waiter.Queue + + // mapsMu protects mappings. + mapsMu sync.Mutex `state:"nosave"` + + // If this file is mmappable, mappings tracks mappings of hostFD into + // memmap.MappingSpaces. + mappings memmap.MappingSet + + // pf implements platform.File for mappings of hostFD. + pf inodePlatformFile + + // If haveBuf is non-zero, hostFD represents a pipe, and buf contains data + // read from the pipe from previous calls to inode.beforeSave(). haveBuf + // and buf are protected by bufMu. haveBuf is accessed using atomic memory + // operations. + bufMu sync.Mutex `state:"nosave"` + haveBuf uint32 + buf []byte +} + +func newInode(ctx context.Context, fs *filesystem, hostFD int, savable bool, fileType linux.FileMode, isTTY bool) (*inode, error) { + // Determine if hostFD is seekable. _, err := unix.Seek(hostFD, 0, linux.SEEK_CUR) seekable := err != syserror.ESPIPE + // We expect regular files to be seekable, as this is required for them to + // be memory-mappable. + if !seekable && fileType == syscall.S_IFREG { + ctx.Infof("host.newInode: host FD %d is a non-seekable regular file", hostFD) + return nil, syserror.ESPIPE + } i := &inode{ - hostFD: hostFD, - ino: fs.NextIno(), - isTTY: isTTY, - wouldBlock: wouldBlock(uint32(fileType)), - seekable: seekable, - // NOTE(b/38213152): Technically, some obscure char devices can be memory - // mapped, but we only allow regular files. - canMap: fileType == linux.S_IFREG, + hostFD: hostFD, + ino: fs.NextIno(), + ftype: uint16(fileType), + mayBlock: fileType != syscall.S_IFREG && fileType != syscall.S_IFDIR, + seekable: seekable, + isTTY: isTTY, + savable: savable, } i.pf.inode = i i.EnableLeakCheck() - // Non-seekable files can't be memory mapped, assert this. - if !i.seekable && i.canMap { - panic("files that can return EWOULDBLOCK (sockets, pipes, etc.) cannot be memory mapped") - } - - // If the hostFD would block, we must set it to non-blocking and handle - // blocking behavior in the sentry. - if i.wouldBlock { + // If the hostFD can return EWOULDBLOCK when set to non-blocking, do so and + // handle blocking behavior in the sentry. + if i.mayBlock { if err := syscall.SetNonblock(i.hostFD, true); err != nil { return nil, err } @@ -80,6 +153,11 @@ func newInode(fs *filesystem, hostFD int, fileType linux.FileMode, isTTY bool) ( // NewFDOptions contains options to NewFD. type NewFDOptions struct { + // If Savable is true, the host file descriptor may be saved/restored by + // numeric value; the sandbox API requires a corresponding host FD with the + // same numeric value to be provieded at time of restore. + Savable bool + // If IsTTY is true, the file descriptor is a TTY. IsTTY bool @@ -114,7 +192,7 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions) } d := &kernfs.Dentry{} - i, err := newInode(fs, hostFD, linux.FileMode(s.Mode).FileType(), opts.IsTTY) + i, err := newInode(ctx, fs, hostFD, opts.Savable, linux.FileMode(s.Mode).FileType(), opts.IsTTY) if err != nil { return nil, err } @@ -132,7 +210,8 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions) // ImportFD sets up and returns a vfs.FileDescription from a donated fd. func ImportFD(ctx context.Context, mnt *vfs.Mount, hostFD int, isTTY bool) (*vfs.FileDescription, error) { return NewFD(ctx, mnt, hostFD, &NewFDOptions{ - IsTTY: isTTY, + Savable: true, + IsTTY: isTTY, }) } @@ -191,68 +270,6 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe return vfs.PrependPathSyntheticError{} } -// inode implements kernfs.Inode. -// -// +stateify savable -type inode struct { - kernfs.InodeNoStatFS - kernfs.InodeNotDirectory - kernfs.InodeNotSymlink - kernfs.InodeTemporary // This holds no meaning as this inode can't be Looked up and is always valid. - - locks vfs.FileLocks - - // When the reference count reaches zero, the host fd is closed. - inodeRefs - - // hostFD contains the host fd that this file was originally created from, - // which must be available at time of restore. - // - // This field is initialized at creation time and is immutable. - hostFD int - - // ino is an inode number unique within this filesystem. - // - // This field is initialized at creation time and is immutable. - ino uint64 - - // isTTY is true if this file represents a TTY. - // - // This field is initialized at creation time and is immutable. - isTTY bool - - // seekable is false if the host fd points to a file representing a stream, - // e.g. a socket or a pipe. Such files are not seekable and can return - // EWOULDBLOCK for I/O operations. - // - // This field is initialized at creation time and is immutable. - seekable bool - - // wouldBlock is true if the host FD would return EWOULDBLOCK for - // operations that would block. - // - // This field is initialized at creation time and is immutable. - wouldBlock bool - - // Event queue for blocking operations. - queue waiter.Queue - - // canMap specifies whether we allow the file to be memory mapped. - // - // This field is initialized at creation time and is immutable. - canMap bool - - // mapsMu protects mappings. - mapsMu sync.Mutex `state:"nosave"` - - // If canMap is true, mappings tracks mappings of hostFD into - // memmap.MappingSpaces. - mappings memmap.MappingSet - - // pf implements platform.File for mappings of hostFD. - pf inodePlatformFile -} - // CheckPermissions implements kernfs.Inode.CheckPermissions. func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error { var s syscall.Stat_t @@ -448,7 +465,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre // DecRef implements kernfs.Inode.DecRef. func (i *inode) DecRef(ctx context.Context) { i.inodeRefs.DecRef(func() { - if i.wouldBlock { + if i.mayBlock { fdnotifier.RemoveFD(int32(i.hostFD)) } if err := unix.Close(i.hostFD); err != nil { @@ -567,6 +584,13 @@ func (f *fileDescription) Allocate(ctx context.Context, mode, offset, length uin // PRead implements vfs.FileDescriptionImpl.PRead. func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { + return 0, syserror.EOPNOTSUPP + } + i := f.inode if !i.seekable { return 0, syserror.ESPIPE @@ -577,19 +601,31 @@ func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, off // Read implements vfs.FileDescriptionImpl.Read. func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { + return 0, syserror.EOPNOTSUPP + } + i := f.inode if !i.seekable { + bufN, err := i.readFromBuf(ctx, &dst) + if err != nil { + return bufN, err + } n, err := readFromHostFD(ctx, i.hostFD, dst, -1, opts.Flags) + total := bufN + n if isBlockError(err) { // If we got any data at all, return it as a "completed" partial read // rather than retrying until complete. - if n != 0 { + if total != 0 { err = nil } else { err = syserror.ErrWouldBlock } } - return n, err + return total, err } f.offsetMu.Lock() @@ -599,13 +635,26 @@ func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts return n, err } -func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags uint32) (int64, error) { - // Check that flags are supported. - // - // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. - if flags&^linux.RWF_HIPRI != 0 { - return 0, syserror.EOPNOTSUPP +func (i *inode) readFromBuf(ctx context.Context, dst *usermem.IOSequence) (int64, error) { + if atomic.LoadUint32(&i.haveBuf) == 0 { + return 0, nil } + i.bufMu.Lock() + defer i.bufMu.Unlock() + if len(i.buf) == 0 { + return 0, nil + } + n, err := dst.CopyOut(ctx, i.buf) + *dst = dst.DropFirst(n) + i.buf = i.buf[n:] + if len(i.buf) == 0 { + atomic.StoreUint32(&i.haveBuf, 0) + i.buf = nil + } + return int64(n), err +} + +func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags uint32) (int64, error) { reader := hostfd.GetReadWriterAt(int32(hostFD), offset, flags) n, err := dst.CopyOutFrom(ctx, reader) hostfd.PutReadWriterAt(reader) @@ -735,14 +784,16 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i } // Sync implements vfs.FileDescriptionImpl.Sync. -func (f *fileDescription) Sync(context.Context) error { +func (f *fileDescription) Sync(ctx context.Context) error { // TODO(gvisor.dev/issue/1897): Currently, we always sync everything. return unix.Fsync(f.inode.hostFD) } // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. func (f *fileDescription) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts) error { - if !f.inode.canMap { + // NOTE(b/38213152): Technically, some obscure char devices can be memory + // mapped, but we only allow regular files. + if f.inode.ftype != syscall.S_IFREG { return syserror.ENODEV } i := f.inode @@ -753,13 +804,17 @@ func (f *fileDescription) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts // EventRegister implements waiter.Waitable.EventRegister. func (f *fileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) { f.inode.queue.EventRegister(e, mask) - fdnotifier.UpdateFD(int32(f.inode.hostFD)) + if f.inode.mayBlock { + fdnotifier.UpdateFD(int32(f.inode.hostFD)) + } } // EventUnregister implements waiter.Waitable.EventUnregister. func (f *fileDescription) EventUnregister(e *waiter.Entry) { f.inode.queue.EventUnregister(e) - fdnotifier.UpdateFD(int32(f.inode.hostFD)) + if f.inode.mayBlock { + fdnotifier.UpdateFD(int32(f.inode.hostFD)) + } } // Readiness uses the poll() syscall to check the status of the underlying FD. diff --git a/pkg/sentry/fsimpl/host/mmap.go b/pkg/sentry/fsimpl/host/mmap.go index b51a17bed..3d7eb2f96 100644 --- a/pkg/sentry/fsimpl/host/mmap.go +++ b/pkg/sentry/fsimpl/host/mmap.go @@ -43,7 +43,7 @@ type inodePlatformFile struct { fileMapper fsutil.HostFileMapper // fileMapperInitOnce is used to lazily initialize fileMapper. - fileMapperInitOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + fileMapperInitOnce sync.Once `state:"nosave"` } // IncRef implements memmap.File.IncRef. diff --git a/pkg/sentry/fsimpl/host/save_restore.go b/pkg/sentry/fsimpl/host/save_restore.go new file mode 100644 index 000000000..7e32a8863 --- /dev/null +++ b/pkg/sentry/fsimpl/host/save_restore.go @@ -0,0 +1,78 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package host + +import ( + "fmt" + "io" + "sync/atomic" + "syscall" + + "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/hostfd" + "gvisor.dev/gvisor/pkg/usermem" +) + +// beforeSave is invoked by stateify. +func (i *inode) beforeSave() { + if !i.savable { + panic("host.inode is not savable") + } + if i.ftype == syscall.S_IFIFO { + // If this pipe FD is readable, drain it so that bytes in the pipe can + // be read after restore. (This is a legacy VFS1 feature.) We don't + // know if the pipe FD is readable, so just try reading and tolerate + // EBADF from the read. + i.bufMu.Lock() + defer i.bufMu.Unlock() + var buf [usermem.PageSize]byte + for { + n, err := hostfd.Preadv2(int32(i.hostFD), safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:])), -1 /* offset */, 0 /* flags */) + if n != 0 { + i.buf = append(i.buf, buf[:n]...) + } + if err != nil { + if err == io.EOF || err == syscall.EAGAIN || err == syscall.EBADF { + break + } + panic(fmt.Errorf("host.inode.beforeSave: buffering from pipe failed: %v", err)) + } + } + if len(i.buf) != 0 { + atomic.StoreUint32(&i.haveBuf, 1) + } + } +} + +// afterLoad is invoked by stateify. +func (i *inode) afterLoad() { + if i.mayBlock { + if err := syscall.SetNonblock(i.hostFD, true); err != nil { + panic(fmt.Sprintf("host.inode.afterLoad: failed to set host FD %d non-blocking: %v", i.hostFD, err)) + } + if err := fdnotifier.AddFD(int32(i.hostFD), &i.queue); err != nil { + panic(fmt.Sprintf("host.inode.afterLoad: fdnotifier.AddFD(%d) failed: %v", i.hostFD, err)) + } + } +} + +// afterLoad is invoked by stateify. +func (i *inodePlatformFile) afterLoad() { + if i.fileMapper.IsInited() { + // Ensure that we don't call i.fileMapper.Init() again. + i.fileMapperInitOnce.Do(func() {}) + } +} diff --git a/pkg/sentry/fsimpl/host/util.go b/pkg/sentry/fsimpl/host/util.go index 412bdb2eb..b2f43a119 100644 --- a/pkg/sentry/fsimpl/host/util.go +++ b/pkg/sentry/fsimpl/host/util.go @@ -43,12 +43,6 @@ func timespecToStatxTimestamp(ts unix.Timespec) linux.StatxTimestamp { return linux.StatxTimestamp{Sec: int64(ts.Sec), Nsec: uint32(ts.Nsec)} } -// wouldBlock returns true for file types that can return EWOULDBLOCK -// for blocking operations, e.g. pipes, character devices, and sockets. -func wouldBlock(fileType uint32) bool { - return fileType == syscall.S_IFIFO || fileType == syscall.S_IFCHR || fileType == syscall.S_IFSOCK -} - // isBlockError checks if an error is EAGAIN or EWOULDBLOCK. // If so, they can be transformed into syserror.ErrWouldBlock. func isBlockError(err error) bool { diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index 858cc24ce..aaad67ab8 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -4,6 +4,18 @@ load("//tools/go_generics:defs.bzl", "go_template_instance") licenses(["notice"]) go_template_instance( + name = "dentry_list", + out = "dentry_list.go", + package = "kernfs", + prefix = "dentry", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*Dentry", + "Linker": "*Dentry", + }, +) + +go_template_instance( name = "fstree", out = "fstree.go", package = "kernfs", @@ -27,22 +39,11 @@ go_template_instance( ) go_template_instance( - name = "dentry_refs", - out = "dentry_refs.go", - package = "kernfs", - prefix = "Dentry", - template = "//pkg/refs_vfs2:refs_template", - types = { - "T": "Dentry", - }, -) - -go_template_instance( name = "static_directory_refs", out = "static_directory_refs.go", package = "kernfs", prefix = "StaticDirectory", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "StaticDirectory", }, @@ -53,7 +54,7 @@ go_template_instance( out = "dir_refs.go", package = "kernfs_test", prefix = "dir", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "dir", }, @@ -64,7 +65,7 @@ go_template_instance( out = "readonly_dir_refs.go", package = "kernfs_test", prefix = "readonlyDir", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "readonlyDir", }, @@ -75,7 +76,7 @@ go_template_instance( out = "synthetic_directory_refs.go", package = "kernfs", prefix = "syntheticDirectory", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "syntheticDirectory", }, @@ -84,7 +85,7 @@ go_template_instance( go_library( name = "kernfs", srcs = [ - "dentry_refs.go", + "dentry_list.go", "dynamic_bytes_file.go", "fd_impl_util.go", "filesystem.go", @@ -104,8 +105,11 @@ go_library( "//pkg/fspath", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", + "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/time", "//pkg/sentry/memmap", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", @@ -129,6 +133,7 @@ go_test( "//pkg/context", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sentry/contexttest", "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index b929118b1..485504995 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -47,11 +47,11 @@ type DynamicBytesFile struct { var _ Inode = (*DynamicBytesFile)(nil) // Init initializes a dynamic bytes file. -func (f *DynamicBytesFile) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) { +func (f *DynamicBytesFile) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) } - f.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm) + f.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm) f.data = data } diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go index abf1905d6..f8dae22f8 100644 --- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -145,8 +145,12 @@ func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem { return fd.vfsfd.VirtualDentry().Mount().Filesystem() } +func (fd *GenericDirectoryFD) dentry() *Dentry { + return fd.vfsfd.Dentry().Impl().(*Dentry) +} + func (fd *GenericDirectoryFD) inode() Inode { - return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode + return fd.dentry().inode } // IterDirents implements vfs.FileDescriptionImpl.IterDirents. IterDirents holds @@ -176,8 +180,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent // Handle "..". if fd.off == 1 { - vfsd := fd.vfsfd.VirtualDentry().Dentry() - parentInode := genericParentOrSelf(vfsd.Impl().(*Dentry)).inode + parentInode := genericParentOrSelf(fd.dentry()).inode stat, err := parentInode.Stat(ctx, fd.filesystem(), opts) if err != nil { return err @@ -219,7 +222,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent var err error relOffset := fd.off - int64(len(fd.children.set)) - 2 - fd.off, err = fd.inode().IterDirents(ctx, cb, fd.off, relOffset) + fd.off, err = fd.inode().IterDirents(ctx, fd.vfsfd.Mount(), cb, fd.off, relOffset) return err } @@ -265,8 +268,7 @@ func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (l // SetStat implements vfs.FileDescriptionImpl.SetStat. func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { creds := auth.CredentialsFromContext(ctx) - inode := fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode - return inode.SetStat(ctx, fd.filesystem(), creds, opts) + return fd.inode().SetStat(ctx, fd.filesystem(), creds, opts) } // Allocate implements vfs.FileDescriptionImpl.Allocate. diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index 6426a55f6..399895f3e 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -373,7 +373,7 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v if !opts.ForSyntheticMountpoint || err == syserror.EEXIST { return err } - childI = newSyntheticDirectory(rp.Credentials(), opts.Mode) + childI = newSyntheticDirectory(ctx, rp.Credentials(), opts.Mode) } var child Dentry child.Init(fs, childI) @@ -517,9 +517,6 @@ afterTrailingSymlink: } var child Dentry child.Init(fs, childI) - // FIXME(gvisor.dev/issue/1193): Race between checking existence with - // fs.stepExistingLocked and parent.insertChild. If possible, we should hold - // dirMu from one to the other. parent.insertChild(pc, &child) // Open may block so we need to unlock fs.mu. IncRef child to prevent // its destruction while fs.mu is unlocked. diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index 122b10591..d9d76758a 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -21,9 +21,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // InodeNoopRefCount partially implements the Inode interface, specifically the @@ -143,7 +145,7 @@ func (InodeNotDirectory) Lookup(ctx context.Context, name string) (Inode, error) } // IterDirents implements Inode.IterDirents. -func (InodeNotDirectory) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { +func (InodeNotDirectory) IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { panic("IterDirents called on non-directory inode") } @@ -172,17 +174,23 @@ func (InodeNotSymlink) Getlink(context.Context, *vfs.Mount) (vfs.VirtualDentry, // // +stateify savable type InodeAttrs struct { - devMajor uint32 - devMinor uint32 - ino uint64 - mode uint32 - uid uint32 - gid uint32 - nlink uint32 + devMajor uint32 + devMinor uint32 + ino uint64 + mode uint32 + uid uint32 + gid uint32 + nlink uint32 + blockSize uint32 + + // Timestamps, all nsecs from the Unix epoch. + atime int64 + mtime int64 + ctime int64 } // Init initializes this InodeAttrs. -func (a *InodeAttrs) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, mode linux.FileMode) { +func (a *InodeAttrs) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, mode linux.FileMode) { if mode.FileType() == 0 { panic(fmt.Sprintf("No file type specified in 'mode' for InodeAttrs.Init(): mode=0%o", mode)) } @@ -198,6 +206,11 @@ func (a *InodeAttrs) Init(creds *auth.Credentials, devMajor, devMinor uint32, in atomic.StoreUint32(&a.uid, uint32(creds.EffectiveKUID)) atomic.StoreUint32(&a.gid, uint32(creds.EffectiveKGID)) atomic.StoreUint32(&a.nlink, nlink) + atomic.StoreUint32(&a.blockSize, usermem.PageSize) + now := ktime.NowFromContext(ctx).Nanoseconds() + atomic.StoreInt64(&a.atime, now) + atomic.StoreInt64(&a.mtime, now) + atomic.StoreInt64(&a.ctime, now) } // DevMajor returns the device major number. @@ -220,12 +233,33 @@ func (a *InodeAttrs) Mode() linux.FileMode { return linux.FileMode(atomic.LoadUint32(&a.mode)) } +// TouchAtime updates a.atime to the current time. +func (a *InodeAttrs) TouchAtime(ctx context.Context, mnt *vfs.Mount) { + if mnt.Flags.NoATime || mnt.ReadOnly() { + return + } + if err := mnt.CheckBeginWrite(); err != nil { + return + } + atomic.StoreInt64(&a.atime, ktime.NowFromContext(ctx).Nanoseconds()) + mnt.EndWrite() +} + +// TouchCMtime updates a.{c/m}time to the current time. The caller should +// synchronize calls to this so that ctime and mtime are updated to the same +// value. +func (a *InodeAttrs) TouchCMtime(ctx context.Context) { + now := ktime.NowFromContext(ctx).Nanoseconds() + atomic.StoreInt64(&a.mtime, now) + atomic.StoreInt64(&a.ctime, now) +} + // Stat partially implements Inode.Stat. Note that this function doesn't provide // all the stat fields, and the embedder should consider extending the result // with filesystem-specific fields. func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) { var stat linux.Statx - stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK + stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME stat.DevMajor = a.devMajor stat.DevMinor = a.devMinor stat.Ino = atomic.LoadUint64(&a.ino) @@ -233,21 +267,15 @@ func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (li stat.UID = atomic.LoadUint32(&a.uid) stat.GID = atomic.LoadUint32(&a.gid) stat.Nlink = atomic.LoadUint32(&a.nlink) - - // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps. - + stat.Blksize = atomic.LoadUint32(&a.blockSize) + stat.Atime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.atime)) + stat.Mtime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.mtime)) + stat.Ctime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.ctime)) return stat, nil } // SetStat implements Inode.SetStat. func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { - return a.SetInodeStat(ctx, fs, creds, opts) -} - -// SetInodeStat sets the corresponding attributes from opts to InodeAttrs. -// This function can be used by other kernfs-based filesystem implementation to -// sets the unexported attributes into InodeAttrs. -func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { if opts.Stat.Mask == 0 { return nil } @@ -256,9 +284,7 @@ func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds // inode numbers are immutable after node creation. Setting the size is often // allowed by kernfs files but does not do anything. If some other behavior is // needed, the embedder should consider extending SetStat. - // - // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps. - if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_SIZE) != 0 { + if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_SIZE) != 0 { return syserror.EPERM } if opts.Stat.Mask&linux.STATX_SIZE != 0 && a.Mode().IsDir() { @@ -286,6 +312,20 @@ func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds atomic.StoreUint32(&a.gid, stat.GID) } + now := ktime.NowFromContext(ctx).Nanoseconds() + if stat.Mask&linux.STATX_ATIME != 0 { + if stat.Atime.Nsec == linux.UTIME_NOW { + stat.Atime = linux.NsecToStatxTimestamp(now) + } + atomic.StoreInt64(&a.atime, stat.Atime.ToNsec()) + } + if stat.Mask&linux.STATX_MTIME != 0 { + if stat.Mtime.Nsec == linux.UTIME_NOW { + stat.Mtime = linux.NsecToStatxTimestamp(now) + } + atomic.StoreInt64(&a.mtime, stat.Mtime.ToNsec()) + } + return nil } @@ -421,7 +461,7 @@ func (o *OrderedChildren) Lookup(ctx context.Context, name string) (Inode, error } // IterDirents implements Inode.IterDirents. -func (o *OrderedChildren) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { +func (o *OrderedChildren) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { // All entries from OrderedChildren have already been handled in // GenericDirectoryFD.IterDirents. return offset, nil @@ -619,9 +659,9 @@ type StaticDirectory struct { var _ Inode = (*StaticDirectory)(nil) // NewStaticDir creates a new static directory and returns its dentry. -func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]Inode, fdOpts GenericDirectoryFDOptions) Inode { +func NewStaticDir(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]Inode, fdOpts GenericDirectoryFDOptions) Inode { inode := &StaticDirectory{} - inode.Init(creds, devMajor, devMinor, ino, perm, fdOpts) + inode.Init(ctx, creds, devMajor, devMinor, ino, perm, fdOpts) inode.EnableLeakCheck() inode.OrderedChildren.Init(OrderedChildrenOptions{}) @@ -632,12 +672,12 @@ func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64 } // Init initializes StaticDirectory. -func (s *StaticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, fdOpts GenericDirectoryFDOptions) { +func (s *StaticDirectory) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, fdOpts GenericDirectoryFDOptions) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) } s.fdOpts = fdOpts - s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeDirectory|perm) + s.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeDirectory|perm) } // Open implements Inode.Open. diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index 606081e68..5c5e09ac5 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -107,6 +107,17 @@ type Filesystem struct { // nextInoMinusOne is used to to allocate inode numbers on this // filesystem. Must be accessed by atomic operations. nextInoMinusOne uint64 + + // cachedDentries contains all dentries with 0 references. (Due to race + // conditions, it may also contain dentries with non-zero references.) + // cachedDentriesLen is the number of dentries in cachedDentries. These + // fields are protected by mu. + cachedDentries dentryList + cachedDentriesLen uint64 + + // MaxCachedDentries is the maximum size of cachedDentries. If not set, + // defaults to 0 and kernfs does not cache any dentries. This is immutable. + MaxCachedDentries uint64 } // deferDecRef defers dropping a dentry ref until the next call to @@ -165,7 +176,12 @@ const ( // +stateify savable type Dentry struct { vfsd vfs.Dentry - DentryRefs + + // refs is the reference count. When refs reaches 0, the dentry may be + // added to the cache or destroyed. If refs == -1, the dentry has already + // been destroyed. refs are allowed to go to 0 and increase again. refs is + // accessed using atomic memory operations. + refs int64 // fs is the owning filesystem. fs is immutable. fs *Filesystem @@ -177,6 +193,12 @@ type Dentry struct { parent *Dentry name string + // If cached is true, dentryEntry links dentry into + // Filesystem.cachedDentries. cached and dentryEntry are protected by + // Filesystem.mu. + cached bool + dentryEntry + // dirMu protects children and the names of child Dentries. // // Note that holding fs.mu for writing is not sufficient; @@ -188,6 +210,150 @@ type Dentry struct { inode Inode } +// IncRef implements vfs.DentryImpl.IncRef. +func (d *Dentry) IncRef() { + // d.refs may be 0 if d.fs.mu is locked, which serializes against + // d.cacheLocked(). + atomic.AddInt64(&d.refs, 1) +} + +// TryIncRef implements vfs.DentryImpl.TryIncRef. +func (d *Dentry) TryIncRef() bool { + for { + refs := atomic.LoadInt64(&d.refs) + if refs <= 0 { + return false + } + if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) { + return true + } + } +} + +// DecRef implements vfs.DentryImpl.DecRef. +func (d *Dentry) DecRef(ctx context.Context) { + if refs := atomic.AddInt64(&d.refs, -1); refs == 0 { + d.fs.mu.Lock() + d.cacheLocked(ctx) + d.fs.mu.Unlock() + } else if refs < 0 { + panic("kernfs.Dentry.DecRef() called without holding a reference") + } +} + +// cacheLocked should be called after d's reference count becomes 0. The ref +// count check may happen before acquiring d.fs.mu so there might be a race +// condition where the ref count is increased again by the time the caller +// acquires d.fs.mu. This race is handled. +// Only reachable dentries are added to the cache. However, a dentry might +// become unreachable *while* it is in the cache due to invalidation. +// +// Preconditions: d.fs.mu must be locked for writing. +func (d *Dentry) cacheLocked(ctx context.Context) { + // Dentries with a non-zero reference count must be retained. (The only way + // to obtain a reference on a dentry with zero references is via path + // resolution, which requires d.fs.mu, so if d.refs is zero then it will + // remain zero while we hold d.fs.mu for writing.) + refs := atomic.LoadInt64(&d.refs) + if refs == -1 { + // Dentry has already been destroyed. + panic(fmt.Sprintf("cacheLocked called on a dentry which has already been destroyed: %v", d)) + } + if refs > 0 { + if d.cached { + d.fs.cachedDentries.Remove(d) + d.fs.cachedDentriesLen-- + d.cached = false + } + return + } + // If the dentry is deleted and invalidated or has no parent, then it is no + // longer reachable by path resolution and should be dropped immediately + // because it has zero references. + // Note that a dentry may not always have a parent; for example magic links + // as described in Inode.Getlink. + if isDead := d.VFSDentry().IsDead(); isDead || d.parent == nil { + if !isDead { + d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, d.VFSDentry()) + } + if d.cached { + d.fs.cachedDentries.Remove(d) + d.fs.cachedDentriesLen-- + d.cached = false + } + d.destroyLocked(ctx) + return + } + // If d is already cached, just move it to the front of the LRU. + if d.cached { + d.fs.cachedDentries.Remove(d) + d.fs.cachedDentries.PushFront(d) + return + } + // Cache the dentry, then evict the least recently used cached dentry if + // the cache becomes over-full. + d.fs.cachedDentries.PushFront(d) + d.fs.cachedDentriesLen++ + d.cached = true + if d.fs.cachedDentriesLen <= d.fs.MaxCachedDentries { + return + } + // Evict the least recently used dentry because cache size is greater than + // max cache size (configured on mount). + victim := d.fs.cachedDentries.Back() + d.fs.cachedDentries.Remove(victim) + d.fs.cachedDentriesLen-- + victim.cached = false + // victim.refs may have become non-zero from an earlier path resolution + // after it was inserted into fs.cachedDentries. + if atomic.LoadInt64(&victim.refs) == 0 { + if !victim.vfsd.IsDead() { + victim.parent.dirMu.Lock() + // Note that victim can't be a mount point (in any mount + // namespace), since VFS holds references on mount points. + d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, victim.VFSDentry()) + delete(victim.parent.children, victim.name) + victim.parent.dirMu.Unlock() + } + victim.destroyLocked(ctx) + } + // Whether or not victim was destroyed, we brought fs.cachedDentriesLen + // back down to fs.MaxCachedDentries, so we don't loop. +} + +// destroyLocked destroys the dentry. +// +// Preconditions: +// * d.fs.mu must be locked for writing. +// * d.refs == 0. +// * d should have been removed from d.parent.children, i.e. d is not reachable +// by path traversal. +// * d.vfsd.IsDead() is true. +func (d *Dentry) destroyLocked(ctx context.Context) { + switch atomic.LoadInt64(&d.refs) { + case 0: + // Mark the dentry destroyed. + atomic.StoreInt64(&d.refs, -1) + case -1: + panic("dentry.destroyLocked() called on already destroyed dentry") + default: + panic("dentry.destroyLocked() called with references on the dentry") + } + + d.inode.DecRef(ctx) // IncRef from Init. + d.inode = nil + + // Drop the reference held by d on its parent without recursively locking + // d.fs.mu. + if d.parent != nil { + if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 { + d.parent.cacheLocked(ctx) + } else if refs < 0 { + panic("kernfs.Dentry.DecRef() called without holding a reference") + } + } +} + // Init initializes this dentry. // // Precondition: Caller must hold a reference on inode. @@ -197,6 +363,7 @@ func (d *Dentry) Init(fs *Filesystem, inode Inode) { d.vfsd.Init(d) d.fs = fs d.inode = inode + atomic.StoreInt64(&d.refs, 1) ftype := inode.Mode().FileType() if ftype == linux.ModeDirectory { d.flags |= dflagsIsDir @@ -204,7 +371,6 @@ func (d *Dentry) Init(fs *Filesystem, inode Inode) { if ftype == linux.ModeSymlink { d.flags |= dflagsIsSymlink } - d.EnableLeakCheck() } // VFSDentry returns the generic vfs dentry for this kernfs dentry. @@ -222,32 +388,6 @@ func (d *Dentry) isSymlink() bool { return atomic.LoadUint32(&d.flags)&dflagsIsSymlink != 0 } -// DecRef implements vfs.DentryImpl.DecRef. -func (d *Dentry) DecRef(ctx context.Context) { - decRefParent := false - d.fs.mu.Lock() - d.DentryRefs.DecRef(func() { - d.inode.DecRef(ctx) // IncRef from Init. - d.inode = nil - if d.parent != nil { - // We will DecRef d.parent once all locks are dropped. - decRefParent = true - d.parent.dirMu.Lock() - // Remove d from parent.children. It might already have been - // removed due to invalidation. - if _, ok := d.parent.children[d.name]; ok { - delete(d.parent.children, d.name) - d.fs.VFSFilesystem().VirtualFilesystem().InvalidateDentry(ctx, d.VFSDentry()) - } - d.parent.dirMu.Unlock() - } - }) - d.fs.mu.Unlock() - if decRefParent { - d.parent.DecRef(ctx) // IncRef from Dentry.insertChild. - } -} - // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. // // Although Linux technically supports inotify on pseudo filesystems (inotify @@ -267,7 +407,9 @@ func (d *Dentry) OnZeroWatches(context.Context) {} // this dentry. This does not update the directory inode, so calling this on its // own isn't sufficient to insert a child into a directory. // -// Precondition: d must represent a directory inode. +// Preconditions: +// * d must represent a directory inode. +// * d.fs.mu must be locked for at least reading. func (d *Dentry) insertChild(name string, child *Dentry) { d.dirMu.Lock() d.insertChildLocked(name, child) @@ -280,6 +422,7 @@ func (d *Dentry) insertChild(name string, child *Dentry) { // Preconditions: // * d must represent a directory inode. // * d.dirMu must be locked. +// * d.fs.mu must be locked for at least reading. func (d *Dentry) insertChildLocked(name string, child *Dentry) { if !d.isDir() { panic(fmt.Sprintf("insertChildLocked called on non-directory Dentry: %+v.", d)) @@ -436,7 +579,7 @@ type inodeDirectory interface { // the inode is a directory. // // The child returned by Lookup will be hashed into the VFS dentry tree, - // atleast for the duration of the current FS operation. + // at least for the duration of the current FS operation. // // Lookup must return the child with an extra reference whose ownership is // transferred to the dentry that is created to point to that inode. If @@ -454,7 +597,7 @@ type inodeDirectory interface { // inside the entries returned by this IterDirents invocation. In other words, // 'offset' should be used to calculate each vfs.Dirent.NextOff as well as // the return value, while 'relOffset' is the place to start iteration. - IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) + IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) } type inodeSymlink interface { diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index 82fa19c03..2418eec44 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -36,7 +36,7 @@ const staticFileContent = "This is sample content for a static test file." // RootDentryFn is a generator function for creating the root dentry of a test // filesystem. See newTestSystem. -type RootDentryFn func(*auth.Credentials, *filesystem) kernfs.Inode +type RootDentryFn func(context.Context, *auth.Credentials, *filesystem) kernfs.Inode // newTestSystem sets up a minimal environment for running a test, including an // instance of a test filesystem. Tests can control the contents of the @@ -72,10 +72,10 @@ type file struct { content string } -func (fs *filesystem) newFile(creds *auth.Credentials, content string) kernfs.Inode { +func (fs *filesystem) newFile(ctx context.Context, creds *auth.Credentials, content string) kernfs.Inode { f := &file{} f.content = content - f.DynamicBytesFile.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), f, 0777) + f.DynamicBytesFile.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), f, 0777) return f } @@ -105,9 +105,9 @@ type readonlyDir struct { locks vfs.FileLocks } -func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { +func (fs *filesystem) newReadonlyDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { dir := &readonlyDir{} - dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) + dir.attrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) dir.EnableLeakCheck() dir.IncLinks(dir.OrderedChildren.Populate(contents)) @@ -142,10 +142,10 @@ type dir struct { fs *filesystem } -func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { +func (fs *filesystem) newDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { dir := &dir{} dir.fs = fs - dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) + dir.attrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true}) dir.EnableLeakCheck() @@ -169,22 +169,24 @@ func (d *dir) DecRef(ctx context.Context) { func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (kernfs.Inode, error) { creds := auth.CredentialsFromContext(ctx) - dir := d.fs.newDir(creds, opts.Mode, nil) + dir := d.fs.newDir(ctx, creds, opts.Mode, nil) if err := d.OrderedChildren.Insert(name, dir); err != nil { dir.DecRef(ctx) return nil, err } + d.TouchCMtime(ctx) d.IncLinks(1) return dir, nil } func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (kernfs.Inode, error) { creds := auth.CredentialsFromContext(ctx) - f := d.fs.newFile(creds, "") + f := d.fs.newFile(ctx, creds, "") if err := d.OrderedChildren.Insert(name, f); err != nil { f.DecRef(ctx) return nil, err } + d.TouchCMtime(ctx) return f, nil } @@ -209,7 +211,7 @@ func (fsType) Release(ctx context.Context) {} func (fst fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opt vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { fs := &filesystem{} fs.VFSFilesystem().Init(vfsObj, &fst, fs) - root := fst.rootFn(creds, fs) + root := fst.rootFn(ctx, creds, fs) var d kernfs.Dentry d.Init(&fs.Filesystem, root) return fs.VFSFilesystem(), d.VFSDentry(), nil @@ -218,9 +220,9 @@ func (fst fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesyst // -------------------- Remainder of the file are test cases -------------------- func TestBasic(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ - "file1": fs.newFile(creds, staticFileContent), + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "file1": fs.newFile(ctx, creds, staticFileContent), }) }) defer sys.Destroy() @@ -228,9 +230,9 @@ func TestBasic(t *testing.T) { } func TestMkdirGetDentry(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ - "dir1": fs.newDir(creds, 0755, nil), + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "dir1": fs.newDir(ctx, creds, 0755, nil), }) }) defer sys.Destroy() @@ -243,9 +245,9 @@ func TestMkdirGetDentry(t *testing.T) { } func TestReadStaticFile(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ - "file1": fs.newFile(creds, staticFileContent), + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "file1": fs.newFile(ctx, creds, staticFileContent), }) }) defer sys.Destroy() @@ -269,9 +271,9 @@ func TestReadStaticFile(t *testing.T) { } func TestCreateNewFileInStaticDir(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ - "dir1": fs.newDir(creds, 0755, nil), + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "dir1": fs.newDir(ctx, creds, 0755, nil), }) }) defer sys.Destroy() @@ -296,8 +298,8 @@ func TestCreateNewFileInStaticDir(t *testing.T) { } func TestDirFDReadWrite(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, nil) + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, nil) }) defer sys.Destroy() @@ -320,14 +322,14 @@ func TestDirFDReadWrite(t *testing.T) { } func TestDirFDIterDirents(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ // Fill root with nodes backed by various inode implementations. - "dir1": fs.newReadonlyDir(creds, 0755, nil), - "dir2": fs.newDir(creds, 0755, map[string]kernfs.Inode{ - "dir3": fs.newDir(creds, 0755, nil), + "dir1": fs.newReadonlyDir(ctx, creds, 0755, nil), + "dir2": fs.newDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "dir3": fs.newDir(ctx, creds, 0755, nil), }), - "file1": fs.newFile(creds, staticFileContent), + "file1": fs.newFile(ctx, creds, staticFileContent), }) }) defer sys.Destroy() diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go index 934cc6c9e..a0736c0d6 100644 --- a/pkg/sentry/fsimpl/kernfs/symlink.go +++ b/pkg/sentry/fsimpl/kernfs/symlink.go @@ -38,16 +38,16 @@ type StaticSymlink struct { var _ Inode = (*StaticSymlink)(nil) // NewStaticSymlink creates a new symlink file pointing to 'target'. -func NewStaticSymlink(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, target string) Inode { +func NewStaticSymlink(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, target string) Inode { inode := &StaticSymlink{} - inode.Init(creds, devMajor, devMinor, ino, target) + inode.Init(ctx, creds, devMajor, devMinor, ino, target) return inode } // Init initializes the instance. -func (s *StaticSymlink) Init(creds *auth.Credentials, devMajor uint32, devMinor uint32, ino uint64, target string) { +func (s *StaticSymlink) Init(ctx context.Context, creds *auth.Credentials, devMajor uint32, devMinor uint32, ino uint64, target string) { s.target = target - s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeSymlink|0777) + s.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeSymlink|0777) } // Readlink implements Inode.Readlink. diff --git a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go index d0ed17b18..463d77d79 100644 --- a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go +++ b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go @@ -41,17 +41,17 @@ type syntheticDirectory struct { var _ Inode = (*syntheticDirectory)(nil) -func newSyntheticDirectory(creds *auth.Credentials, perm linux.FileMode) Inode { +func newSyntheticDirectory(ctx context.Context, creds *auth.Credentials, perm linux.FileMode) Inode { inode := &syntheticDirectory{} - inode.Init(creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm) + inode.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm) return inode } -func (dir *syntheticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { +func (dir *syntheticDirectory) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("perm contains non-permission bits: %#o", perm)) } - dir.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.S_IFDIR|perm) + dir.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.S_IFDIR|perm) dir.OrderedChildren.Init(OrderedChildrenOptions{ Writable: true, }) @@ -76,11 +76,12 @@ func (dir *syntheticDirectory) NewDir(ctx context.Context, name string, opts vfs if !opts.ForSyntheticMountpoint { return nil, syserror.EPERM } - subdirI := newSyntheticDirectory(auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask) + subdirI := newSyntheticDirectory(ctx, auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask) if err := dir.OrderedChildren.Insert(name, subdirI); err != nil { subdirI.DecRef(ctx) return nil, err } + dir.TouchCMtime(ctx) return subdirI, nil } diff --git a/pkg/sentry/fsimpl/overlay/BUILD b/pkg/sentry/fsimpl/overlay/BUILD index 1e11b0428..fd6c55921 100644 --- a/pkg/sentry/fsimpl/overlay/BUILD +++ b/pkg/sentry/fsimpl/overlay/BUILD @@ -23,6 +23,7 @@ go_library( "fstree.go", "overlay.go", "regular_file.go", + "save_restore.go", ], visibility = ["//pkg/sentry:internal"], deps = [ @@ -30,6 +31,7 @@ go_library( "//pkg/context", "//pkg/fspath", "//pkg/log", + "//pkg/refsvfs2", "//pkg/sentry/arch", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index 78a01bbb7..10161a08d 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -302,8 +302,14 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str child.devMinor = fs.dirDevMinor child.ino = fs.newDirIno() } else if !child.upperVD.Ok() { + childDevMinor, err := fs.getLowerDevMinor(child.devMajor, child.devMinor) + if err != nil { + ctx.Infof("overlay.filesystem.lookupLocked: failed to map lower layer device number (%d, %d) to an overlay-specific device number: %v", child.devMajor, child.devMinor, err) + child.destroyLocked(ctx) + return nil, err + } child.devMajor = linux.UNNAMED_MAJOR - child.devMinor = fs.lowerDevMinors[child.lowerVDs[0].Mount().Filesystem()] + child.devMinor = childDevMinor } parent.IncRef() diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go index 4c5de8d32..c812f0a70 100644 --- a/pkg/sentry/fsimpl/overlay/overlay.go +++ b/pkg/sentry/fsimpl/overlay/overlay.go @@ -22,6 +22,7 @@ // filesystem.renameMu // dentry.dirMu // dentry.copyMu +// filesystem.devMu // *** "memmap.Mappable locks" below this point // dentry.mapsMu // *** "memmap.Mappable locks taken by Translate" below this point @@ -33,12 +34,14 @@ package overlay import ( + "fmt" "strings" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/refsvfs2" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -99,10 +102,15 @@ type filesystem struct { // is immutable. dirDevMinor uint32 - // lowerDevMinors maps lower layer filesystems to device minor numbers - // assigned to non-directory files originating from that filesystem. - // lowerDevMinors is immutable. - lowerDevMinors map[*vfs.Filesystem]uint32 + // lowerDevMinors maps device numbers from lower layer filesystems to + // device minor numbers assigned to non-directory files originating from + // that filesystem. (This remapping is necessary for lower layers because a + // file on a lower layer, and that same file on an overlay, are + // distinguishable because they will diverge after copy-up; this isn't true + // for non-directory files already on the upper layer.) lowerDevMinors is + // protected by devMu. + devMu sync.Mutex `state:"nosave"` + lowerDevMinors map[layerDevNumber]uint32 // renameMu synchronizes renaming with non-renaming operations in order to // ensure consistent lock ordering between dentry.dirMu in different @@ -114,78 +122,69 @@ type filesystem struct { lastDirIno uint64 } +// +stateify savable +type layerDevNumber struct { + major uint32 + minor uint32 +} + // GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { mopts := vfs.GenericParseMountOptions(opts.Data) fsoptsRaw := opts.InternalData - fsopts, haveFSOpts := fsoptsRaw.(FilesystemOptions) - if fsoptsRaw != nil && !haveFSOpts { + fsopts, ok := fsoptsRaw.(FilesystemOptions) + if fsoptsRaw != nil && !ok { ctx.Infof("overlay.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted overlay.FilesystemOptions or nil", fsoptsRaw) return nil, nil, syserror.EINVAL } - if haveFSOpts { - if len(fsopts.LowerRoots) == 0 { - ctx.Infof("overlay.FilesystemType.GetFilesystem: LowerRoots must be non-empty") + vfsroot := vfs.RootFromContext(ctx) + if vfsroot.Ok() { + defer vfsroot.DecRef(ctx) + } + + if upperPathname, ok := mopts["upperdir"]; ok { + if fsopts.UpperRoot.Ok() { + ctx.Infof("overlay.FilesystemType.GetFilesystem: both upperdir and FilesystemOptions.UpperRoot are specified") return nil, nil, syserror.EINVAL } - if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() { - ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two LowerRoots are required when UpperRoot is unspecified") + delete(mopts, "upperdir") + // Linux overlayfs also requires a workdir when upperdir is + // specified; we don't, so silently ignore this option. + delete(mopts, "workdir") + upperPath := fspath.Parse(upperPathname) + if !upperPath.Absolute { + ctx.Infof("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname) return nil, nil, syserror.EINVAL } - // We don't enforce a maximum number of lower layers when not - // configured by applications; the sandbox owner can have an overlay - // filesystem with any number of lower layers. - } else { - vfsroot := vfs.RootFromContext(ctx) - defer vfsroot.DecRef(ctx) - upperPathname, ok := mopts["upperdir"] - if ok { - delete(mopts, "upperdir") - // Linux overlayfs also requires a workdir when upperdir is - // specified; we don't, so silently ignore this option. - delete(mopts, "workdir") - upperPath := fspath.Parse(upperPathname) - if !upperPath.Absolute { - ctx.Infof("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname) - return nil, nil, syserror.EINVAL - } - upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{ - Root: vfsroot, - Start: vfsroot, - Path: upperPath, - FollowFinalSymlink: true, - }, &vfs.GetDentryOptions{ - CheckSearchable: true, - }) - if err != nil { - ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err) - return nil, nil, err - } - defer upperRoot.DecRef(ctx) - privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */) - if err != nil { - ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err) - return nil, nil, err - } - defer privateUpperRoot.DecRef(ctx) - fsopts.UpperRoot = privateUpperRoot + upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{ + Root: vfsroot, + Start: vfsroot, + Path: upperPath, + FollowFinalSymlink: true, + }, &vfs.GetDentryOptions{ + CheckSearchable: true, + }) + if err != nil { + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err) + return nil, nil, err + } + privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */) + upperRoot.DecRef(ctx) + if err != nil { + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err) + return nil, nil, err } - lowerPathnamesStr, ok := mopts["lowerdir"] - if !ok { - ctx.Infof("overlay.FilesystemType.GetFilesystem: missing required option lowerdir") + defer privateUpperRoot.DecRef(ctx) + fsopts.UpperRoot = privateUpperRoot + } + + if lowerPathnamesStr, ok := mopts["lowerdir"]; ok { + if len(fsopts.LowerRoots) != 0 { + ctx.Infof("overlay.FilesystemType.GetFilesystem: both lowerdir and FilesystemOptions.LowerRoots are specified") return nil, nil, syserror.EINVAL } delete(mopts, "lowerdir") lowerPathnames := strings.Split(lowerPathnamesStr, ":") - const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK - if len(lowerPathnames) < 2 && !fsopts.UpperRoot.Ok() { - ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two lowerdirs are required when upperdir is unspecified") - return nil, nil, syserror.EINVAL - } - if len(lowerPathnames) > maxLowerLayers { - ctx.Infof("overlay.FilesystemType.GetFilesystem: %d lowerdirs specified, maximum %d", len(lowerPathnames), maxLowerLayers) - return nil, nil, syserror.EINVAL - } for _, lowerPathname := range lowerPathnames { lowerPath := fspath.Parse(lowerPathname) if !lowerPath.Absolute { @@ -204,8 +203,8 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve lowerdir %q: %v", lowerPathname, err) return nil, nil, err } - defer lowerRoot.DecRef(ctx) privateLowerRoot, err := clonePrivateMount(vfsObj, lowerRoot, true /* forceReadOnly */) + lowerRoot.DecRef(ctx) if err != nil { ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of lowerdir %q: %v", lowerPathname, err) return nil, nil, err @@ -214,31 +213,31 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fsopts.LowerRoots = append(fsopts.LowerRoots, privateLowerRoot) } } + if len(mopts) != 0 { ctx.Infof("overlay.FilesystemType.GetFilesystem: unused options: %v", mopts) return nil, nil, syserror.EINVAL } - // Allocate device numbers. + if len(fsopts.LowerRoots) == 0 { + ctx.Infof("overlay.FilesystemType.GetFilesystem: at least one lower layer is required") + return nil, nil, syserror.EINVAL + } + if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() { + ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two lower layers are required when no upper layer is present") + return nil, nil, syserror.EINVAL + } + const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK + if len(fsopts.LowerRoots) > maxLowerLayers { + ctx.Infof("overlay.FilesystemType.GetFilesystem: %d lower layers specified, maximum %d", len(fsopts.LowerRoots), maxLowerLayers) + return nil, nil, syserror.EINVAL + } + + // Allocate dirDevMinor. lowerDevMinors are allocated dynamically. dirDevMinor, err := vfsObj.GetAnonBlockDevMinor() if err != nil { return nil, nil, err } - lowerDevMinors := make(map[*vfs.Filesystem]uint32) - for _, lowerRoot := range fsopts.LowerRoots { - lowerFS := lowerRoot.Mount().Filesystem() - if _, ok := lowerDevMinors[lowerFS]; !ok { - devMinor, err := vfsObj.GetAnonBlockDevMinor() - if err != nil { - vfsObj.PutAnonBlockDevMinor(dirDevMinor) - for _, lowerDevMinor := range lowerDevMinors { - vfsObj.PutAnonBlockDevMinor(lowerDevMinor) - } - return nil, nil, err - } - lowerDevMinors[lowerFS] = devMinor - } - } // Take extra references held by the filesystem. if fsopts.UpperRoot.Ok() { @@ -252,7 +251,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt opts: fsopts, creds: creds.Fork(), dirDevMinor: dirDevMinor, - lowerDevMinors: lowerDevMinors, + lowerDevMinors: make(map[layerDevNumber]uint32), } fs.vfsfs.Init(vfsObj, &fstype, fs) @@ -302,7 +301,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt root.ino = fs.newDirIno() } else if !root.upperVD.Ok() { root.devMajor = linux.UNNAMED_MAJOR - root.devMinor = fs.lowerDevMinors[root.lowerVDs[0].Mount().Filesystem()] + rootDevMinor, err := fs.getLowerDevMinor(rootStat.DevMajor, rootStat.DevMinor) + if err != nil { + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to get device number for root: %v", err) + root.destroyLocked(ctx) + fs.vfsfs.DecRef(ctx) + return nil, nil, err + } + root.devMinor = rootDevMinor root.ino = rootStat.Ino } else { root.devMajor = rootStat.DevMajor @@ -375,6 +381,21 @@ func (fs *filesystem) newDirIno() uint64 { return atomic.AddUint64(&fs.lastDirIno, 1) } +func (fs *filesystem) getLowerDevMinor(layerMajor, layerMinor uint32) (uint32, error) { + fs.devMu.Lock() + defer fs.devMu.Unlock() + orig := layerDevNumber{layerMajor, layerMinor} + if minor, ok := fs.lowerDevMinors[orig]; ok { + return minor, nil + } + minor, err := fs.vfsfs.VirtualFilesystem().GetAnonBlockDevMinor() + if err != nil { + return 0, err + } + fs.lowerDevMinors[orig] = minor + return minor, nil +} + // dentry implements vfs.DentryImpl. // // +stateify savable @@ -458,9 +479,9 @@ type dentry struct { // // - isMappable is non-zero iff wrappedMappable is non-nil. isMappable is // accessed using atomic memory operations. - mapsMu sync.Mutex + mapsMu sync.Mutex `state:"nosave"` lowerMappings memmap.MappingSet - dataMu sync.RWMutex + dataMu sync.RWMutex `state:"nosave"` wrappedMappable memmap.Mappable isMappable uint32 @@ -484,6 +505,9 @@ func (fs *filesystem) newDentry() *dentry { } d.lowerVDs = d.inlineLowerVDs[:0] d.vfsd.Init(d) + if refsvfs2.LeakCheckEnabled() { + refsvfs2.Register(d, "overlay.dentry") + } return d } @@ -583,6 +607,14 @@ func (d *dentry) destroyLocked(ctx context.Context) { panic("overlay.dentry.DecRef() called without holding a reference") } } + if refsvfs2.LeakCheckEnabled() { + refsvfs2.Unregister(d, "overlay.dentry") + } +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (d *dentry) LeakMessage() string { + return fmt.Sprintf("[overlay.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs)) } // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. diff --git a/pkg/sentry/fsimpl/overlay/save_restore.go b/pkg/sentry/fsimpl/overlay/save_restore.go new file mode 100644 index 000000000..054e17b17 --- /dev/null +++ b/pkg/sentry/fsimpl/overlay/save_restore.go @@ -0,0 +1,27 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package overlay + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +func (d *dentry) afterLoad() { + if refsvfs2.LeakCheckEnabled() && atomic.LoadInt64(&d.refs) != -1 { + refsvfs2.Register(d, "overlay.dentry") + } +} diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index 2e086e34c..5196a2a80 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "fd_dir_inode_refs.go", package = "proc", prefix = "fdDirInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "fdDirInode", }, @@ -19,7 +19,7 @@ go_template_instance( out = "fd_info_dir_inode_refs.go", package = "proc", prefix = "fdInfoDirInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "fdInfoDirInode", }, @@ -30,7 +30,7 @@ go_template_instance( out = "subtasks_inode_refs.go", package = "proc", prefix = "subtasksInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "subtasksInode", }, @@ -41,7 +41,7 @@ go_template_instance( out = "task_inode_refs.go", package = "proc", prefix = "taskInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "taskInode", }, @@ -52,7 +52,7 @@ go_template_instance( out = "tasks_inode_refs.go", package = "proc", prefix = "tasksInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "tasksInode", }, @@ -82,6 +82,7 @@ go_library( "//pkg/context", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/fs/lock", "//pkg/sentry/fsbridge", diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go index fd70a07de..99abcab66 100644 --- a/pkg/sentry/fsimpl/proc/filesystem.go +++ b/pkg/sentry/fsimpl/proc/filesystem.go @@ -17,6 +17,7 @@ package proc import ( "fmt" + "strconv" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -24,10 +25,14 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" ) -// Name is the default filesystem name. -const Name = "proc" +const ( + // Name is the default filesystem name. + Name = "proc" + defaultMaxCachedDentries = uint64(1000) +) // FilesystemType is the factory class for procfs. // @@ -63,9 +68,22 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF if err != nil { return nil, nil, err } + + mopts := vfs.GenericParseMountOptions(opts.Data) + maxCachedDentries := defaultMaxCachedDentries + if str, ok := mopts["dentry_cache_limit"]; ok { + delete(mopts, "dentry_cache_limit") + maxCachedDentries, err = strconv.ParseUint(str, 10, 64) + if err != nil { + ctx.Warningf("proc.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str) + return nil, nil, syserror.EINVAL + } + } + procfs := &filesystem{ devMinor: devMinor, } + procfs.MaxCachedDentries = maxCachedDentries procfs.VFSFilesystem().Init(vfsObj, &ft, procfs) var cgroups map[string]string @@ -74,7 +92,7 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF cgroups = data.Cgroups } - inode := procfs.newTasksInode(k, pidns, cgroups) + inode := procfs.newTasksInode(ctx, k, pidns, cgroups) var dentry kernfs.Dentry dentry.Init(&procfs.Filesystem, inode) return procfs.VFSFilesystem(), dentry.VFSDentry(), nil @@ -94,11 +112,11 @@ type dynamicInode interface { kernfs.Inode vfs.DynamicBytesSource - Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) + Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) } -func (fs *filesystem) newInode(creds *auth.Credentials, perm linux.FileMode, inode dynamicInode) dynamicInode { - inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), inode, perm) +func (fs *filesystem) newInode(ctx context.Context, creds *auth.Credentials, perm linux.FileMode, inode dynamicInode) dynamicInode { + inode.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), inode, perm) return inode } @@ -114,8 +132,8 @@ func newStaticFile(data string) *staticFile { return &staticFile{StaticData: vfs.StaticData{Data: data}} } -func (fs *filesystem) newStaticDir(creds *auth.Credentials, children map[string]kernfs.Inode) kernfs.Inode { - return kernfs.NewStaticDir(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, children, kernfs.GenericDirectoryFDOptions{ +func (fs *filesystem) newStaticDir(ctx context.Context, creds *auth.Credentials, children map[string]kernfs.Inode) kernfs.Inode { + return kernfs.NewStaticDir(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, children, kernfs.GenericDirectoryFDOptions{ SeekEnd: kernfs.SeekEndZero, }) } diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go index bad2fab4f..cb3c5e0fd 100644 --- a/pkg/sentry/fsimpl/proc/subtasks.go +++ b/pkg/sentry/fsimpl/proc/subtasks.go @@ -58,7 +58,7 @@ func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, cgroupControllers: cgroupControllers, } // Note: credentials are overridden by taskOwnedInode. - subInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + subInode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) subInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) subInode.EnableLeakCheck() @@ -84,7 +84,7 @@ func (i *subtasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, } // IterDirents implements kernfs.inodeDirectory.IterDirents. -func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { +func (i *subtasksInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { tasks := i.task.ThreadGroup().MemberIDs(i.pidns) if len(tasks) == 0 { return offset, syserror.ENOENT diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index b63a4eca0..19011b010 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -64,6 +64,7 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace "gid_map": fs.newTaskOwnedInode(task, fs.NextIno(), 0644, &idMapData{task: task, gids: true}), "io": fs.newTaskOwnedInode(task, fs.NextIno(), 0400, newIO(task, isThreadGroup)), "maps": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mapsData{task: task}), + "mem": fs.newMemInode(task, fs.NextIno(), 0400), "mountinfo": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mountInfoData{task: task}), "mounts": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mountsData{task: task}), "net": fs.newTaskNetDir(task), @@ -89,7 +90,7 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace taskInode := &taskInode{task: task} // Note: credentials are overridden by taskOwnedInode. - taskInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + taskInode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) taskInode.EnableLeakCheck() inode := &taskOwnedInode{Inode: taskInode, owner: task} @@ -144,7 +145,7 @@ var _ kernfs.Inode = (*taskOwnedInode)(nil) func (fs *filesystem) newTaskOwnedInode(task *kernel.Task, ino uint64, perm linux.FileMode, inode dynamicInode) kernfs.Inode { // Note: credentials are overridden by taskOwnedInode. - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm) + inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm) return &taskOwnedInode{Inode: inode, owner: task} } @@ -152,7 +153,7 @@ func (fs *filesystem) newTaskOwnedInode(task *kernel.Task, ino uint64, perm linu func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.FileMode, children map[string]kernfs.Inode) kernfs.Inode { // Note: credentials are overridden by taskOwnedInode. fdOpts := kernfs.GenericDirectoryFDOptions{SeekEnd: kernfs.SeekEndZero} - dir := kernfs.NewStaticDir(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, children, fdOpts) + dir := kernfs.NewStaticDir(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, children, fdOpts) return &taskOwnedInode{Inode: dir, owner: task} } diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go index 2c80ac5c2..d268b44be 100644 --- a/pkg/sentry/fsimpl/proc/task_fds.go +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -64,7 +64,7 @@ type fdDir struct { } // IterDirents implements kernfs.inodeDirectory.IterDirents. -func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { +func (i *fdDir) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { var fds []int32 i.task.WithMuLocked(func(t *kernel.Task) { if fdTable := t.FDTable(); fdTable != nil { @@ -127,15 +127,15 @@ func (fs *filesystem) newFDDirInode(task *kernel.Task) kernfs.Inode { produceSymlink: true, }, } - inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + inode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) inode.EnableLeakCheck() inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) return inode } // IterDirents implements kernfs.inodeDirectory.IterDirents. -func (i *fdDirInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { - return i.fdDir.IterDirents(ctx, cb, offset, relOffset) +func (i *fdDirInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { + return i.fdDir.IterDirents(ctx, mnt, cb, offset, relOffset) } // Lookup implements kernfs.inodeDirectory.Lookup. @@ -209,7 +209,7 @@ func (fs *filesystem) newFDSymlink(task *kernel.Task, fd int32, ino uint64) kern task: task, fd: fd, } - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) + inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) return inode } @@ -264,7 +264,7 @@ func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) kernfs.Inode { task: task, }, } - inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + inode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) inode.EnableLeakCheck() inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) return inode @@ -288,8 +288,8 @@ func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (kernfs.Inode, } // IterDirents implements Inode.IterDirents. -func (i *fdInfoDirInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { - return i.fdDir.IterDirents(ctx, cb, offset, relOffset) +func (i *fdInfoDirInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { + return i.fdDir.IterDirents(ctx, mnt, cb, offset, relOffset) } // Open implements kernfs.Inode.Open. diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index 79f8b7e9f..ba71d0fde 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -249,7 +250,7 @@ type commInode struct { func (fs *filesystem) newComm(task *kernel.Task, ino uint64, perm linux.FileMode) kernfs.Inode { inode := &commInode{task: task} - inode.DynamicBytesFile.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, &commData{task: task}, perm) + inode.DynamicBytesFile.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, &commData{task: task}, perm) return inode } @@ -366,6 +367,162 @@ func (d *idMapData) Write(ctx context.Context, src usermem.IOSequence, offset in return int64(srclen), nil } +var _ kernfs.Inode = (*memInode)(nil) + +// memInode implements kernfs.Inode for /proc/[pid]/mem. +// +// +stateify savable +type memInode struct { + kernfs.InodeAttrs + kernfs.InodeNoStatFS + kernfs.InodeNoopRefCount + kernfs.InodeNotDirectory + kernfs.InodeNotSymlink + + task *kernel.Task + locks vfs.FileLocks +} + +func (fs *filesystem) newMemInode(task *kernel.Task, ino uint64, perm linux.FileMode) kernfs.Inode { + // Note: credentials are overridden by taskOwnedInode. + inode := &memInode{task: task} + inode.init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm) + return &taskOwnedInode{Inode: inode, owner: task} +} + +func (f *memInode) init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { + if perm&^linux.PermissionsMask != 0 { + panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) + } + f.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm) +} + +// Open implements kernfs.Inode.Open. +func (f *memInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + // TODO(gvisor.dev/issue/260): Add check for PTRACE_MODE_ATTACH_FSCREDS + // Permission to read this file is governed by PTRACE_MODE_ATTACH_FSCREDS + // Since we dont implement setfsuid/setfsgid we can just use PTRACE_MODE_ATTACH + if !kernel.ContextCanTrace(ctx, f.task, true) { + return nil, syserror.EACCES + } + if err := checkTaskState(f.task); err != nil { + return nil, err + } + fd := &memFD{} + if err := fd.Init(rp.Mount(), d, f, opts.Flags); err != nil { + return nil, err + } + return &fd.vfsfd, nil +} + +// SetStat implements kernfs.Inode.SetStat. +func (*memInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { + return syserror.EPERM +} + +var _ vfs.FileDescriptionImpl = (*memFD)(nil) + +// memFD implements vfs.FileDescriptionImpl for /proc/[pid]/mem. +// +// +stateify savable +type memFD struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.LockFD + + inode *memInode + + // mu guards the fields below. + mu sync.Mutex `state:"nosave"` + offset int64 +} + +// Init initializes memFD. +func (fd *memFD) Init(m *vfs.Mount, d *kernfs.Dentry, inode *memInode, flags uint32) error { + fd.LockFD.Init(&inode.locks) + if err := fd.vfsfd.Init(fd, flags, m, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { + return err + } + fd.inode = inode + return nil +} + +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *memFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + fd.mu.Lock() + defer fd.mu.Unlock() + switch whence { + case linux.SEEK_SET: + case linux.SEEK_CUR: + offset += fd.offset + default: + return 0, syserror.EINVAL + } + if offset < 0 { + return 0, syserror.EINVAL + } + fd.offset = offset + return offset, nil +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *memFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + if dst.NumBytes() == 0 { + return 0, nil + } + m, err := getMMIncRef(fd.inode.task) + if err != nil { + return 0, nil + } + defer m.DecUsers(ctx) + // Buffer the read data because of MM locks + buf := make([]byte, dst.NumBytes()) + n, readErr := m.CopyIn(ctx, usermem.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true}) + if n > 0 { + if _, err := dst.CopyOut(ctx, buf[:n]); err != nil { + return 0, syserror.EFAULT + } + return int64(n), nil + } + if readErr != nil { + return 0, syserror.EIO + } + return 0, nil +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *memFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + fd.mu.Lock() + n, err := fd.PRead(ctx, dst, fd.offset, opts) + fd.offset += n + fd.mu.Unlock() + return n, err +} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (fd *memFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + fs := fd.vfsfd.VirtualDentry().Mount().Filesystem() + return fd.inode.Stat(ctx, fs, opts) +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (fd *memFD) SetStat(context.Context, vfs.SetStatOptions) error { + return syserror.EPERM +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *memFD) Release(context.Context) {} + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *memFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *memFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) +} + // mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps. // // +stateify savable @@ -657,7 +814,7 @@ var _ kernfs.Inode = (*exeSymlink)(nil) func (fs *filesystem) newExeSymlink(task *kernel.Task, ino uint64) kernfs.Inode { inode := &exeSymlink{task: task} - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) + inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) return inode } @@ -733,7 +890,7 @@ var _ kernfs.Inode = (*cwdSymlink)(nil) func (fs *filesystem) newCwdSymlink(task *kernel.Task, ino uint64) kernfs.Inode { inode := &cwdSymlink{task: task} - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) + inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) return inode } @@ -850,7 +1007,7 @@ func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns stri inode := &namespaceSymlink{task: task} // Note: credentials are overridden by taskOwnedInode. - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, target) + inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, target) taskInode := &taskOwnedInode{Inode: inode, owner: task} return taskInode @@ -872,8 +1029,10 @@ func (s *namespaceSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.Vir // Create a synthetic inode to represent the namespace. fs := mnt.Filesystem().Impl().(*filesystem) + nsInode := &namespaceInode{} + nsInode.Init(ctx, auth.CredentialsFromContext(ctx), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0444) dentry := &kernfs.Dentry{} - dentry.Init(&fs.Filesystem, &namespaceInode{}) + dentry.Init(&fs.Filesystem, nsInode) vd := vfs.MakeVirtualDentry(mnt, dentry.VFSDentry()) // Only IncRef vd.Mount() because vd.Dentry() already holds a ref of 1. mnt.IncRef() @@ -897,11 +1056,11 @@ type namespaceInode struct { var _ kernfs.Inode = (*namespaceInode)(nil) // Init initializes a namespace inode. -func (i *namespaceInode) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { +func (i *namespaceInode) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) } - i.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm) + i.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm) } // Open implements kernfs.Inode.Open. diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go index 3425e8698..5a9ee111f 100644 --- a/pkg/sentry/fsimpl/proc/task_net.go +++ b/pkg/sentry/fsimpl/proc/task_net.go @@ -57,33 +57,33 @@ func (fs *filesystem) newTaskNetDir(task *kernel.Task) kernfs.Inode { // TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task // network namespace. contents = map[string]kernfs.Inode{ - "dev": fs.newInode(root, 0444, &netDevData{stack: stack}), - "snmp": fs.newInode(root, 0444, &netSnmpData{stack: stack}), + "dev": fs.newInode(task, root, 0444, &netDevData{stack: stack}), + "snmp": fs.newInode(task, root, 0444, &netSnmpData{stack: stack}), // The following files are simple stubs until they are implemented in // netstack, if the file contains a header the stub is just the header // otherwise it is an empty file. - "arp": fs.newInode(root, 0444, newStaticFile(arp)), - "netlink": fs.newInode(root, 0444, newStaticFile(netlink)), - "netstat": fs.newInode(root, 0444, &netStatData{}), - "packet": fs.newInode(root, 0444, newStaticFile(packet)), - "protocols": fs.newInode(root, 0444, newStaticFile(protocols)), + "arp": fs.newInode(task, root, 0444, newStaticFile(arp)), + "netlink": fs.newInode(task, root, 0444, newStaticFile(netlink)), + "netstat": fs.newInode(task, root, 0444, &netStatData{}), + "packet": fs.newInode(task, root, 0444, newStaticFile(packet)), + "protocols": fs.newInode(task, root, 0444, newStaticFile(protocols)), // Linux sets psched values to: nsec per usec, psched tick in ns, 1000000, // high res timer ticks per sec (ClockGetres returns 1ns resolution). - "psched": fs.newInode(root, 0444, newStaticFile(psched)), - "ptype": fs.newInode(root, 0444, newStaticFile(ptype)), - "route": fs.newInode(root, 0444, &netRouteData{stack: stack}), - "tcp": fs.newInode(root, 0444, &netTCPData{kernel: k}), - "udp": fs.newInode(root, 0444, &netUDPData{kernel: k}), - "unix": fs.newInode(root, 0444, &netUnixData{kernel: k}), + "psched": fs.newInode(task, root, 0444, newStaticFile(psched)), + "ptype": fs.newInode(task, root, 0444, newStaticFile(ptype)), + "route": fs.newInode(task, root, 0444, &netRouteData{stack: stack}), + "tcp": fs.newInode(task, root, 0444, &netTCPData{kernel: k}), + "udp": fs.newInode(task, root, 0444, &netUDPData{kernel: k}), + "unix": fs.newInode(task, root, 0444, &netUnixData{kernel: k}), } if stack.SupportsIPv6() { - contents["if_inet6"] = fs.newInode(root, 0444, &ifinet6{stack: stack}) - contents["ipv6_route"] = fs.newInode(root, 0444, newStaticFile("")) - contents["tcp6"] = fs.newInode(root, 0444, &netTCP6Data{kernel: k}) - contents["udp6"] = fs.newInode(root, 0444, newStaticFile(upd6)) + contents["if_inet6"] = fs.newInode(task, root, 0444, &ifinet6{stack: stack}) + contents["ipv6_route"] = fs.newInode(task, root, 0444, newStaticFile("")) + contents["tcp6"] = fs.newInode(task, root, 0444, &netTCP6Data{kernel: k}) + contents["udp6"] = fs.newInode(task, root, 0444, newStaticFile(upd6)) } } diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index 3259c3732..b81ea14bf 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -62,19 +62,19 @@ type tasksInode struct { var _ kernfs.Inode = (*tasksInode)(nil) -func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode { +func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode { root := auth.NewRootCredentials(pidns.UserNamespace()) contents := map[string]kernfs.Inode{ - "cpuinfo": fs.newInode(root, 0444, newStaticFileSetStat(cpuInfoData(k))), - "filesystems": fs.newInode(root, 0444, &filesystemsData{}), - "loadavg": fs.newInode(root, 0444, &loadavgData{}), - "sys": fs.newSysDir(root, k), - "meminfo": fs.newInode(root, 0444, &meminfoData{}), - "mounts": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/mounts"), - "net": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/net"), - "stat": fs.newInode(root, 0444, &statData{}), - "uptime": fs.newInode(root, 0444, &uptimeData{}), - "version": fs.newInode(root, 0444, &versionData{}), + "cpuinfo": fs.newInode(ctx, root, 0444, newStaticFileSetStat(cpuInfoData(k))), + "filesystems": fs.newInode(ctx, root, 0444, &filesystemsData{}), + "loadavg": fs.newInode(ctx, root, 0444, &loadavgData{}), + "sys": fs.newSysDir(ctx, root, k), + "meminfo": fs.newInode(ctx, root, 0444, &meminfoData{}), + "mounts": kernfs.NewStaticSymlink(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/mounts"), + "net": kernfs.NewStaticSymlink(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/net"), + "stat": fs.newInode(ctx, root, 0444, &statData{}), + "uptime": fs.newInode(ctx, root, 0444, &uptimeData{}), + "version": fs.newInode(ctx, root, 0444, &versionData{}), } inode := &tasksInode{ @@ -82,7 +82,7 @@ func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace fs: fs, cgroupControllers: cgroupControllers, } - inode.InodeAttrs.Init(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + inode.InodeAttrs.Init(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) inode.EnableLeakCheck() inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) @@ -106,9 +106,9 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err // If it failed to parse, check if it's one of the special handled files. switch name { case selfName: - return i.newSelfSymlink(root), nil + return i.newSelfSymlink(ctx, root), nil case threadSelfName: - return i.newThreadSelfSymlink(root), nil + return i.newThreadSelfSymlink(ctx, root), nil } return nil, syserror.ENOENT } @@ -122,7 +122,7 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err } // IterDirents implements kernfs.inodeDirectory.IterDirents. -func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) { +func (i *tasksInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) { // fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256 const FIRST_PROCESS_ENTRY = 256 diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go index 07c27cdd9..01b7a6678 100644 --- a/pkg/sentry/fsimpl/proc/tasks_files.go +++ b/pkg/sentry/fsimpl/proc/tasks_files.go @@ -43,9 +43,9 @@ type selfSymlink struct { var _ kernfs.Inode = (*selfSymlink)(nil) -func (i *tasksInode) newSelfSymlink(creds *auth.Credentials) kernfs.Inode { +func (i *tasksInode) newSelfSymlink(ctx context.Context, creds *auth.Credentials) kernfs.Inode { inode := &selfSymlink{pidns: i.pidns} - inode.Init(creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777) + inode.Init(ctx, creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777) return inode } @@ -84,9 +84,9 @@ type threadSelfSymlink struct { var _ kernfs.Inode = (*threadSelfSymlink)(nil) -func (i *tasksInode) newThreadSelfSymlink(creds *auth.Credentials) kernfs.Inode { +func (i *tasksInode) newThreadSelfSymlink(ctx context.Context, creds *auth.Credentials) kernfs.Inode { inode := &threadSelfSymlink{pidns: i.pidns} - inode.Init(creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777) + inode.Init(ctx, creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777) return inode } diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index 95420368d..7c7afdcfa 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -40,93 +40,93 @@ const ( ) // newSysDir returns the dentry corresponding to /proc/sys directory. -func (fs *filesystem) newSysDir(root *auth.Credentials, k *kernel.Kernel) kernfs.Inode { - return fs.newStaticDir(root, map[string]kernfs.Inode{ - "kernel": fs.newStaticDir(root, map[string]kernfs.Inode{ - "hostname": fs.newInode(root, 0444, &hostnameData{}), - "shmall": fs.newInode(root, 0444, shmData(linux.SHMALL)), - "shmmax": fs.newInode(root, 0444, shmData(linux.SHMMAX)), - "shmmni": fs.newInode(root, 0444, shmData(linux.SHMMNI)), +func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k *kernel.Kernel) kernfs.Inode { + return fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "kernel": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "hostname": fs.newInode(ctx, root, 0444, &hostnameData{}), + "shmall": fs.newInode(ctx, root, 0444, shmData(linux.SHMALL)), + "shmmax": fs.newInode(ctx, root, 0444, shmData(linux.SHMMAX)), + "shmmni": fs.newInode(ctx, root, 0444, shmData(linux.SHMMNI)), }), - "vm": fs.newStaticDir(root, map[string]kernfs.Inode{ - "mmap_min_addr": fs.newInode(root, 0444, &mmapMinAddrData{k: k}), - "overcommit_memory": fs.newInode(root, 0444, newStaticFile("0\n")), + "vm": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "mmap_min_addr": fs.newInode(ctx, root, 0444, &mmapMinAddrData{k: k}), + "overcommit_memory": fs.newInode(ctx, root, 0444, newStaticFile("0\n")), }), - "net": fs.newSysNetDir(root, k), + "net": fs.newSysNetDir(ctx, root, k), }) } // newSysNetDir returns the dentry corresponding to /proc/sys/net directory. -func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) kernfs.Inode { +func (fs *filesystem) newSysNetDir(ctx context.Context, root *auth.Credentials, k *kernel.Kernel) kernfs.Inode { var contents map[string]kernfs.Inode // TODO(gvisor.dev/issue/1833): Support for using the network stack in the // network namespace of the calling process. if stack := k.RootNetworkNamespace().Stack(); stack != nil { contents = map[string]kernfs.Inode{ - "ipv4": fs.newStaticDir(root, map[string]kernfs.Inode{ - "tcp_recovery": fs.newInode(root, 0644, &tcpRecoveryData{stack: stack}), - "tcp_rmem": fs.newInode(root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}), - "tcp_sack": fs.newInode(root, 0644, &tcpSackData{stack: stack}), - "tcp_wmem": fs.newInode(root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}), - "ip_forward": fs.newInode(root, 0444, &ipForwarding{stack: stack}), + "ipv4": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "tcp_recovery": fs.newInode(ctx, root, 0644, &tcpRecoveryData{stack: stack}), + "tcp_rmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}), + "tcp_sack": fs.newInode(ctx, root, 0644, &tcpSackData{stack: stack}), + "tcp_wmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}), + "ip_forward": fs.newInode(ctx, root, 0444, &ipForwarding{stack: stack}), // The following files are simple stubs until they are implemented in // netstack, most of these files are configuration related. We use the // value closest to the actual netstack behavior or any empty file, all // of these files will have mode 0444 (read-only for all users). - "ip_local_port_range": fs.newInode(root, 0444, newStaticFile("16000 65535")), - "ip_local_reserved_ports": fs.newInode(root, 0444, newStaticFile("")), - "ipfrag_time": fs.newInode(root, 0444, newStaticFile("30")), - "ip_nonlocal_bind": fs.newInode(root, 0444, newStaticFile("0")), - "ip_no_pmtu_disc": fs.newInode(root, 0444, newStaticFile("1")), + "ip_local_port_range": fs.newInode(ctx, root, 0444, newStaticFile("16000 65535")), + "ip_local_reserved_ports": fs.newInode(ctx, root, 0444, newStaticFile("")), + "ipfrag_time": fs.newInode(ctx, root, 0444, newStaticFile("30")), + "ip_nonlocal_bind": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "ip_no_pmtu_disc": fs.newInode(ctx, root, 0444, newStaticFile("1")), // tcp_allowed_congestion_control tell the user what they are able to // do as an unprivledged process so we leave it empty. - "tcp_allowed_congestion_control": fs.newInode(root, 0444, newStaticFile("")), - "tcp_available_congestion_control": fs.newInode(root, 0444, newStaticFile("reno")), - "tcp_congestion_control": fs.newInode(root, 0444, newStaticFile("reno")), + "tcp_allowed_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("")), + "tcp_available_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("reno")), + "tcp_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("reno")), // Many of the following stub files are features netstack doesn't // support. The unsupported features return "0" to indicate they are // disabled. - "tcp_base_mss": fs.newInode(root, 0444, newStaticFile("1280")), - "tcp_dsack": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_early_retrans": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_fack": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_fastopen": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_fastopen_key": fs.newInode(root, 0444, newStaticFile("")), - "tcp_invalid_ratelimit": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_keepalive_intvl": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_keepalive_probes": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_keepalive_time": fs.newInode(root, 0444, newStaticFile("7200")), - "tcp_mtu_probing": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_no_metrics_save": fs.newInode(root, 0444, newStaticFile("1")), - "tcp_probe_interval": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_probe_threshold": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_retries1": fs.newInode(root, 0444, newStaticFile("3")), - "tcp_retries2": fs.newInode(root, 0444, newStaticFile("15")), - "tcp_rfc1337": fs.newInode(root, 0444, newStaticFile("1")), - "tcp_slow_start_after_idle": fs.newInode(root, 0444, newStaticFile("1")), - "tcp_synack_retries": fs.newInode(root, 0444, newStaticFile("5")), - "tcp_syn_retries": fs.newInode(root, 0444, newStaticFile("3")), - "tcp_timestamps": fs.newInode(root, 0444, newStaticFile("1")), + "tcp_base_mss": fs.newInode(ctx, root, 0444, newStaticFile("1280")), + "tcp_dsack": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_early_retrans": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_fack": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_fastopen": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_fastopen_key": fs.newInode(ctx, root, 0444, newStaticFile("")), + "tcp_invalid_ratelimit": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_keepalive_intvl": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_keepalive_probes": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_keepalive_time": fs.newInode(ctx, root, 0444, newStaticFile("7200")), + "tcp_mtu_probing": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_no_metrics_save": fs.newInode(ctx, root, 0444, newStaticFile("1")), + "tcp_probe_interval": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_probe_threshold": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_retries1": fs.newInode(ctx, root, 0444, newStaticFile("3")), + "tcp_retries2": fs.newInode(ctx, root, 0444, newStaticFile("15")), + "tcp_rfc1337": fs.newInode(ctx, root, 0444, newStaticFile("1")), + "tcp_slow_start_after_idle": fs.newInode(ctx, root, 0444, newStaticFile("1")), + "tcp_synack_retries": fs.newInode(ctx, root, 0444, newStaticFile("5")), + "tcp_syn_retries": fs.newInode(ctx, root, 0444, newStaticFile("3")), + "tcp_timestamps": fs.newInode(ctx, root, 0444, newStaticFile("1")), }), - "core": fs.newStaticDir(root, map[string]kernfs.Inode{ - "default_qdisc": fs.newInode(root, 0444, newStaticFile("pfifo_fast")), - "message_burst": fs.newInode(root, 0444, newStaticFile("10")), - "message_cost": fs.newInode(root, 0444, newStaticFile("5")), - "optmem_max": fs.newInode(root, 0444, newStaticFile("0")), - "rmem_default": fs.newInode(root, 0444, newStaticFile("212992")), - "rmem_max": fs.newInode(root, 0444, newStaticFile("212992")), - "somaxconn": fs.newInode(root, 0444, newStaticFile("128")), - "wmem_default": fs.newInode(root, 0444, newStaticFile("212992")), - "wmem_max": fs.newInode(root, 0444, newStaticFile("212992")), + "core": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "default_qdisc": fs.newInode(ctx, root, 0444, newStaticFile("pfifo_fast")), + "message_burst": fs.newInode(ctx, root, 0444, newStaticFile("10")), + "message_cost": fs.newInode(ctx, root, 0444, newStaticFile("5")), + "optmem_max": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "rmem_default": fs.newInode(ctx, root, 0444, newStaticFile("212992")), + "rmem_max": fs.newInode(ctx, root, 0444, newStaticFile("212992")), + "somaxconn": fs.newInode(ctx, root, 0444, newStaticFile("128")), + "wmem_default": fs.newInode(ctx, root, 0444, newStaticFile("212992")), + "wmem_max": fs.newInode(ctx, root, 0444, newStaticFile("212992")), }), } } - return fs.newStaticDir(root, contents) + return fs.newStaticDir(ctx, root, contents) } // mmapMinAddrData implements vfs.DynamicBytesSource for diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go index 2582ababd..7ee6227a9 100644 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_test.go @@ -77,6 +77,7 @@ var ( "gid_map": linux.DT_REG, "io": linux.DT_REG, "maps": linux.DT_REG, + "mem": linux.DT_REG, "mountinfo": linux.DT_REG, "mounts": linux.DT_REG, "net": linux.DT_DIR, diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go index cf91ea36c..fda1fa942 100644 --- a/pkg/sentry/fsimpl/sockfs/sockfs.go +++ b/pkg/sentry/fsimpl/sockfs/sockfs.go @@ -108,13 +108,13 @@ func (i *inode) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, e // NewDentry constructs and returns a sockfs dentry. // // Preconditions: mnt.Filesystem() must have been returned by NewFilesystem(). -func NewDentry(creds *auth.Credentials, mnt *vfs.Mount) *vfs.Dentry { +func NewDentry(ctx context.Context, mnt *vfs.Mount) *vfs.Dentry { fs := mnt.Filesystem().Impl().(*filesystem) // File mode matches net/socket.c:sock_alloc. filemode := linux.FileMode(linux.S_IFSOCK | 0600) i := &inode{} - i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.Filesystem.NextIno(), filemode) + i.InodeAttrs.Init(ctx, auth.CredentialsFromContext(ctx), linux.UNNAMED_MAJOR, fs.devMinor, fs.Filesystem.NextIno(), filemode) d := &kernfs.Dentry{} d.Init(&fs.Filesystem, i) diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD index 906cd52cb..09043b572 100644 --- a/pkg/sentry/fsimpl/sys/BUILD +++ b/pkg/sentry/fsimpl/sys/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "dir_refs.go", package = "sys", prefix = "dir", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "dir", }, @@ -28,6 +28,7 @@ go_library( "//pkg/coverage", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sentry/arch", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/kernel", diff --git a/pkg/sentry/fsimpl/sys/kcov.go b/pkg/sentry/fsimpl/sys/kcov.go index 31a361029..b13f141a8 100644 --- a/pkg/sentry/fsimpl/sys/kcov.go +++ b/pkg/sentry/fsimpl/sys/kcov.go @@ -29,7 +29,7 @@ import ( func (fs *filesystem) newKcovFile(ctx context.Context, creds *auth.Credentials) kernfs.Inode { k := &kcovInode{} - k.InodeAttrs.Init(creds, 0, 0, fs.NextIno(), linux.S_IFREG|0600) + k.InodeAttrs.Init(ctx, creds, 0, 0, fs.NextIno(), linux.S_IFREG|0600) return k } diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index 1ad679830..7d2147141 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -18,6 +18,7 @@ package sys import ( "bytes" "fmt" + "strconv" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -29,9 +30,12 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) -// Name is the default filesystem name. -const Name = "sysfs" -const defaultSysDirMode = linux.FileMode(0755) +const ( + // Name is the default filesystem name. + Name = "sysfs" + defaultSysDirMode = linux.FileMode(0755) + defaultMaxCachedDentries = uint64(1000) +) // FilesystemType implements vfs.FilesystemType. // @@ -62,28 +66,40 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return nil, nil, err } + mopts := vfs.GenericParseMountOptions(opts.Data) + maxCachedDentries := defaultMaxCachedDentries + if str, ok := mopts["dentry_cache_limit"]; ok { + delete(mopts, "dentry_cache_limit") + maxCachedDentries, err = strconv.ParseUint(str, 10, 64) + if err != nil { + ctx.Warningf("sys.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str) + return nil, nil, syserror.EINVAL + } + } + fs := &filesystem{ devMinor: devMinor, } + fs.MaxCachedDentries = maxCachedDentries fs.VFSFilesystem().Init(vfsObj, &fsType, fs) - root := fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{ - "block": fs.newDir(creds, defaultSysDirMode, nil), - "bus": fs.newDir(creds, defaultSysDirMode, nil), - "class": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{ - "power_supply": fs.newDir(creds, defaultSysDirMode, nil), + root := fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ + "block": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "bus": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "class": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ + "power_supply": fs.newDir(ctx, creds, defaultSysDirMode, nil), }), - "dev": fs.newDir(creds, defaultSysDirMode, nil), - "devices": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{ - "system": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{ + "dev": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "devices": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ + "system": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ "cpu": cpuDir(ctx, fs, creds), }), }), - "firmware": fs.newDir(creds, defaultSysDirMode, nil), - "fs": fs.newDir(creds, defaultSysDirMode, nil), + "firmware": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "fs": fs.newDir(ctx, creds, defaultSysDirMode, nil), "kernel": kernelDir(ctx, fs, creds), - "module": fs.newDir(creds, defaultSysDirMode, nil), - "power": fs.newDir(creds, defaultSysDirMode, nil), + "module": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "power": fs.newDir(ctx, creds, defaultSysDirMode, nil), }) var rootD kernfs.Dentry rootD.Init(&fs.Filesystem, root) @@ -94,14 +110,14 @@ func cpuDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs k := kernel.KernelFromContext(ctx) maxCPUCores := k.ApplicationCores() children := map[string]kernfs.Inode{ - "online": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), - "possible": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), - "present": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), + "online": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)), + "possible": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)), + "present": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)), } for i := uint(0); i < maxCPUCores; i++ { - children[fmt.Sprintf("cpu%d", i)] = fs.newDir(creds, linux.FileMode(0555), nil) + children[fmt.Sprintf("cpu%d", i)] = fs.newDir(ctx, creds, linux.FileMode(0555), nil) } - return fs.newDir(creds, defaultSysDirMode, children) + return fs.newDir(ctx, creds, defaultSysDirMode, children) } func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs.Inode { @@ -111,12 +127,12 @@ func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) ker var children map[string]kernfs.Inode if coverage.KcovAvailable() { children = map[string]kernfs.Inode{ - "debug": fs.newDir(creds, linux.FileMode(0700), map[string]kernfs.Inode{ + "debug": fs.newDir(ctx, creds, linux.FileMode(0700), map[string]kernfs.Inode{ "kcov": fs.newKcovFile(ctx, creds), }), } } - return fs.newDir(creds, defaultSysDirMode, children) + return fs.newDir(ctx, creds, defaultSysDirMode, children) } // Release implements vfs.FilesystemImpl.Release. @@ -140,9 +156,9 @@ type dir struct { locks vfs.FileLocks } -func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { +func (fs *filesystem) newDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { d := &dir{} - d.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755) + d.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755) d.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) d.EnableLeakCheck() d.IncLinks(d.OrderedChildren.Populate(contents)) @@ -191,9 +207,9 @@ func (c *cpuFile) Generate(ctx context.Context, buf *bytes.Buffer) error { return nil } -func (fs *filesystem) newCPUFile(creds *auth.Credentials, maxCores uint, mode linux.FileMode) kernfs.Inode { +func (fs *filesystem) newCPUFile(ctx context.Context, creds *auth.Credentials, maxCores uint, mode linux.FileMode) kernfs.Inode { c := &cpuFile{maxCores: maxCores} - c.DynamicBytesFile.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), c, mode) + c.DynamicBytesFile.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), c, mode) return c } diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index 5cd428d64..fe520b6fd 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -31,7 +31,7 @@ go_template_instance( out = "inode_refs.go", package = "tmpfs", prefix = "inode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "inode", }, @@ -48,6 +48,7 @@ go_library( "inode_refs.go", "named_pipe.go", "regular_file.go", + "save_restore.go", "socket_file.go", "symlink.go", "tmpfs.go", @@ -60,6 +61,7 @@ go_library( "//pkg/fspath", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index ce4e3eda7..98680fde9 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -42,7 +42,7 @@ type regularFile struct { inode inode // memFile is a platform.File used to allocate pages to this regularFile. - memFile *pgalloc.MemoryFile + memFile *pgalloc.MemoryFile `state:"nosave"` // memoryUsageKind is the memory accounting category under which pages backing // this regularFile's contents are accounted. @@ -92,7 +92,7 @@ type regularFile struct { func (fs *filesystem) newRegularFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode { file := ®ularFile{ - memFile: fs.memFile, + memFile: fs.mfp.MemoryFile(), memoryUsageKind: usage.Tmpfs, seals: linux.F_SEAL_SEAL, } diff --git a/pkg/sentry/fsimpl/tmpfs/save_restore.go b/pkg/sentry/fsimpl/tmpfs/save_restore.go new file mode 100644 index 000000000..b27f75cc2 --- /dev/null +++ b/pkg/sentry/fsimpl/tmpfs/save_restore.go @@ -0,0 +1,20 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tmpfs + +// afterLoad is called by stateify. +func (rf *regularFile) afterLoad() { + rf.memFile = rf.inode.fs.mfp.MemoryFile() +} diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index e2a0aac69..4ce859d57 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -61,8 +61,9 @@ type FilesystemType struct{} type filesystem struct { vfsfs vfs.Filesystem - // memFile is used to allocate pages to for regular files. - memFile *pgalloc.MemoryFile + // mfp is used to allocate memory that stores regular file contents. mfp is + // immutable. + mfp pgalloc.MemoryFileProvider // clock is a realtime clock used to set timestamps in file operations. clock time.Clock @@ -106,8 +107,8 @@ type FilesystemOpts struct { // GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, _ string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { - memFileProvider := pgalloc.MemoryFileProviderFromContext(ctx) - if memFileProvider == nil { + mfp := pgalloc.MemoryFileProviderFromContext(ctx) + if mfp == nil { panic("MemoryFileProviderFromContext returned nil") } @@ -181,7 +182,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } clock := time.RealtimeClockFromContext(ctx) fs := filesystem{ - memFile: memFileProvider.MemoryFile(), + mfp: mfp, clock: clock, devMinor: devMinor, } diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD index 0ca750281..e265be0ee 100644 --- a/pkg/sentry/fsimpl/verity/BUILD +++ b/pkg/sentry/fsimpl/verity/BUILD @@ -6,6 +6,7 @@ go_library( name = "verity", srcs = [ "filesystem.go", + "save_restore.go", "verity.go", ], visibility = ["//pkg/sentry:internal"], @@ -15,6 +16,7 @@ go_library( "//pkg/fspath", "//pkg/marshal/primitive", "//pkg/merkletree", + "//pkg/refsvfs2", "//pkg/sentry/arch", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel", @@ -38,10 +40,12 @@ go_test( "//pkg/context", "//pkg/fspath", "//pkg/sentry/arch", + "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/fsimpl/tmpfs", + "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/contexttest", "//pkg/sentry/vfs", + "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 03da505e1..2f6050cfd 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -192,7 +192,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. if err == syserror.ENOENT || err == syserror.ENODATA { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) } if err != nil { return nil, err @@ -201,7 +201,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // unexpected modifications to the file system. offset, err := strconv.Atoi(off) if err != nil { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err)) } // Open parent Merkle tree file to read and verify child's hash. @@ -215,7 +215,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // The parent Merkle tree file should have been created. If it's // missing, it indicates an unexpected modification to the file system. if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) } if err != nil { return nil, err @@ -233,7 +233,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. if err == syserror.ENOENT || err == syserror.ENODATA { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) } if err != nil { return nil, err @@ -243,7 +243,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // unexpected modifications to the file system. parentSize, err := strconv.Atoi(dataSize) if err != nil { - return nil, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) } fdReader := vfs.FileReadWriteSeeker{ @@ -256,7 +256,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de Start: parent.lowerVD, }, &vfs.StatOptions{}) if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err)) } if err != nil { return nil, err @@ -267,20 +267,22 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // Verify returns with success. var buf bytes.Buffer if _, err := merkletree.Verify(&merkletree.VerifyParams{ - Out: &buf, - File: &fdReader, - Tree: &fdReader, - Size: int64(parentSize), - Name: parent.name, - Mode: uint32(parentStat.Mode), - UID: parentStat.UID, - GID: parentStat.GID, + Out: &buf, + File: &fdReader, + Tree: &fdReader, + Size: int64(parentSize), + Name: parent.name, + Mode: uint32(parentStat.Mode), + UID: parentStat.UID, + GID: parentStat.GID, + //TODO(b/156980949): Support passing other hash algorithms. + HashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, ReadOffset: int64(offset), - ReadSize: int64(merkletree.DigestSize()), + ReadSize: int64(merkletree.DigestSize(linux.FS_VERITY_HASH_ALG_SHA256)), Expected: parent.hash, DataAndTreeInSameFile: true, }); err != nil && err != io.EOF { - return nil, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification for %s failed: %v", childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Verification for %s failed: %v", childPath, err)) } // Cache child hash when it's verified the first time. @@ -312,7 +314,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat Flags: linux.O_RDONLY, }) if err == syserror.ENOENT { - return alertIntegrityViolation(err, fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err)) + return alertIntegrityViolation(fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err)) } if err != nil { return err @@ -324,7 +326,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat }) if err == syserror.ENODATA { - return alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err)) + return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err)) } if err != nil { return err @@ -332,7 +334,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat size, err := strconv.Atoi(merkleSize) if err != nil { - return alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) + return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) } fdReader := vfs.FileReadWriteSeeker{ @@ -342,14 +344,16 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat var buf bytes.Buffer params := &merkletree.VerifyParams{ - Out: &buf, - Tree: &fdReader, - Size: int64(size), - Name: d.name, - Mode: uint32(stat.Mode), - UID: stat.UID, - GID: stat.GID, - ReadOffset: 0, + Out: &buf, + Tree: &fdReader, + Size: int64(size), + Name: d.name, + Mode: uint32(stat.Mode), + UID: stat.UID, + GID: stat.GID, + //TODO(b/156980949): Support passing other hash algorithms. + HashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + ReadOffset: 0, // Set read size to 0 so only the metadata is verified. ReadSize: 0, Expected: d.hash, @@ -360,17 +364,57 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat } if _, err := merkletree.Verify(params); err != nil && err != io.EOF { - return alertIntegrityViolation(err, fmt.Sprintf("Verification stat for %s failed: %v", childPath, err)) + return alertIntegrityViolation(fmt.Sprintf("Verification stat for %s failed: %v", childPath, err)) } d.mode = uint32(stat.Mode) d.uid = stat.UID d.gid = stat.GID + d.size = uint32(size) return nil } // Preconditions: fs.renameMu must be locked. d.dirMu must be locked. func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { if child, ok := parent.children[name]; ok { + // If verity is enabled on child, we should check again whether + // the file and the corresponding Merkle tree are as expected, + // in order to catch deletion/renaming after the last time it's + // accessed. + if child.verityEnabled() { + vfsObj := fs.vfsfs.VirtualFilesystem() + // Get the path to the child dentry. This is only used + // to provide path information in failure case. + path, err := vfsObj.PathnameWithDeleted(ctx, child.fs.rootDentry.lowerVD, child.lowerVD) + if err != nil { + return nil, err + } + + childVD, err := parent.getLowerAt(ctx, vfsObj, name) + if err == syserror.ENOENT { + // The file was previously accessed. If the + // file does not exist now, it indicates an + // unexpected modification to the file system. + return nil, alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", path)) + } + if err != nil { + return nil, err + } + defer childVD.DecRef(ctx) + + childMerkleVD, err := parent.getLowerAt(ctx, vfsObj, merklePrefix+name) + // The Merkle tree file was previous accessed. If it + // does not exist now, it indicates an unexpected + // modification to the file system. + if err == syserror.ENOENT { + return nil, alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", path)) + } + if err != nil { + return nil, err + } + + defer childMerkleVD.DecRef(ctx) + } + // If enabling verification on files/directories is not allowed // during runtime, all cached children are already verified. If // runtime enable is allowed and the parent directory is @@ -418,13 +462,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) { vfsObj := fs.vfsfs.VirtualFilesystem() - childFilename := fspath.Parse(name) - childVD, childErr := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ - Root: parent.lowerVD, - Start: parent.lowerVD, - Path: childFilename, - }, &vfs.GetDentryOptions{}) - + childVD, childErr := parent.getLowerAt(ctx, vfsObj, name) // We will handle ENOENT separately, as it may indicate unexpected // modifications to the file system, and may cause a sentry panic. if childErr != nil && childErr != syserror.ENOENT { @@ -437,13 +475,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, defer childVD.DecRef(ctx) } - childMerkleFilename := merklePrefix + name - childMerkleVD, childMerkleErr := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ - Root: parent.lowerVD, - Start: parent.lowerVD, - Path: fspath.Parse(childMerkleFilename), - }, &vfs.GetDentryOptions{}) - + childMerkleVD, childMerkleErr := parent.getLowerAt(ctx, vfsObj, merklePrefix+name) // We will handle ENOENT separately, as it may indicate unexpected // modifications to the file system, and may cause a sentry panic. if childMerkleErr != nil && childMerkleErr != syserror.ENOENT { @@ -472,7 +504,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, // corresponding Merkle tree is found. This indicates an // unexpected modification to the file system that // removed/renamed the child. - return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name)) + return nil, alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name)) } else if childErr == nil && childMerkleErr == syserror.ENOENT { // If in allowRuntimeEnable mode, and the Merkle tree file is // not created yet, we create an empty Merkle tree file, so that @@ -488,7 +520,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ Root: parent.lowerVD, Start: parent.lowerVD, - Path: fspath.Parse(childMerkleFilename), + Path: fspath.Parse(merklePrefix + name), }, &vfs.OpenOptions{ Flags: linux.O_RDWR | linux.O_CREAT, Mode: 0644, @@ -497,11 +529,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, return nil, err } childMerkleFD.DecRef(ctx) - childMerkleVD, err = vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ - Root: parent.lowerVD, - Start: parent.lowerVD, - Path: fspath.Parse(childMerkleFilename), - }, &vfs.GetDentryOptions{}) + childMerkleVD, err = parent.getLowerAt(ctx, vfsObj, merklePrefix+name) if err != nil { return nil, err } @@ -509,7 +537,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, // If runtime enable is not allowed. This indicates an // unexpected modification to the file system that // removed/renamed the Merkle tree file. - return nil, alertIntegrityViolation(childMerkleErr, fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name)) + return nil, alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name)) } } else if childErr == syserror.ENOENT && childMerkleErr == syserror.ENOENT { // Both the child and the corresponding Merkle tree are missing. @@ -518,7 +546,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, // TODO(b/167752508): Investigate possible ways to differentiate // cases that both files are deleted from cases that they never // exist in the file system. - return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Failed to find file %s", parentPath+"/"+name)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to find file %s", parentPath+"/"+name)) } mask := uint32(linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID) @@ -762,7 +790,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // missing, it indicates an unexpected modification to the file system. if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("File %s expected but not found", path)) + return nil, alertIntegrityViolation(fmt.Sprintf("File %s expected but not found", path)) } return nil, err } @@ -785,7 +813,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // the file system. if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path)) + return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) } return nil, err } @@ -810,7 +838,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf }) if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path)) + return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) } return nil, err } @@ -828,7 +856,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf if err != nil { if err == syserror.ENOENT { parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD) - return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) + return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) } return nil, err } diff --git a/pkg/sentry/fsimpl/verity/save_restore.go b/pkg/sentry/fsimpl/verity/save_restore.go new file mode 100644 index 000000000..4a161163c --- /dev/null +++ b/pkg/sentry/fsimpl/verity/save_restore.go @@ -0,0 +1,27 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package verity + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +func (d *dentry) afterLoad() { + if refsvfs2.LeakCheckEnabled() && atomic.LoadInt64(&d.refs) != -1 { + refsvfs2.Register(d, "verity.dentry") + } +} diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index 8dc9e26bc..92ca6ca6b 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -23,6 +23,7 @@ package verity import ( "fmt" + "math" "strconv" "sync/atomic" @@ -31,6 +32,7 @@ import ( "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/merkletree" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -153,10 +155,10 @@ func (FilesystemType) Release(ctx context.Context) {} // alertIntegrityViolation alerts a violation of integrity, which usually means // unexpected modification to the file system is detected. In -// noCrashOnVerificationFailure mode, it returns an error, otherwise it panic. -func alertIntegrityViolation(err error, msg string) error { +// noCrashOnVerificationFailure mode, it returns EIO, otherwise it panic. +func alertIntegrityViolation(msg string) error { if noCrashOnVerificationFailure { - return err + return syserror.EIO } panic(msg) } @@ -236,7 +238,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // the root Merkle file, or it's never generated. fs.vfsfs.DecRef(ctx) d.DecRef(ctx) - return nil, nil, alertIntegrityViolation(err, "Failed to find root Merkle file") + return nil, nil, alertIntegrityViolation("Failed to find root Merkle file") } d.lowerMerkleVD = lowerMerkleVD @@ -289,11 +291,12 @@ type dentry struct { // fs is the owning filesystem. fs is immutable. fs *filesystem - // mode, uid and gid are the file mode, owner, and group of the file in - // the underlying file system. + // mode, uid, gid and size are the file mode, owner, group, and size of + // the file in the underlying file system. mode uint32 uid uint32 gid uint32 + size uint32 // parent is the dentry corresponding to this dentry's parent directory. // name is this dentry's name in parent. If this dentry is a filesystem @@ -331,6 +334,9 @@ func (fs *filesystem) newDentry() *dentry { fs: fs, } d.vfsd.Init(d) + if refsvfs2.LeakCheckEnabled() { + refsvfs2.Register(d, "verity.dentry") + } return d } @@ -393,6 +399,9 @@ func (d *dentry) destroyLocked(ctx context.Context) { if d.lowerVD.Ok() { d.lowerVD.DecRef(ctx) } + if refsvfs2.LeakCheckEnabled() { + refsvfs2.Unregister(d, "verity.dentry") + } if d.lowerMerkleVD.Ok() { d.lowerMerkleVD.DecRef(ctx) @@ -412,6 +421,11 @@ func (d *dentry) destroyLocked(ctx context.Context) { } } +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (d *dentry) LeakMessage() string { + return fmt.Sprintf("[verity.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs)) +} + // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. func (d *dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) { //TODO(b/159261227): Implement InotifyWithParent. @@ -448,6 +462,16 @@ func (d *dentry) verityEnabled() bool { return !d.fs.allowRuntimeEnable || len(d.hash) != 0 } +// getLowerAt returns the dentry in the underlying file system, which is +// represented by filename relative to d. +func (d *dentry) getLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, filename string) (vfs.VirtualDentry, error) { + return vfsObj.GetDentryAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + Path: fspath.Parse(filename), + }, &vfs.GetDentryOptions{}) +} + func (d *dentry) readlink(ctx context.Context) (string, error) { return d.fs.vfsfs.VirtualFilesystem().ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{ Root: d.lowerVD, @@ -489,6 +513,10 @@ type fileDescription struct { // directory that contains the current file/directory. This is only used // if allowRuntimeEnable is set to true. parentMerkleWriter *vfs.FileDescription + + // off is the file offset. off is protected by mu. + mu sync.Mutex `state:"nosave"` + off int64 } // Release implements vfs.FileDescriptionImpl.Release. @@ -524,6 +552,32 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) return syserror.EPERM } +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + fd.mu.Lock() + defer fd.mu.Unlock() + n := int64(0) + switch whence { + case linux.SEEK_SET: + // use offset as specified + case linux.SEEK_CUR: + n = fd.off + case linux.SEEK_END: + n = int64(fd.d.size) + default: + return 0, syserror.EINVAL + } + if offset > math.MaxInt64-n { + return 0, syserror.EINVAL + } + offset += n + if offset < 0 { + return 0, syserror.EINVAL + } + fd.off = offset + return offset, nil +} + // generateMerkle generates a Merkle tree file for fd. If fd points to a file // /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The hash // of the generated Merkle tree and the data size is returned. If fd points to @@ -546,6 +600,8 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, params := &merkletree.GenerateParams{ TreeReader: &merkleReader, TreeWriter: &merkleWriter, + //TODO(b/156980949): Support passing other hash algorithms. + HashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, } switch atomic.LoadUint32(&fd.d.mode) & linux.S_IFMT { @@ -611,7 +667,7 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (ui // or directory other than the root, the parent Merkle tree file should // have also been initialized. if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || (fd.parentMerkleWriter == nil && fd.d != fd.d.fs.rootDentry) { - return 0, alertIntegrityViolation(syserror.EIO, "Unexpected verity fd: missing expected underlying fds") + return 0, alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds") } hash, dataSize, err := fd.generateMerkle(ctx) @@ -657,6 +713,9 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (ui // measureVerity returns the hash of fd, saved in verityDigest. func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, verityDigest usermem.Addr) (uintptr, error) { t := kernel.TaskFromContext(ctx) + if t == nil { + return 0, syserror.EINVAL + } var metadata linux.DigestMetadata // If allowRuntimeEnable is true, an empty fd.d.hash indicates that @@ -667,7 +726,7 @@ func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, ve if fd.d.fs.allowRuntimeEnable { return 0, syserror.ENODATA } - return 0, alertIntegrityViolation(syserror.ENODATA, "Ioctl measureVerity: no hash found") + return 0, alertIntegrityViolation("Ioctl measureVerity: no hash found") } // The first part of VerityDigest is the metadata. @@ -702,6 +761,9 @@ func (fd *fileDescription) verityFlags(ctx context.Context, uio usermem.IO, flag } t := kernel.TaskFromContext(ctx) + if t == nil { + return 0, syserror.EINVAL + } _, err := primitive.CopyInt32Out(t, flags, f) return 0, err } @@ -722,6 +784,16 @@ func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch. } } +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // Implement Read with PRead by setting offset. + fd.mu.Lock() + n, err := fd.PRead(ctx, dst, fd.off, opts) + fd.off += n + fd.mu.Unlock() + return n, err +} + // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { // No need to verify if the file is not enabled yet in @@ -742,7 +814,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of // contains the expected xattrs. If the xattr does not exist, it // indicates unexpected modifications to the file system. if err == syserror.ENODATA { - return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) + return 0, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) } if err != nil { return 0, err @@ -752,7 +824,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of // unexpected modifications to the file system. size, err := strconv.Atoi(dataSize) if err != nil { - return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) + return 0, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) } dataReader := vfs.FileReadWriteSeeker{ @@ -766,25 +838,37 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of } n, err := merkletree.Verify(&merkletree.VerifyParams{ - Out: dst.Writer(ctx), - File: &dataReader, - Tree: &merkleReader, - Size: int64(size), - Name: fd.d.name, - Mode: fd.d.mode, - UID: fd.d.uid, - GID: fd.d.gid, + Out: dst.Writer(ctx), + File: &dataReader, + Tree: &merkleReader, + Size: int64(size), + Name: fd.d.name, + Mode: fd.d.mode, + UID: fd.d.uid, + GID: fd.d.gid, + //TODO(b/156980949): Support passing other hash algorithms. + HashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, ReadOffset: offset, ReadSize: dst.NumBytes(), Expected: fd.d.hash, DataAndTreeInSameFile: false, }) if err != nil { - return 0, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification failed: %v", err)) + return 0, alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err)) } return n, err } +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.EROFS +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.EROFS +} + // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { return fd.lowerFD.LockPOSIX(ctx, uid, t, start, length, whence, block) diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go index e301d35f5..c647cbfd3 100644 --- a/pkg/sentry/fsimpl/verity/verity_test.go +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -25,10 +25,12 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -41,11 +43,18 @@ const maxDataSize = 100000 // newVerityRoot creates a new verity mount, and returns the root. The // underlying file system is tmpfs. If the error is not nil, then cleanup // should be called when the root is no longer needed. -func newVerityRoot(ctx context.Context, t *testing.T) (*vfs.VirtualFilesystem, vfs.VirtualDentry, error) { +func newVerityRoot(t *testing.T) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *kernel.Task, error) { + k, err := testutil.Boot() + if err != nil { + t.Fatalf("testutil.Boot: %v", err) + } + + ctx := k.SupervisorContext() + rand.Seed(time.Now().UnixNano()) vfsObj := &vfs.VirtualFilesystem{} if err := vfsObj.Init(ctx); err != nil { - return nil, vfs.VirtualDentry{}, fmt.Errorf("VFS init: %v", err) + return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err) } vfsObj.MustRegisterFilesystemType("verity", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ @@ -67,16 +76,26 @@ func newVerityRoot(ctx context.Context, t *testing.T) (*vfs.VirtualFilesystem, v }, }) if err != nil { - return nil, vfs.VirtualDentry{}, fmt.Errorf("NewMountNamespace: %v", err) + return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("NewMountNamespace: %v", err) } root := mntns.Root() root.IncRef() + + // Use lowerRoot in the task as we modify the lower file system + // directly in many tests. + lowerRoot := root.Dentry().Impl().(*dentry).lowerVD + tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) + task, err := testutil.CreateTask(ctx, "name", tc, mntns, lowerRoot, lowerRoot) + if err != nil { + t.Fatalf("testutil.CreateTask: %v", err) + } + t.Helper() t.Cleanup(func() { root.DecRef(ctx) mntns.DecRef(ctx) }) - return vfsObj, root, nil + return vfsObj, root, task, nil } // newFileFD creates a new file in the verity mount, and returns the FD. The FD @@ -145,8 +164,7 @@ func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) er // TestOpen ensures that when a file is created, the corresponding Merkle tree // file and the root Merkle tree file exist. func TestOpen(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) + vfsObj, root, ctx, err := newVerityRoot(t) if err != nil { t.Fatalf("newVerityRoot: %v", err) } @@ -180,11 +198,10 @@ func TestOpen(t *testing.T) { } } -// TestUnmodifiedFileSucceeds ensures that read from an untouched verity file -// succeeds after enabling verity for it. -func TestReadUnmodifiedFileSucceeds(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) +// TestPReadUnmodifiedFileSucceeds ensures that pread from an untouched verity +// file succeeds after enabling verity for it. +func TestPReadUnmodifiedFileSucceeds(t *testing.T) { + vfsObj, root, ctx, err := newVerityRoot(t) if err != nil { t.Fatalf("newVerityRoot: %v", err) } @@ -213,11 +230,42 @@ func TestReadUnmodifiedFileSucceeds(t *testing.T) { } } +// TestReadUnmodifiedFileSucceeds ensures that read from an untouched verity +// file succeeds after enabling verity for it. +func TestReadUnmodifiedFileSucceeds(t *testing.T) { + vfsObj, root, ctx, err := newVerityRoot(t) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirm a normal read succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + buf := make([]byte, size) + n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) + if err != nil && err != io.EOF { + t.Fatalf("fd.Read: %v", err) + } + + if n != int64(size) { + t.Errorf("fd.PRead got read length %d, want %d", n, size) + } +} + // TestReopenUnmodifiedFileSucceeds ensures that reopen an untouched verity file // succeeds after enabling verity for it. func TestReopenUnmodifiedFileSucceeds(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) + vfsObj, root, ctx, err := newVerityRoot(t) if err != nil { t.Fatalf("newVerityRoot: %v", err) } @@ -248,10 +296,10 @@ func TestReopenUnmodifiedFileSucceeds(t *testing.T) { } } -// TestModifiedFileFails ensures that read from a modified verity file fails. -func TestModifiedFileFails(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) +// TestPReadModifiedFileFails ensures that read from a modified verity file +// fails. +func TestPReadModifiedFileFails(t *testing.T) { + vfsObj, root, ctx, err := newVerityRoot(t) if err != nil { t.Fatalf("newVerityRoot: %v", err) } @@ -289,15 +337,59 @@ func TestModifiedFileFails(t *testing.T) { // Confirm that read from the modified file fails. buf := make([]byte, size) if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { - t.Fatalf("fd.PRead succeeded with modified file") + t.Fatalf("fd.PRead succeeded, expected failure") + } +} + +// TestReadModifiedFileFails ensures that read from a modified verity file +// fails. +func TestReadModifiedFileFails(t *testing.T) { + vfsObj, root, ctx, err := newVerityRoot(t) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Open a new lowerFD that's read/writable. + lowerVD := fd.Impl().(*fileDescription).d.lowerVD + + lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerVD, + Start: lowerVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + if err := corruptRandomBit(ctx, lowerFD, size); err != nil { + t.Fatalf("corruptRandomBit: %v", err) + } + + // Confirm that read from the modified file fails. + buf := make([]byte, size) + if _, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}); err == nil { + t.Fatalf("fd.Read succeeded, expected failure") } } // TestModifiedMerkleFails ensures that read from a verity file fails if the // corresponding Merkle tree file is modified. func TestModifiedMerkleFails(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) + vfsObj, root, ctx, err := newVerityRoot(t) if err != nil { t.Fatalf("newVerityRoot: %v", err) } @@ -350,8 +442,7 @@ func TestModifiedMerkleFails(t *testing.T) { // verity enabled directory fails if the hashes related to the target file in // the parent Merkle tree file is modified. func TestModifiedParentMerkleFails(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) + vfsObj, root, ctx, err := newVerityRoot(t) if err != nil { t.Fatalf("newVerityRoot: %v", err) } @@ -428,8 +519,7 @@ func TestModifiedParentMerkleFails(t *testing.T) { // TestUnmodifiedStatSucceeds ensures that stat of an untouched verity file // succeeds after enabling verity for it. func TestUnmodifiedStatSucceeds(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) + vfsObj, root, ctx, err := newVerityRoot(t) if err != nil { t.Fatalf("newVerityRoot: %v", err) } @@ -455,8 +545,7 @@ func TestUnmodifiedStatSucceeds(t *testing.T) { // TestModifiedStatFails checks that getting stat for a file with modified stat // should fail. func TestModifiedStatFails(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) + vfsObj, root, ctx, err := newVerityRoot(t) if err != nil { t.Fatalf("newVerityRoot: %v", err) } @@ -489,3 +578,123 @@ func TestModifiedStatFails(t *testing.T) { t.Errorf("fd.Stat succeeded when it should fail") } } + +// TestOpenDeletedOrRenamedFileFails ensures that opening a deleted/renamed +// verity enabled file or the corresponding Merkle tree file fails with the +// verify error. +func TestOpenDeletedFileFails(t *testing.T) { + testCases := []struct { + // Tests removing files is remove is true. Otherwise tests + // renaming files. + remove bool + // The original file is removed/renamed if changeFile is true. + changeFile bool + // The Merkle tree file is removed/renamed if changeMerkleFile + // is true. + changeMerkleFile bool + }{ + { + remove: true, + changeFile: true, + changeMerkleFile: false, + }, + { + remove: true, + changeFile: false, + changeMerkleFile: true, + }, + { + remove: false, + changeFile: true, + changeMerkleFile: false, + }, + { + remove: false, + changeFile: true, + changeMerkleFile: false, + }, + } + for _, tc := range testCases { + t.Run(fmt.Sprintf("remove:%t", tc.remove), func(t *testing.T) { + vfsObj, root, ctx, err := newVerityRoot(t) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + rootLowerVD := root.Dentry().Impl().(*dentry).lowerVD + if tc.remove { + if tc.changeFile { + if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(filename), + }); err != nil { + t.Fatalf("UnlinkAt: %v", err) + } + } + if tc.changeMerkleFile { + if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(merklePrefix + filename), + }); err != nil { + t.Fatalf("UnlinkAt: %v", err) + } + } + } else { + newFilename := "renamed-test-file" + if tc.changeFile { + if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(filename), + }, &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(newFilename), + }, &vfs.RenameOptions{}); err != nil { + t.Fatalf("RenameAt: %v", err) + } + } + if tc.changeMerkleFile { + if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(merklePrefix + filename), + }, &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(merklePrefix + newFilename), + }, &vfs.RenameOptions{}); err != nil { + t.Fatalf("UnlinkAt: %v", err) + } + } + } + + // Ensure reopening the verity enabled file fails. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular, + }); err != syserror.EIO { + t.Errorf("got OpenAt error: %v, expected EIO", err) + } + }) + } +} diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index fbe6d6aa6..f31277d30 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -32,9 +32,13 @@ type Stack interface { InterfaceAddrs() map[int32][]InterfaceAddr // AddInterfaceAddr adds an address to the network interface identified by - // index. + // idx. AddInterfaceAddr(idx int32, addr InterfaceAddr) error + // RemoveInterfaceAddr removes an address from the network interface + // identified by idx. + RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error + // SupportsIPv6 returns true if the stack supports IPv6 connectivity. SupportsIPv6() bool diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index 1779cc6f3..9ebeba8a3 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -15,6 +15,9 @@ package inet import ( + "bytes" + "fmt" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -58,6 +61,24 @@ func (s *TestStack) AddInterfaceAddr(idx int32, addr InterfaceAddr) error { return nil } +// RemoveInterfaceAddr implements Stack.RemoveInterfaceAddr. +func (s *TestStack) RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error { + interfaceAddrs, ok := s.InterfaceAddrsMap[idx] + if !ok { + return fmt.Errorf("unknown idx: %d", idx) + } + + var filteredAddrs []InterfaceAddr + for _, interfaceAddr := range interfaceAddrs { + if !bytes.Equal(interfaceAddr.Addr, addr.Addr) { + filteredAddrs = append(filteredAddrs, addr) + } + } + s.InterfaceAddrsMap[idx] = filteredAddrs + + return nil +} + // SupportsIPv6 implements Stack.SupportsIPv6. func (s *TestStack) SupportsIPv6() bool { return s.SupportsIPv6Flag diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index c0de72eef..90dd4a047 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -79,7 +79,7 @@ go_template_instance( out = "fd_table_refs.go", package = "kernel", prefix = "FDTable", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "FDTable", }, @@ -90,7 +90,7 @@ go_template_instance( out = "fs_context_refs.go", package = "kernel", prefix = "FSContext", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "FSContext", }, @@ -101,7 +101,7 @@ go_template_instance( out = "ipc_namespace_refs.go", package = "kernel", prefix = "IPCNamespace", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "IPCNamespace", }, @@ -112,7 +112,7 @@ go_template_instance( out = "process_group_refs.go", package = "kernel", prefix = "ProcessGroup", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "ProcessGroup", }, @@ -123,7 +123,7 @@ go_template_instance( out = "session_refs.go", package = "kernel", prefix = "Session", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "Session", }, @@ -229,7 +229,7 @@ go_library( "//pkg/marshal/primitive", "//pkg/metric", "//pkg/refs", - "//pkg/refs_vfs2", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/secio", "//pkg/sentry/arch", diff --git a/pkg/sentry/kernel/abstract_socket_namespace.go b/pkg/sentry/kernel/abstract_socket_namespace.go index 1b9721534..0ddbe5ff6 100644 --- a/pkg/sentry/kernel/abstract_socket_namespace.go +++ b/pkg/sentry/kernel/abstract_socket_namespace.go @@ -19,7 +19,7 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs_vfs2" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sync" ) @@ -27,7 +27,7 @@ import ( // +stateify savable type abstractEndpoint struct { ep transport.BoundEndpoint - socket refs_vfs2.RefCounter + socket refsvfs2.RefCounter name string ns *AbstractSocketNamespace } @@ -57,7 +57,7 @@ func NewAbstractSocketNamespace() *AbstractSocketNamespace { // its backing socket. type boundEndpoint struct { transport.BoundEndpoint - socket refs_vfs2.RefCounter + socket refsvfs2.RefCounter } // Release implements transport.BoundEndpoint.Release. @@ -89,7 +89,7 @@ func (a *AbstractSocketNamespace) BoundEndpoint(name string) transport.BoundEndp // // When the last reference managed by socket is dropped, ep may be removed from the // namespace. -func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep transport.BoundEndpoint, socket refs_vfs2.RefCounter) error { +func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep transport.BoundEndpoint, socket refsvfs2.RefCounter) error { a.mu.Lock() defer a.mu.Unlock() @@ -109,7 +109,7 @@ func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep tran // Remove removes the specified socket at name from the abstract socket // namespace, if it has not yet been replaced. -func (a *AbstractSocketNamespace) Remove(name string, socket refs_vfs2.RefCounter) { +func (a *AbstractSocketNamespace) Remove(name string, socket refsvfs2.RefCounter) { a.mu.Lock() defer a.mu.Unlock() diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 0ec7344cd..7aba31587 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -110,7 +110,7 @@ func (f *FDTable) saveDescriptorTable() map[int32]descriptor { func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) { ctx := context.Background() - f.init() // Initialize table. + f.initNoLeakCheck() // Initialize table. f.used = 0 for fd, d := range m { if file, fileVFS2 := f.setAll(ctx, fd, d.file, d.fileVFS2, d.flags); file != nil || fileVFS2 != nil { @@ -240,6 +240,10 @@ func (f *FDTable) String() string { case fileVFS2 != nil: vfsObj := fileVFS2.Mount().Filesystem().VirtualFilesystem() + vd := fileVFS2.VirtualDentry() + if vd.Dentry() == nil { + panic(fmt.Sprintf("fd %d (type %T) has nil dentry: %#v", fd, fileVFS2.Impl(), fileVFS2)) + } name, err := vfsObj.PathnameWithDeleted(ctx, vfs.VirtualDentry{}, fileVFS2.VirtualDentry()) if err != nil { fmt.Fprintf(&buf, "<err: %v>\n", err) diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go index da79e6627..3476551f3 100644 --- a/pkg/sentry/kernel/fd_table_unsafe.go +++ b/pkg/sentry/kernel/fd_table_unsafe.go @@ -31,14 +31,21 @@ type descriptorTable struct { slice unsafe.Pointer `state:".(map[int32]*descriptor)"` } -// init initializes the table. +// initNoLeakCheck initializes the table without enabling leak checking. // -// TODO(gvisor.dev/1486): Enable leak check for FDTable. -func (f *FDTable) init() { +// This is used when loading an FDTable after S/R, during which the ref count +// object itself will enable leak checking if necessary. +func (f *FDTable) initNoLeakCheck() { var slice []unsafe.Pointer // Empty slice. atomic.StorePointer(&f.slice, unsafe.Pointer(&slice)) } +// init initializes the table with leak checking. +func (f *FDTable) init() { + f.initNoLeakCheck() + f.EnableLeakCheck() +} + // get gets a file entry. // // The boolean indicates whether this was in range. diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go index d46d1e1c1..41fb2a784 100644 --- a/pkg/sentry/kernel/fs_context.go +++ b/pkg/sentry/kernel/fs_context.go @@ -130,13 +130,15 @@ func (f *FSContext) Fork() *FSContext { f.root.IncRef() } - return &FSContext{ + ctx := &FSContext{ cwd: f.cwd, root: f.root, cwdVFS2: f.cwdVFS2, rootVFS2: f.rootVFS2, umask: f.umask, } + ctx.EnableLeakCheck() + return ctx } // WorkingDirectory returns the current working directory. @@ -147,19 +149,23 @@ func (f *FSContext) WorkingDirectory() *fs.Dirent { f.mu.Lock() defer f.mu.Unlock() - f.cwd.IncRef() + if f.cwd != nil { + f.cwd.IncRef() + } return f.cwd } // WorkingDirectoryVFS2 returns the current working directory. // -// This will return nil if called after f is destroyed, otherwise it will return -// a Dirent with a reference taken. +// This will return an empty vfs.VirtualDentry if called after f is +// destroyed, otherwise it will return a Dirent with a reference taken. func (f *FSContext) WorkingDirectoryVFS2() vfs.VirtualDentry { f.mu.Lock() defer f.mu.Unlock() - f.cwdVFS2.IncRef() + if f.cwdVFS2.Ok() { + f.cwdVFS2.IncRef() + } return f.cwdVFS2 } @@ -218,13 +224,15 @@ func (f *FSContext) RootDirectory() *fs.Dirent { // RootDirectoryVFS2 returns the current filesystem root. // -// This will return nil if called after f is destroyed, otherwise it will return -// a Dirent with a reference taken. +// This will return an empty vfs.VirtualDentry if called after f is +// destroyed, otherwise it will return a Dirent with a reference taken. func (f *FSContext) RootDirectoryVFS2() vfs.VirtualDentry { f.mu.Lock() defer f.mu.Unlock() - f.rootVFS2.IncRef() + if f.rootVFS2.Ok() { + f.rootVFS2.IncRef() + } return f.rootVFS2 } diff --git a/pkg/sentry/kernel/ipc_namespace.go b/pkg/sentry/kernel/ipc_namespace.go index 3f34ee0db..b87e40dd1 100644 --- a/pkg/sentry/kernel/ipc_namespace.go +++ b/pkg/sentry/kernel/ipc_namespace.go @@ -55,7 +55,7 @@ func (i *IPCNamespace) ShmRegistry() *shm.Registry { return i.shms } -// DecRef implements refs_vfs2.RefCounter.DecRef. +// DecRef implements refsvfs2.RefCounter.DecRef. func (i *IPCNamespace) DecRef(ctx context.Context) { i.IPCNamespaceRefs.DecRef(func() { i.shms.Release(ctx) diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 0eb2bf7bd..9b2be44d4 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -430,9 +430,8 @@ func (k *Kernel) Init(args InitKernelArgs) error { // SaveTo saves the state of k to w. // // Preconditions: The kernel must be paused throughout the call to SaveTo. -func (k *Kernel) SaveTo(w wire.Writer) error { +func (k *Kernel) SaveTo(ctx context.Context, w wire.Writer) error { saveStart := time.Now() - ctx := k.SupervisorContext() // Do not allow other Kernel methods to affect it while it's being saved. k.extMu.Lock() @@ -446,38 +445,55 @@ func (k *Kernel) SaveTo(w wire.Writer) error { k.mf.StartEvictions() k.mf.WaitForEvictions() - // Flush write operations on open files so data reaches backing storage. - // This must come after MemoryFile eviction since eviction may cause file - // writes. - if err := k.tasks.flushWritesToFiles(ctx); err != nil { - return err - } + if VFS2Enabled { + // Discard unsavable mappings, such as those for host file descriptors. + if err := k.invalidateUnsavableMappings(ctx); err != nil { + return fmt.Errorf("failed to invalidate unsavable mappings: %v", err) + } - // Remove all epoll waiter objects from underlying wait queues. - // NOTE: for programs to resume execution in future snapshot scenarios, - // we will need to re-establish these waiter objects after saving. - k.tasks.unregisterEpollWaiters(ctx) + // Prepare filesystems for saving. This must be done after + // invalidateUnsavableMappings(), since dropping memory mappings may + // affect filesystem state (e.g. page cache reference counts). + if err := k.vfs.PrepareSave(ctx); err != nil { + return err + } + } else { + // Flush cached file writes to backing storage. This must come after + // MemoryFile eviction since eviction may cause file writes. + if err := k.flushWritesToFiles(ctx); err != nil { + return err + } - // Clear the dirent cache before saving because Dirents must be Loaded in a - // particular order (parents before children), and Loading dirents from a cache - // breaks that order. - if err := k.flushMountSourceRefs(ctx); err != nil { - return err - } + // Remove all epoll waiter objects from underlying wait queues. + // NOTE: for programs to resume execution in future snapshot scenarios, + // we will need to re-establish these waiter objects after saving. + k.tasks.unregisterEpollWaiters(ctx) - // Ensure that all inode and mount release operations have completed. - fs.AsyncBarrier() + // Clear the dirent cache before saving because Dirents must be Loaded in a + // particular order (parents before children), and Loading dirents from a cache + // breaks that order. + if err := k.flushMountSourceRefs(ctx); err != nil { + return err + } - // Once all fs work has completed (flushed references have all been released), - // reset mount mappings. This allows individual mounts to save how inodes map - // to filesystem resources. Without this, fs.Inodes cannot be restored. - fs.SaveInodeMappings() + // Ensure that all inode and mount release operations have completed. + fs.AsyncBarrier() - // Discard unsavable mappings, such as those for host file descriptors. - // This must be done after waiting for "asynchronous fs work", which - // includes async I/O that may touch application memory. - if err := k.invalidateUnsavableMappings(ctx); err != nil { - return fmt.Errorf("failed to invalidate unsavable mappings: %v", err) + // Once all fs work has completed (flushed references have all been released), + // reset mount mappings. This allows individual mounts to save how inodes map + // to filesystem resources. Without this, fs.Inodes cannot be restored. + fs.SaveInodeMappings() + + // Discard unsavable mappings, such as those for host file descriptors. + // This must be done after waiting for "asynchronous fs work", which + // includes async I/O that may touch application memory. + // + // TODO(gvisor.dev/issue/1624): This rationale is believed to be + // obsolete since AIO callbacks are now waited-for by Kernel.Pause(), + // but this order is conservatively retained for VFS1. + if err := k.invalidateUnsavableMappings(ctx); err != nil { + return fmt.Errorf("failed to invalidate unsavable mappings: %v", err) + } } // Save the CPUID FeatureSet before the rest of the kernel so we can @@ -486,14 +502,14 @@ func (k *Kernel) SaveTo(w wire.Writer) error { // // N.B. This will also be saved along with the full kernel save below. cpuidStart := time.Now() - if _, err := state.Save(k.SupervisorContext(), w, k.FeatureSet()); err != nil { + if _, err := state.Save(ctx, w, k.FeatureSet()); err != nil { return err } log.Infof("CPUID save took [%s].", time.Since(cpuidStart)) // Save the kernel state. kernelStart := time.Now() - stats, err := state.Save(k.SupervisorContext(), w, k) + stats, err := state.Save(ctx, w, k) if err != nil { return err } @@ -502,7 +518,7 @@ func (k *Kernel) SaveTo(w wire.Writer) error { // Save the memory file's state. memoryStart := time.Now() - if err := k.mf.SaveTo(k.SupervisorContext(), w); err != nil { + if err := k.mf.SaveTo(ctx, w); err != nil { return err } log.Infof("Memory save took [%s].", time.Since(memoryStart)) @@ -514,11 +530,9 @@ func (k *Kernel) SaveTo(w wire.Writer) error { // flushMountSourceRefs flushes the MountSources for all mounted filesystems // and open FDs. +// +// Preconditions: !VFS2Enabled. func (k *Kernel) flushMountSourceRefs(ctx context.Context) error { - if VFS2Enabled { - return nil // Not relevant. - } - // Flush all mount sources for currently mounted filesystems in each task. flushed := make(map[*fs.MountNamespace]struct{}) k.tasks.mu.RLock() @@ -561,13 +575,9 @@ func (ts *TaskSet) forEachFDPaused(ctx context.Context, f func(*fs.File, *vfs.Fi return err } -func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error { - // TODO(gvisor.dev/issue/1663): Add save support for VFS2. - if VFS2Enabled { - return nil - } - - return ts.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error { +// Preconditions: !VFS2Enabled. +func (k *Kernel) flushWritesToFiles(ctx context.Context) error { + return k.tasks.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error { if flags := file.Flags(); !flags.Write { return nil } @@ -589,37 +599,8 @@ func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error { }) } -// Preconditions: The kernel must be paused. -func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error { - invalidated := make(map[*mm.MemoryManager]struct{}) - k.tasks.mu.RLock() - defer k.tasks.mu.RUnlock() - for t := range k.tasks.Root.tids { - // We can skip locking Task.mu here since the kernel is paused. - if mm := t.tc.MemoryManager; mm != nil { - if _, ok := invalidated[mm]; !ok { - if err := mm.InvalidateUnsavable(ctx); err != nil { - return err - } - invalidated[mm] = struct{}{} - } - } - // I really wish we just had a sync.Map of all MMs... - if r, ok := t.runState.(*runSyscallAfterExecStop); ok { - if err := r.tc.MemoryManager.InvalidateUnsavable(ctx); err != nil { - return err - } - } - } - return nil -} - +// Preconditions: !VFS2Enabled. func (ts *TaskSet) unregisterEpollWaiters(ctx context.Context) { - // TODO(gvisor.dev/issue/1663): Add save support for VFS2. - if VFS2Enabled { - return - } - ts.mu.RLock() defer ts.mu.RUnlock() @@ -644,8 +625,33 @@ func (ts *TaskSet) unregisterEpollWaiters(ctx context.Context) { } } +// Preconditions: The kernel must be paused. +func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error { + invalidated := make(map[*mm.MemoryManager]struct{}) + k.tasks.mu.RLock() + defer k.tasks.mu.RUnlock() + for t := range k.tasks.Root.tids { + // We can skip locking Task.mu here since the kernel is paused. + if mm := t.tc.MemoryManager; mm != nil { + if _, ok := invalidated[mm]; !ok { + if err := mm.InvalidateUnsavable(ctx); err != nil { + return err + } + invalidated[mm] = struct{}{} + } + } + // I really wish we just had a sync.Map of all MMs... + if r, ok := t.runState.(*runSyscallAfterExecStop); ok { + if err := r.tc.MemoryManager.InvalidateUnsavable(ctx); err != nil { + return err + } + } + } + return nil +} + // LoadFrom returns a new Kernel loaded from args. -func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clocks) error { +func (k *Kernel) LoadFrom(ctx context.Context, r wire.Reader, net inet.Stack, clocks sentrytime.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error { loadStart := time.Now() initAppCores := k.applicationCores @@ -656,7 +662,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock // don't need to explicitly install it in the Kernel. cpuidStart := time.Now() var features cpuid.FeatureSet - if _, err := state.Load(k.SupervisorContext(), r, &features); err != nil { + if _, err := state.Load(ctx, r, &features); err != nil { return err } log.Infof("CPUID load took [%s].", time.Since(cpuidStart)) @@ -671,7 +677,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock // Load the kernel state. kernelStart := time.Now() - stats, err := state.Load(k.SupervisorContext(), r, k) + stats, err := state.Load(ctx, r, k) if err != nil { return err } @@ -684,7 +690,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock // Load the memory file's state. memoryStart := time.Now() - if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil { + if err := k.mf.LoadFrom(ctx, r); err != nil { return err } log.Infof("Memory load took [%s].", time.Since(memoryStart)) @@ -696,11 +702,17 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock net.Resume() } - // Ensure that all pending asynchronous work is complete: - // - namedpipe opening - // - inode file opening - if err := fs.AsyncErrorBarrier(); err != nil { - return err + if VFS2Enabled { + if err := k.vfs.CompleteRestore(ctx, vfsOpts); err != nil { + return err + } + } else { + // Ensure that all pending asynchronous work is complete: + // - namedpipe opening + // - inode file opening + if err := fs.AsyncErrorBarrier(); err != nil { + return err + } } tcpip.AsyncLoading.Wait() diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index 1a152142b..d96bf253b 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -33,6 +33,8 @@ import ( // VFSPipe represents the actual pipe, analagous to an inode. VFSPipes should // not be copied. +// +// +stateify savable type VFSPipe struct { // mu protects the fields below. mu sync.Mutex `state:"nosave"` @@ -164,6 +166,8 @@ func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, l // VFSPipeFD implements vfs.FileDescriptionImpl for pipes. It also implements // non-atomic usermem.IO methods, allowing it to be passed as usermem.IO to // other FileDescriptions for splice(2) and tee(2). +// +// +stateify savable type VFSPipeFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index c00fa1138..c39ecfb8f 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -283,6 +283,33 @@ func (s *Set) Change(ctx context.Context, creds *auth.Credentials, owner fs.File return nil } +// GetStat extracts semid_ds information from the set. +func (s *Set) GetStat(creds *auth.Credentials) (*linux.SemidDS, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // "The calling process must have read permission on the semaphore set." + if !s.checkPerms(creds, fs.PermMask{Read: true}) { + return nil, syserror.EACCES + } + + ds := &linux.SemidDS{ + SemPerm: linux.IPCPerm{ + Key: uint32(s.key), + UID: uint32(creds.UserNamespace.MapFromKUID(s.owner.UID)), + GID: uint32(creds.UserNamespace.MapFromKGID(s.owner.GID)), + CUID: uint32(creds.UserNamespace.MapFromKUID(s.creator.UID)), + CGID: uint32(creds.UserNamespace.MapFromKGID(s.creator.GID)), + Mode: uint16(s.perms.LinuxMode()), + Seq: 0, // IPC sequence not supported. + }, + SemOTime: s.opTime.TimeT(), + SemCTime: s.changeTime.TimeT(), + SemNSems: uint64(s.Size()), + } + return ds, nil +} + // SetVal overrides a semaphore value, waking up waiters as needed. func (s *Set) SetVal(ctx context.Context, num int32, val int16, creds *auth.Credentials, pid int32) error { if val < 0 || val > valueMax { @@ -320,7 +347,7 @@ func (s *Set) SetValAll(ctx context.Context, vals []uint16, creds *auth.Credenti } for _, val := range vals { - if val < 0 || val > valueMax { + if val > valueMax { return syserror.ERANGE } } diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index f8a382fd8..80a592c8f 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "shm_refs.go", package = "shm", prefix = "Shm", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "Shm", }, @@ -27,7 +27,7 @@ go_library( "//pkg/context", "//pkg/log", "//pkg/refs", - "//pkg/refs_vfs2", + "//pkg/refsvfs2", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index b4a47ccca..6dbeccfe2 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -78,7 +78,7 @@ go_template_instance( out = "aio_mappable_refs.go", package = "mm", prefix = "aioMappable", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "aioMappable", }, @@ -89,7 +89,7 @@ go_template_instance( out = "special_mappable_refs.go", package = "mm", prefix = "SpecialMappable", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "SpecialMappable", }, @@ -127,6 +127,7 @@ go_library( "//pkg/context", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safecopy", "//pkg/safemem", "//pkg/sentry/arch", diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index a163f956d..84992c06d 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -235,8 +235,9 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) return c.fault(int32(syscall.SIGSEGV), info) case ring0.Vector(bounce): // ring0.VirtualizationException return usermem.NoAccess, platform.ErrContextInterrupt - case ring0.El0Sync_undef, - ring0.El1Sync_undef: + case ring0.El0Sync_undef: + return c.fault(int32(syscall.SIGILL), info) + case ring0.El1Sync_undef: *info = arch.SignalInfo{ Signo: int32(syscall.SIGILL), Code: 1, // ILL_ILLOPC (illegal opcode). diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index 2370a9276..1079a024b 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -366,6 +366,19 @@ MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \ LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack. +// EXCEPTION_WITH_ERROR is a common exception handler function. +#define EXCEPTION_WITH_ERROR(user, vector) \ + WORD $0xd538d092; \ //MRS TPIDR_EL1, R18 + WORD $0xd538601a; \ //MRS FAR_EL1, R26 + MOVD R26, CPU_FAULT_ADDR(RSV_REG); \ + MOVD $user, R3; \ + MOVD R3, CPU_ERROR_TYPE(RSV_REG); \ // Set error type to user. + MOVD $vector, R3; \ + MOVD R3, CPU_VECTOR_CODE(RSV_REG); \ + MRS ESR_EL1, R3; \ + MOVD R3, CPU_ERROR_CODE(RSV_REG); \ + B ·kernelExitToEl1(SB); + // storeAppASID writes the application's asid value. TEXT ·storeAppASID(SB),NOSPLIT,$0-8 MOVD asid+0(FP), R1 @@ -659,21 +672,7 @@ el0_svc: el0_da: el0_ia: - WORD $0xd538d092 //MRS TPIDR_EL1, R18 - WORD $0xd538601a //MRS FAR_EL1, R26 - - MOVD R26, CPU_FAULT_ADDR(RSV_REG) - - MOVD $1, R3 - MOVD R3, CPU_ERROR_TYPE(RSV_REG) // Set error type to user. - - MOVD $PageFault, R3 - MOVD R3, CPU_VECTOR_CODE(RSV_REG) - - MRS ESR_EL1, R3 - MOVD R3, CPU_ERROR_CODE(RSV_REG) - - B ·kernelExitToEl1(SB) + EXCEPTION_WITH_ERROR(1, PageFault) el0_fpsimd_acc: B ·Shutdown(SB) @@ -688,10 +687,7 @@ el0_sp_pc: B ·Shutdown(SB) el0_undef: - MOVD $El0Sync_undef, R3 - MOVD R3, CPU_VECTOR_CODE(RSV_REG) - - B ·kernelExitToEl1(SB) + EXCEPTION_WITH_ERROR(1, El0Sync_undef) el0_dbg: B ·Shutdown(SB) diff --git a/pkg/sentry/socket/control/control_vfs2.go b/pkg/sentry/socket/control/control_vfs2.go index d9621968c..37d02948f 100644 --- a/pkg/sentry/socket/control/control_vfs2.go +++ b/pkg/sentry/socket/control/control_vfs2.go @@ -24,6 +24,8 @@ import ( ) // SCMRightsVFS2 represents a SCM_RIGHTS socket control message. +// +// +stateify savable type SCMRightsVFS2 interface { transport.RightsControlMessage @@ -34,9 +36,11 @@ type SCMRightsVFS2 interface { Files(ctx context.Context, max int) (rf RightsFilesVFS2, truncated bool) } -// RightsFiles represents a SCM_RIGHTS socket control message. A reference is -// maintained for each vfs.FileDescription and is release either when an FD is created or -// when the Release method is called. +// RightsFilesVFS2 represents a SCM_RIGHTS socket control message. A reference +// is maintained for each vfs.FileDescription and is release either when an FD +// is created or when the Release method is called. +// +// +stateify savable type RightsFilesVFS2 []*vfs.FileDescription // NewSCMRightsVFS2 creates a new SCM_RIGHTS socket control message diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go index 163af329b..9a2cac40b 100644 --- a/pkg/sentry/socket/hostinet/socket_vfs2.go +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -33,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// +stateify savable type socketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -51,7 +52,7 @@ var _ = socket.SocketVFS2(&socketVFS2{}) func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol int, fd int, flags uint32) (*vfs.FileDescription, *syserr.Error) { mnt := t.Kernel().SocketMount() - d := sockfs.NewDentry(t.Credentials(), mnt) + d := sockfs.NewDentry(t, mnt) defer d.DecRef(t) s := &socketVFS2{ diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index faa61160e..7e7857ac3 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -324,7 +324,12 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { } // AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. -func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { +func (s *Stack) AddInterfaceAddr(int32, inet.InterfaceAddr) error { + return syserror.EACCES +} + +// RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr. +func (s *Stack) RemoveInterfaceAddr(int32, inet.InterfaceAddr) error { return syserror.EACCES } @@ -359,7 +364,7 @@ func (s *Stack) TCPSACKEnabled() (bool, error) { } // SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled. -func (s *Stack) SetTCPSACKEnabled(enabled bool) error { +func (s *Stack) SetTCPSACKEnabled(bool) error { return syserror.EACCES } @@ -369,7 +374,7 @@ func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { } // SetTCPRecovery implements inet.Stack.SetTCPRecovery. -func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error { +func (s *Stack) SetTCPRecovery(inet.TCPLossRecovery) error { return syserror.EACCES } @@ -430,18 +435,18 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { } if rawLine == "" { - return fmt.Errorf("Failed to get raw line") + return fmt.Errorf("failed to get raw line") } parts := strings.SplitN(rawLine, ":", 2) if len(parts) != 2 { - return fmt.Errorf("Failed to get prefix from: %q", rawLine) + return fmt.Errorf("failed to get prefix from: %q", rawLine) } sliceStat = toSlice(stat) fields := strings.Fields(strings.TrimSpace(parts[1])) if len(fields) != len(sliceStat) { - return fmt.Errorf("Failed to parse fields: %q", rawLine) + return fmt.Errorf("failed to parse fields: %q", rawLine) } if _, ok := stat.(*inet.StatSNMPTCP); ok { snmpTCP = true @@ -457,7 +462,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { sliceStat[i], err = strconv.ParseUint(fields[i], 10, 64) } if err != nil { - return fmt.Errorf("Failed to parse field %d from: %q, %v", i, rawLine, err) + return fmt.Errorf("failed to parse field %d from: %q, %v", i, rawLine, err) } } @@ -495,6 +500,6 @@ func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { } // SetForwarding implements inet.Stack.SetForwarding. -func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { +func (s *Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error { return syserror.EACCES } diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 844acfede..352c51390 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -71,7 +71,7 @@ func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma } if filter.Protocol != header.TCPProtocolNumber { - return nil, fmt.Errorf("TCP matching is only valid for protocol %d.", header.TCPProtocolNumber) + return nil, fmt.Errorf("TCP matching is only valid for protocol %d", header.TCPProtocolNumber) } return &TCPMatcher{ diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 63201201c..c88d8268d 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -68,7 +68,7 @@ func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma } if filter.Protocol != header.UDPProtocolNumber { - return nil, fmt.Errorf("UDP matching is only valid for protocol %d.", header.UDPProtocolNumber) + return nil, fmt.Errorf("UDP matching is only valid for protocol %d", header.UDPProtocolNumber) } return &UDPMatcher{ diff --git a/pkg/sentry/socket/netlink/provider_vfs2.go b/pkg/sentry/socket/netlink/provider_vfs2.go index e8930f031..f061c5d62 100644 --- a/pkg/sentry/socket/netlink/provider_vfs2.go +++ b/pkg/sentry/socket/netlink/provider_vfs2.go @@ -51,7 +51,7 @@ func (*socketProviderVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol vfsfd := &s.vfsfd mnt := t.Kernel().SocketMount() - d := sockfs.NewDentry(t.Credentials(), mnt) + d := sockfs.NewDentry(t, mnt) defer d.DecRef(t) if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{ DenyPRead: true, diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index c84d8bd7c..22216158e 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -36,9 +36,9 @@ type commandKind int const ( kindNew commandKind = 0x0 - kindDel = 0x1 - kindGet = 0x2 - kindSet = 0x3 + kindDel commandKind = 0x1 + kindGet commandKind = 0x2 + kindSet commandKind = 0x3 ) func typeKind(typ uint16) commandKind { @@ -423,6 +423,11 @@ func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlin } attrs = rest + // NOTE: A netlink message will contain multiple header attributes. + // Both the IFA_ADDRESS and IFA_LOCAL attributes are typically sent + // with IFA_ADDRESS being a prefix address and IFA_LOCAL being the + // local interface address. We add the local interface address here + // and ignore the IFA_ADDRESS. switch ahdr.Type { case linux.IFA_LOCAL: err := stack.AddInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{ @@ -439,8 +444,57 @@ func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlin } else if err != nil { return syserr.ErrInvalidArgument } + case linux.IFA_ADDRESS: + default: + return syserr.ErrNotSupported + } + } + return nil +} + +// delAddr handles RTM_DELADDR requests. +func (p *Protocol) delAddr(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { + stack := inet.StackFromContext(ctx) + if stack == nil { + // No network stack. + return syserr.ErrProtocolNotSupported + } + + var ifa linux.InterfaceAddrMessage + attrs, ok := msg.GetData(&ifa) + if !ok { + return syserr.ErrInvalidArgument + } + + for !attrs.Empty() { + ahdr, value, rest, ok := attrs.ParseFirst() + if !ok { + return syserr.ErrInvalidArgument + } + attrs = rest + + // NOTE: A netlink message will contain multiple header attributes. + // Both the IFA_ADDRESS and IFA_LOCAL attributes are typically sent + // with IFA_ADDRESS being a prefix address and IFA_LOCAL being the + // local interface address. We use the local interface address to + // remove the address and ignore the IFA_ADDRESS. + switch ahdr.Type { + case linux.IFA_LOCAL: + err := stack.RemoveInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{ + Family: ifa.Family, + PrefixLen: ifa.PrefixLen, + Flags: ifa.Flags, + Addr: value, + }) + if err != nil { + return syserr.ErrInvalidArgument + } + case linux.IFA_ADDRESS: + default: + return syserr.ErrNotSupported } } + return nil } @@ -485,6 +539,8 @@ func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms return p.dumpRoutes(ctx, msg, ms) case linux.RTM_NEWADDR: return p.newAddr(ctx, msg, ms) + case linux.RTM_DELADDR: + return p.delAddr(ctx, msg, ms) default: return syserr.ErrNotSupported } diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go index c83b23242..461d524e5 100644 --- a/pkg/sentry/socket/netlink/socket_vfs2.go +++ b/pkg/sentry/socket/netlink/socket_vfs2.go @@ -37,6 +37,8 @@ import ( // to/from the kernel. // // SocketVFS2 implements socket.SocketVFS2 and transport.Credentialer. +// +// +stateify savable type SocketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 211f07947..86c634715 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1244,6 +1244,18 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam vP := primitive.Int32(boolToInt32(v)) return &vP, nil + case linux.SO_ACCEPTCONN: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptBool(tcpip.AcceptConnOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil + default: socket.GetSockOptEmitUnimplementedEvent(t, name) } diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index 4c6791fff..b0d9e4d9e 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -35,6 +35,8 @@ import ( // SocketVFS2 encapsulates all the state needed to represent a network stack // endpoint in the kernel context. +// +// +stateify savable type SocketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -55,7 +57,7 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu } mnt := t.Kernel().SocketMount() - d := sockfs.NewDentry(t.Credentials(), mnt) + d := sockfs.NewDentry(t, mnt) defer d.DecRef(t) s := &SocketVFS2{ diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 1028d2a6e..fa9ac9059 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -100,56 +100,101 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { return nicAddrs } -// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. -func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { +// convertAddr converts an InterfaceAddr to a ProtocolAddress. +func convertAddr(addr inet.InterfaceAddr) (tcpip.ProtocolAddress, error) { var ( - protocol tcpip.NetworkProtocolNumber - address tcpip.Address + protocol tcpip.NetworkProtocolNumber + address tcpip.Address + protocolAddress tcpip.ProtocolAddress ) switch addr.Family { case linux.AF_INET: - if len(addr.Addr) < header.IPv4AddressSize { - return syserror.EINVAL + if len(addr.Addr) != header.IPv4AddressSize { + return protocolAddress, syserror.EINVAL } if addr.PrefixLen > header.IPv4AddressSize*8 { - return syserror.EINVAL + return protocolAddress, syserror.EINVAL } protocol = ipv4.ProtocolNumber - address = tcpip.Address(addr.Addr[:header.IPv4AddressSize]) - + address = tcpip.Address(addr.Addr) case linux.AF_INET6: - if len(addr.Addr) < header.IPv6AddressSize { - return syserror.EINVAL + if len(addr.Addr) != header.IPv6AddressSize { + return protocolAddress, syserror.EINVAL } if addr.PrefixLen > header.IPv6AddressSize*8 { - return syserror.EINVAL + return protocolAddress, syserror.EINVAL } protocol = ipv6.ProtocolNumber - address = tcpip.Address(addr.Addr[:header.IPv6AddressSize]) - + address = tcpip.Address(addr.Addr) default: - return syserror.ENOTSUP + return protocolAddress, syserror.ENOTSUP } - protocolAddress := tcpip.ProtocolAddress{ + protocolAddress = tcpip.ProtocolAddress{ Protocol: protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: address, PrefixLen: int(addr.PrefixLen), }, } + return protocolAddress, nil +} + +// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. +func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { + protocolAddress, err := convertAddr(addr) + if err != nil { + return err + } // Attach address to interface. - if err := s.Stack.AddProtocolAddressWithOptions(tcpip.NICID(idx), protocolAddress, stack.CanBePrimaryEndpoint); err != nil { + nicID := tcpip.NICID(idx) + if err := s.Stack.AddProtocolAddressWithOptions(nicID, protocolAddress, stack.CanBePrimaryEndpoint); err != nil { + return syserr.TranslateNetstackError(err).ToError() + } + + // Add route for local network if it doesn't exist already. + localRoute := tcpip.Route{ + Destination: protocolAddress.AddressWithPrefix.Subnet(), + Gateway: "", // No gateway for local network. + NIC: nicID, + } + + for _, rt := range s.Stack.GetRouteTable() { + if rt.Equal(localRoute) { + return nil + } + } + + // Local route does not exist yet. Add it. + s.Stack.AddRoute(localRoute) + + return nil +} + +// RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr. +func (s *Stack) RemoveInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { + protocolAddress, err := convertAddr(addr) + if err != nil { + return err + } + + // Remove addresses matching the address and prefix. + nicID := tcpip.NICID(idx) + if err := s.Stack.RemoveAddress(nicID, protocolAddress.AddressWithPrefix.Address); err != nil { return syserr.TranslateNetstackError(err).ToError() } - // Add route for local network. - s.Stack.AddRoute(tcpip.Route{ + // Remove the corresponding local network route if it exists. + localRoute := tcpip.Route{ Destination: protocolAddress.AddressWithPrefix.Subnet(), Gateway: "", // No gateway for local network. - NIC: tcpip.NICID(idx), + NIC: nicID, + } + s.Stack.RemoveRoutes(func(rt tcpip.Route) bool { + return rt.Equal(localRoute) }) + return nil } diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index cc7408698..cce0acc33 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "socket_refs.go", package = "unix", prefix = "socketOperations", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "SocketOperations", }, @@ -19,7 +19,7 @@ go_template_instance( out = "socket_vfs2_refs.go", package = "unix", prefix = "socketVFS2", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "SocketVFS2", }, @@ -43,6 +43,7 @@ go_library( "//pkg/log", "//pkg/marshal", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/device", diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 26c3a51b9..3ebbd28b0 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -20,7 +20,7 @@ go_template_instance( out = "queue_refs.go", package = "transport", prefix = "queue", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "queue", }, @@ -44,6 +44,7 @@ go_library( "//pkg/ilist", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sync", "//pkg/syserr", "//pkg/tcpip", diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index d6fc03520..b648273a4 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -32,6 +32,8 @@ import ( const initialLimit = 16 * 1024 // A RightsControlMessage is a control message containing FDs. +// +// +stateify savable type RightsControlMessage interface { // Clone returns a copy of the RightsControlMessage. Clone() RightsControlMessage @@ -336,7 +338,7 @@ type Receiver interface { RecvMaxQueueSize() int64 // Release releases any resources owned by the Receiver. It should be - // called before droping all references to a Receiver. + // called before dropping all references to a Receiver. Release(ctx context.Context) } @@ -487,7 +489,7 @@ func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds c := q.control.Clone() // Don't consume data since we are peeking. - copied, data, _ = vecCopy(data, q.buffer) + copied, _, _ = vecCopy(data, q.buffer) return copied, copied, c, false, q.addr, notify, nil } @@ -572,6 +574,12 @@ func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds return copied, copied, c, cmTruncated, q.addr, notify, nil } +// Release implements Receiver.Release. +func (q *streamQueueReceiver) Release(ctx context.Context) { + q.queueReceiver.Release(ctx) + q.control.Release(ctx) +} + // A ConnectedEndpoint is an Endpoint that can be used to send Messages. type ConnectedEndpoint interface { // Passcred implements Endpoint.Passcred. @@ -619,7 +627,7 @@ type ConnectedEndpoint interface { SendMaxQueueSize() int64 // Release releases any resources owned by the ConnectedEndpoint. It should - // be called before droping all references to a ConnectedEndpoint. + // be called before dropping all references to a ConnectedEndpoint. Release(ctx context.Context) // CloseUnread sets the fact that this end is closed with unread data to @@ -879,7 +887,7 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { - case tcpip.KeepaliveEnabledOption: + case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: return false, nil case tcpip.PasscredOption: diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index a4a76d0a3..adad485a9 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -81,7 +81,6 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty }, } s.EnableLeakCheck() - return fs.NewFile(ctx, d, flags, &s) } diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 678355fb9..7a78444dc 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -55,7 +55,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{}) // returns a corresponding file description. func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) { mnt := t.Kernel().SocketMount() - d := sockfs.NewDentry(t.Credentials(), mnt) + d := sockfs.NewDentry(t, mnt) defer d.DecRef(t) fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &vfs.FileLocks{}) @@ -80,6 +80,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3 stype: stype, }, } + sock.EnableLeakCheck() sock.LockFD.Init(locks) vfsfd := &sock.vfsfd if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{ diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD index 0ea4aab8b..563d60578 100644 --- a/pkg/sentry/state/BUILD +++ b/pkg/sentry/state/BUILD @@ -12,10 +12,12 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/log", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/time", + "//pkg/sentry/vfs", "//pkg/sentry/watchdog", "//pkg/state/statefile", "//pkg/syserror", diff --git a/pkg/sentry/state/state.go b/pkg/sentry/state/state.go index 245d2c5cf..167754537 100644 --- a/pkg/sentry/state/state.go +++ b/pkg/sentry/state/state.go @@ -19,10 +19,12 @@ import ( "fmt" "io" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/time" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sentry/watchdog" "gvisor.dev/gvisor/pkg/state/statefile" "gvisor.dev/gvisor/pkg/syserror" @@ -57,7 +59,7 @@ type SaveOpts struct { } // Save saves the system state. -func (opts SaveOpts) Save(k *kernel.Kernel, w *watchdog.Watchdog) error { +func (opts SaveOpts) Save(ctx context.Context, k *kernel.Kernel, w *watchdog.Watchdog) error { log.Infof("Sandbox save started, pausing all tasks.") k.Pause() k.ReceiveTaskStates() @@ -81,7 +83,7 @@ func (opts SaveOpts) Save(k *kernel.Kernel, w *watchdog.Watchdog) error { err = ErrStateFile{err} } else { // Save the kernel. - err = k.SaveTo(wc) + err = k.SaveTo(ctx, wc) // ENOSPC is a state file error. This error can only come from // writing the state file, and not from fs.FileOperations.Fsync @@ -108,7 +110,7 @@ type LoadOpts struct { } // Load loads the given kernel, setting the provided platform and stack. -func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack, clocks time.Clocks) error { +func (opts LoadOpts) Load(ctx context.Context, k *kernel.Kernel, n inet.Stack, clocks time.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error { // Open the file. r, m, err := statefile.NewReader(opts.Source, opts.Key) if err != nil { @@ -118,5 +120,5 @@ func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack, clocks time.Clocks) er previousMetadata = m // Restore the Kernel object graph. - return k.LoadFrom(r, n, clocks) + return k.LoadFrom(ctx, r, n, clocks, vfsOpts) } diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 9c9def7cd..36902d177 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -118,7 +118,7 @@ var AMD64 = &kernel.SyscallTable{ 63: syscalls.Supported("uname", Uname), 64: syscalls.Supported("semget", Semget), 65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), - 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), + 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), 67: syscalls.Supported("shmdt", Shmdt), 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) @@ -619,7 +619,7 @@ var ARM64 = &kernel.SyscallTable{ 188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 190: syscalls.Supported("semget", Semget), - 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), + 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), 193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), 194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go index 47dadb800..c2d4bf805 100644 --- a/pkg/sentry/syscalls/linux/sys_sem.go +++ b/pkg/sentry/syscalls/linux/sys_sem.go @@ -129,9 +129,17 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal v, err := getPID(t, id, num) return uintptr(v), nil, err + case linux.IPC_STAT: + arg := args[3].Pointer() + ds, err := ipcStat(t, id) + if err == nil { + _, err = ds.CopyOut(t, arg) + } + + return 0, nil, err + case linux.IPC_INFO, linux.SEM_INFO, - linux.IPC_STAT, linux.SEM_STAT, linux.SEM_STAT_ANY, linux.GETNCNT, @@ -171,6 +179,16 @@ func ipcSet(t *kernel.Task, id int32, uid auth.UID, gid auth.GID, perms fs.FileP return set.Change(t, creds, owner, perms) } +func ipcStat(t *kernel.Task, id int32) (*linux.SemidDS, error) { + r := t.IPCNamespace().SemaphoreRegistry() + set := r.FindByID(id) + if set == nil { + return nil, syserror.EINVAL + } + creds := auth.CredentialsFromContext(t) + return set.GetStat(creds) +} + func setVal(t *kernel.Task, id int32, num int32, val int16) error { r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByID(id) diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index c855608db..440c9307c 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -32,7 +32,7 @@ go_template_instance( out = "file_description_refs.go", package = "vfs", prefix = "FileDescription", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "FileDescription", }, @@ -43,7 +43,7 @@ go_template_instance( out = "mount_namespace_refs.go", package = "vfs", prefix = "MountNamespace", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "MountNamespace", }, @@ -54,7 +54,7 @@ go_template_instance( out = "filesystem_refs.go", package = "vfs", prefix = "Filesystem", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "Filesystem", }, @@ -87,6 +87,7 @@ go_library( "pathname.go", "permissions.go", "resolving_path.go", + "save_restore.go", "vfs.go", ], visibility = ["//pkg/sentry:internal"], @@ -99,6 +100,7 @@ go_library( "//pkg/gohacks", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go index 8f36c3e3b..a98aac52b 100644 --- a/pkg/sentry/vfs/epoll.go +++ b/pkg/sentry/vfs/epoll.go @@ -74,7 +74,7 @@ type epollInterestKey struct { // +stateify savable type epollInterest struct { // epoll is the owning EpollInstance. epoll is immutable. - epoll *EpollInstance + epoll *EpollInstance `state:"wait"` // key is the file to which this epollInterest applies. key is immutable. key epollInterestKey diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 183957ad8..546e445aa 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -183,7 +183,6 @@ func (fd *FileDescription) DecRef(ctx context.Context) { } fd.vd.DecRef(ctx) fd.flagsMu.Lock() - // TODO(gvisor.dev/issue/1663): We may need to unregister during save, as we do in VFS1. if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { fd.asyncHandler.Unregister(fd) } diff --git a/pkg/sentry/vfs/genericfstree/genericfstree.go b/pkg/sentry/vfs/genericfstree/genericfstree.go index 2d27d9d35..ba6e6ed49 100644 --- a/pkg/sentry/vfs/genericfstree/genericfstree.go +++ b/pkg/sentry/vfs/genericfstree/genericfstree.go @@ -71,7 +71,7 @@ func PrependPath(vfsroot vfs.VirtualDentry, mnt *vfs.Mount, d *Dentry, b *fspath if mnt == vfsroot.Mount() && &d.vfsd == vfsroot.Dentry() { return vfs.PrependPathAtVFSRootError{} } - if &d.vfsd == mnt.Root() { + if mnt != nil && &d.vfsd == mnt.Root() { return nil } if d.parent == nil { @@ -81,3 +81,12 @@ func PrependPath(vfsroot vfs.VirtualDentry, mnt *vfs.Mount, d *Dentry, b *fspath d = d.parent } } + +// DebugPathname returns a pathname to d relative to its filesystem root. +// DebugPathname does not correspond to any Linux function; it's used to +// generate dentry pathnames for debugging. +func DebugPathname(d *Dentry) string { + var b fspath.Builder + _ = PrependPath(vfs.VirtualDentry{}, nil, d, &b) + return b.String() +} diff --git a/pkg/sentry/vfs/lock.go b/pkg/sentry/vfs/lock.go index 55783d4eb..1ff202f2a 100644 --- a/pkg/sentry/vfs/lock.go +++ b/pkg/sentry/vfs/lock.go @@ -12,11 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package lock provides POSIX and BSD style file locking for VFS2 file -// implementations. -// -// The actual implementations can be found in the lock package under -// sentry/fs/lock. package vfs import ( diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 78f115bfa..d452d2cda 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" ) @@ -106,6 +107,9 @@ func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *Mount if opts.ReadOnly { mnt.setReadOnlyLocked(true) } + if refsvfs2.LeakCheckEnabled() { + refsvfs2.Register(mnt, "vfs.Mount") + } return mnt } @@ -489,26 +493,38 @@ func (mnt *Mount) IncRef() { // DecRef decrements mnt's reference count. func (mnt *Mount) DecRef(ctx context.Context) { - refs := atomic.AddInt64(&mnt.refs, -1) - if refs&^math.MinInt64 == 0 { // mask out MSB - var vd VirtualDentry - if mnt.parent() != nil { - mnt.vfs.mountMu.Lock() - mnt.vfs.mounts.seq.BeginWrite() - vd = mnt.vfs.disconnectLocked(mnt) - mnt.vfs.mounts.seq.EndWrite() - mnt.vfs.mountMu.Unlock() - } - if mnt.root != nil { - mnt.root.DecRef(ctx) - } - mnt.fs.DecRef(ctx) - if vd.Ok() { - vd.DecRef(ctx) + r := atomic.AddInt64(&mnt.refs, -1) + if r&^math.MinInt64 == 0 { // mask out MSB + if refsvfs2.LeakCheckEnabled() { + refsvfs2.Unregister(mnt, "vfs.Mount") } + mnt.destroy(ctx) } } +func (mnt *Mount) destroy(ctx context.Context) { + var vd VirtualDentry + if mnt.parent() != nil { + mnt.vfs.mountMu.Lock() + mnt.vfs.mounts.seq.BeginWrite() + vd = mnt.vfs.disconnectLocked(mnt) + mnt.vfs.mounts.seq.EndWrite() + mnt.vfs.mountMu.Unlock() + } + if mnt.root != nil { + mnt.root.DecRef(ctx) + } + mnt.fs.DecRef(ctx) + if vd.Ok() { + vd.DecRef(ctx) + } +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (mnt *Mount) LeakMessage() string { + return fmt.Sprintf("[vfs.Mount %p] reference count of %d instead of 0", mnt, atomic.LoadInt64(&mnt.refs)) +} + // DecRef decrements mntns' reference count. func (mntns *MountNamespace) DecRef(ctx context.Context) { vfs := mntns.root.fs.VirtualFilesystem() diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index b7d122d22..cb48c37a1 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -98,7 +98,6 @@ type mountTable struct { // length and cap in separate uint32s) for ~free. size uint64 - // FIXME(gvisor.dev/issue/1663): Slots need to be saved. slots unsafe.Pointer `state:"nosave"` // []mountSlot; never nil after Init } @@ -212,6 +211,26 @@ loop: } } +// Range calls f on each Mount in mt. If f returns false, Range stops iteration +// and returns immediately. +func (mt *mountTable) Range(f func(*Mount) bool) { + tcap := uintptr(1) << (mt.size & mtSizeOrderMask) + slotPtr := mt.slots + last := unsafe.Pointer(uintptr(mt.slots) + ((tcap - 1) * mountSlotBytes)) + for { + slot := (*mountSlot)(slotPtr) + if slot.value != nil { + if !f((*Mount)(slot.value)) { + return + } + } + if slotPtr == last { + return + } + slotPtr = unsafe.Pointer(uintptr(slotPtr) + mountSlotBytes) + } +} + // Insert inserts the given mount into mt. // // Preconditions: mt must not already contain a Mount with the same mount point diff --git a/pkg/sentry/vfs/save_restore.go b/pkg/sentry/vfs/save_restore.go new file mode 100644 index 000000000..46e50d55d --- /dev/null +++ b/pkg/sentry/vfs/save_restore.go @@ -0,0 +1,124 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// FilesystemImplSaveRestoreExtension is an optional extension to +// FilesystemImpl. +type FilesystemImplSaveRestoreExtension interface { + // PrepareSave prepares this filesystem for serialization. + PrepareSave(ctx context.Context) error + + // CompleteRestore completes restoration from checkpoint for this + // filesystem after deserialization. + CompleteRestore(ctx context.Context, opts CompleteRestoreOptions) error +} + +// PrepareSave prepares all filesystems for serialization. +func (vfs *VirtualFilesystem) PrepareSave(ctx context.Context) error { + failures := 0 + for fs := range vfs.getFilesystems() { + if ext, ok := fs.impl.(FilesystemImplSaveRestoreExtension); ok { + if err := ext.PrepareSave(ctx); err != nil { + ctx.Warningf("%T.PrepareSave failed: %v", fs.impl, err) + failures++ + } + } + fs.DecRef(ctx) + } + if failures != 0 { + return fmt.Errorf("%d filesystems failed to prepare for serialization", failures) + } + return nil +} + +// CompleteRestore completes restoration from checkpoint for all filesystems +// after deserialization. +func (vfs *VirtualFilesystem) CompleteRestore(ctx context.Context, opts *CompleteRestoreOptions) error { + failures := 0 + for fs := range vfs.getFilesystems() { + if ext, ok := fs.impl.(FilesystemImplSaveRestoreExtension); ok { + if err := ext.CompleteRestore(ctx, *opts); err != nil { + ctx.Warningf("%T.CompleteRestore failed: %v", fs.impl, err) + failures++ + } + } + fs.DecRef(ctx) + } + if failures != 0 { + return fmt.Errorf("%d filesystems failed to complete restore after deserialization", failures) + } + return nil +} + +// CompleteRestoreOptions contains options to +// VirtualFilesystem.CompleteRestore() and +// FilesystemImplSaveRestoreExtension.CompleteRestore(). +type CompleteRestoreOptions struct { + // If ValidateFileSizes is true, filesystem implementations backed by + // remote filesystems should verify that file sizes have not changed + // between checkpoint and restore. + ValidateFileSizes bool + + // If ValidateFileModificationTimestamps is true, filesystem + // implementations backed by remote filesystems should validate that file + // mtimes have not changed between checkpoint and restore. + ValidateFileModificationTimestamps bool +} + +// saveMounts is called by stateify. +func (vfs *VirtualFilesystem) saveMounts() []*Mount { + if atomic.LoadPointer(&vfs.mounts.slots) == nil { + // vfs.Init() was never called. + return nil + } + var mounts []*Mount + vfs.mounts.Range(func(mount *Mount) bool { + mounts = append(mounts, mount) + return true + }) + return mounts +} + +// loadMounts is called by stateify. +func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) { + if mounts == nil { + return + } + vfs.mounts.Init() + for _, mount := range mounts { + vfs.mounts.Insert(mount) + } +} + +func (mnt *Mount) afterLoad() { + if refsvfs2.LeakCheckEnabled() && atomic.LoadInt64(&mnt.refs) != 0 { + refsvfs2.Register(mnt, "vfs.Mount") + } +} + +// afterLoad is called by stateify. +func (epi *epollInterest) afterLoad() { + // Mark all epollInterests as ready after restore so that the next call to + // EpollInstance.ReadEvents() rechecks their readiness. + epi.Callback(nil) +} diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 38d2701d2..48d6252f7 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -71,7 +71,7 @@ type VirtualFilesystem struct { // points. // // mounts is analogous to Linux's mount_hashtable. - mounts mountTable + mounts mountTable `state:".([]*Mount)"` // mountpoints maps mount points to mounts at those points in all // namespaces. mountpoints is protected by mountMu. @@ -780,23 +780,27 @@ func (vfs *VirtualFilesystem) RemoveXattrAt(ctx context.Context, creds *auth.Cre // SyncAllFilesystems has the semantics of Linux's sync(2). func (vfs *VirtualFilesystem) SyncAllFilesystems(ctx context.Context) error { + var retErr error + for fs := range vfs.getFilesystems() { + if err := fs.impl.Sync(ctx); err != nil && retErr == nil { + retErr = err + } + fs.DecRef(ctx) + } + return retErr +} + +func (vfs *VirtualFilesystem) getFilesystems() map[*Filesystem]struct{} { fss := make(map[*Filesystem]struct{}) vfs.filesystemsMu.Lock() + defer vfs.filesystemsMu.Unlock() for fs := range vfs.filesystems { if !fs.TryIncRef() { continue } fss[fs] = struct{}{} } - vfs.filesystemsMu.Unlock() - var retErr error - for fs := range fss { - if err := fs.impl.Sync(ctx); err != nil && retErr == nil { - retErr = err - } - fs.DecRef(ctx) - } - return retErr + return fss } // MkdirAllAt recursively creates non-existent directories on the given path diff --git a/pkg/state/BUILD b/pkg/state/BUILD index 089b3bbef..92c51879b 100644 --- a/pkg/state/BUILD +++ b/pkg/state/BUILD @@ -4,19 +4,6 @@ load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) go_template_instance( - name = "pending_list", - out = "pending_list.go", - package = "state", - prefix = "pending", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*objectEncodeState", - "ElementMapper": "pendingMapper", - "Linker": "*pendingEntry", - }, -) - -go_template_instance( name = "deferred_list", out = "deferred_list.go", package = "state", @@ -83,7 +70,6 @@ go_library( "deferred_list.go", "encode.go", "encode_unsafe.go", - "pending_list.go", "state.go", "state_norace.go", "state_race.go", diff --git a/pkg/state/decode.go b/pkg/state/decode.go index 89467ca8e..e519ddeca 100644 --- a/pkg/state/decode.go +++ b/pkg/state/decode.go @@ -21,6 +21,7 @@ import ( "math" "reflect" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/state/wire" ) @@ -258,7 +259,7 @@ func (ds *decodeState) waitObject(ods *objectDecodeState, encoded wire.Object, c // For the purposes of this function, a child object is either a field within a // struct or an array element, with one such indirection per element in // path. The returned value may be an unexported field, so it may not be -// directly assignable. See unsafePointerTo. +// directly assignable. See decode_unsafe.go. func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value { // See wire.Ref.Dots. The path here is specified in reverse order. for i := len(path) - 1; i >= 0; i-- { @@ -519,9 +520,7 @@ func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, e // Normal assignment: authoritative only if no dots. v := ds.register(x, obj.Type().Elem()) - if v.IsValid() { - obj.Set(unsafePointerTo(v)) - } + obj.Set(reflectValueRWAddr(v)) case wire.Bool: obj.SetBool(bool(x)) case wire.Int: @@ -559,7 +558,7 @@ func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, e // contents will still be filled in later on. typ := reflect.ArrayOf(int(x.Capacity), obj.Type().Elem()) // The object type. v := ds.register(&x.Ref, typ) - obj.Set(v.Slice3(0, int(x.Length), int(x.Capacity))) + obj.Set(reflectValueRWSlice3(v, 0, int(x.Length), int(x.Capacity))) case *wire.Array: ds.decodeArray(ods, obj, x) case *wire.Struct: @@ -592,7 +591,7 @@ func (ds *decodeState) Load(obj reflect.Value) { ds.pending.PushBack(rootOds) // Read the number of objects. - lastID, object, err := ReadHeader(ds.r) + numObjects, object, err := ReadHeader(ds.r) if err != nil { Failf("header error: %w", err) } @@ -604,42 +603,44 @@ func (ds *decodeState) Load(obj reflect.Value) { var ( encoded wire.Object ods *objectDecodeState - id = objectID(1) + id objectID tid = typeID(1) ) if err := safely(func() { // Decode all objects in the stream. // - // Note that the structure of this decoding loop should match - // the raw decoding loop in printer.go. - for id <= objectID(lastID) { - // Unmarshal the object. + // Note that the structure of this decoding loop should match the raw + // decoding loop in state/pretty/pretty.printer.printStream(). + for i := uint64(0); i < numObjects; { + // Unmarshal either a type object or object ID. encoded = wire.Load(ds.r) - - // Is this a type object? Handle inline. - if wt, ok := encoded.(*wire.Type); ok { - ds.types.Register(wt) + switch we := encoded.(type) { + case *wire.Type: + ds.types.Register(we) tid++ encoded = nil continue + case wire.Uint: + id = objectID(we) + i++ + // Unmarshal and resolve the actual object. + encoded = wire.Load(ds.r) + ods = ds.lookup(id) + if ods != nil { + // Decode the object. + ds.decodeObject(ods, ods.obj, encoded) + } else { + // If an object hasn't had interest registered + // previously or isn't yet valid, we deferred + // decoding until interest is registered. + ds.deferred[id] = encoded + } + // For error handling. + ods = nil + encoded = nil + default: + Failf("wanted type or object ID, got %#v", encoded) } - - // Actually resolve the object. - ods = ds.lookup(id) - if ods != nil { - // Decode the object. - ds.decodeObject(ods, ods.obj, encoded) - } else { - // If an object hasn't had interest registered - // previously or isn't yet valid, we deferred - // decoding until interest is registered. - ds.deferred[id] = encoded - } - - // For error handling. - ods = nil - encoded = nil - id++ } }); err != nil { // Include as much information as we can, taking into account @@ -647,16 +648,25 @@ func (ds *decodeState) Load(obj reflect.Value) { if ods != nil { Failf("error decoding object ID %d (%T) from %#v: %w", id, ods.obj.Interface(), encoded, err) } else if encoded != nil { - Failf("lookup error decoding object ID %d from %#v: %w", id, encoded, err) + Failf("error decoding from %#v: %w", encoded, err) } else { Failf("general decoding error: %w", err) } } // Check if we have any deferred objects. + numDeferred := 0 for id, encoded := range ds.deferred { - // Shoud never happen, the graph was bogus. - Failf("still have deferred objects: one is ID %d, %#v", id, encoded) + numDeferred++ + if s, ok := encoded.(*wire.Struct); ok && s.TypeID != 0 { + typ := ds.types.LookupType(typeID(s.TypeID)) + log.Warningf("unused deferred object: ID %d, type %v", id, typ) + } else { + log.Warningf("unused deferred object: ID %d, %#v", id, encoded) + } + } + if numDeferred != 0 { + Failf("still had %d deferred objects", numDeferred) } // Scan and fire all callbacks. We iterate over the list of incomplete diff --git a/pkg/state/decode_unsafe.go b/pkg/state/decode_unsafe.go index d048f61a1..f1208e2a2 100644 --- a/pkg/state/decode_unsafe.go +++ b/pkg/state/decode_unsafe.go @@ -15,13 +15,62 @@ package state import ( + "fmt" "reflect" + "runtime" "unsafe" ) -// unsafePointerTo is logically equivalent to reflect.Value.Addr, but works on -// values representing unexported fields. This bypasses visibility, but not -// type safety. -func unsafePointerTo(obj reflect.Value) reflect.Value { +// reflectValueRWAddr is equivalent to obj.Addr(), except that the returned +// reflect.Value is usable in assignments even if obj was obtained by the use +// of unexported struct fields. +// +// Preconditions: obj.CanAddr(). +func reflectValueRWAddr(obj reflect.Value) reflect.Value { return reflect.NewAt(obj.Type(), unsafe.Pointer(obj.UnsafeAddr())) } + +// reflectValueRWSlice3 is equivalent to arr.Slice3(i, j, k), except that the +// returned reflect.Value is usable in assignments even if obj was obtained by +// the use of unexported struct fields. +// +// Preconditions: +// * arr.Kind() == reflect.Array. +// * i, j, k >= 0. +// * i <= j <= k <= arr.Len(). +func reflectValueRWSlice3(arr reflect.Value, i, j, k int) reflect.Value { + if arr.Kind() != reflect.Array { + panic(fmt.Sprintf("arr has kind %v, wanted %v", arr.Kind(), reflect.Array)) + } + if i < 0 || j < 0 || k < 0 { + panic(fmt.Sprintf("negative subscripts (%d, %d, %d)", i, j, k)) + } + if i > j { + panic(fmt.Sprintf("subscript i (%d) > j (%d)", i, j)) + } + if j > k { + panic(fmt.Sprintf("subscript j (%d) > k (%d)", j, k)) + } + if k > arr.Len() { + panic(fmt.Sprintf("subscript k (%d) > array length (%d)", k, arr.Len())) + } + + sliceTyp := reflect.SliceOf(arr.Type().Elem()) + if i == arr.Len() { + // By precondition, i == j == k == arr.Len(). + return reflect.MakeSlice(sliceTyp, 0, 0) + } + slh := reflect.SliceHeader{ + // reflect.Value.CanAddr() == false for arrays, so we need to get the + // address from the first element of the array. + Data: arr.Index(i).UnsafeAddr(), + Len: j - i, + Cap: k - i, + } + slobj := reflect.NewAt(sliceTyp, unsafe.Pointer(&slh)).Elem() + // Before slobj is constructed, arr holds the only pointer-typed pointer to + // the array since reflect.SliceHeader.Data is a uintptr, so arr must be + // kept alive. + runtime.KeepAlive(arr) + return slobj +} diff --git a/pkg/state/encode.go b/pkg/state/encode.go index 92fcad4e9..560e7c2a3 100644 --- a/pkg/state/encode.go +++ b/pkg/state/encode.go @@ -17,13 +17,14 @@ package state import ( "context" "reflect" + "sort" "gvisor.dev/gvisor/pkg/state/wire" ) // objectEncodeState the type and identity of an object occupying a memory // address range. This is the value type for addrSet, and the intrusive entry -// for the pending and deferred lists. +// for the deferred list. type objectEncodeState struct { // id is the assigned ID for this object. id objectID @@ -47,7 +48,6 @@ type objectEncodeState struct { // references may be updated directly and automatically. refs []*wire.Ref - pendingEntry deferredEntry } @@ -93,9 +93,15 @@ type encodeState struct { // serialized. pendingTypes []wire.Type - // pending is the list of objects to be serialized. Serialization does + // pending maps object IDs to objects to be serialized. Serialization does // not actually occur until the full object graph is computed. - pending pendingList + pending map[objectID]*objectEncodeState + + // encodedStructs maps reflect.Values representing structs to previous + // encodings of those structs. This is necessary to avoid duplicate calls + // to SaverLoader.StateSave() that may result in multiple calls to + // Sink.SaveValue() for a given field, resulting in object duplication. + encodedStructs map[reflect.Value]*wire.Struct // stats tracks time data. stats Stats @@ -189,7 +195,8 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) { // depending on this value knows there's nothing there. return } - if seg, _ := es.values.Find(addr); seg.Ok() { + seg, gap := es.values.Find(addr) + if seg.Ok() { // Ensure the map types match. existing := seg.Value() if existing.obj.Type() != obj.Type() { @@ -203,13 +210,20 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) { } // Record the map. + r := addrRange{addr, addr + 1} oes := &objectEncodeState{ id: es.nextID(), obj: obj, how: encodeMapAsValue, } - es.values.Add(addrRange{addr, addr + 1}, oes) - es.pending.PushBack(oes) + // Use Insert instead of InsertWithoutMergingUnchecked when race + // detection is enabled to get additional sanity-checking from Merge. + if !raceEnabled { + es.values.InsertWithoutMergingUnchecked(gap, r, oes) + } else { + es.values.Insert(gap, r, oes) + } + es.pending[oes.id] = oes es.deferred.PushBack(oes) // See above: no ref recording. @@ -245,7 +259,7 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) { obj: obj, } es.zeroValues[typ] = oes - es.pending.PushBack(oes) + es.pending[oes.id] = oes es.deferred.PushBack(oes) } @@ -258,86 +272,112 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) { size = 1 // See above. } - // Calculate the container. end := addr + size r := addrRange{addr, end} - if seg, _ := es.values.Find(addr); seg.Ok() { + seg := es.values.LowerBoundSegment(addr) + var ( + oes *objectEncodeState + gap addrGapIterator + ) + + // Does at least one previously-registered object overlap this one? + if seg.Ok() && seg.Start() < end { existing := seg.Value() - switch { - case seg.Start() == addr && seg.End() == end && obj.Type() == existing.obj.Type(): - // The object is a perfect match. Happy path. Avoid the - // traversal and just return directly. We don't need to - // encode the type information or any dots here. + + if seg.Range() == r && typ == existing.obj.Type() { + // This exact object is already registered. Avoid the traversal and + // just return directly. We don't need to encode the type + // information or any dots here. ref.Root = wire.Uint(existing.id) existing.refs = append(existing.refs, ref) return + } - case (seg.Start() < addr && seg.End() >= end) || (seg.Start() <= addr && seg.End() > end): - // The previously registered object is larger than - // this, no need to update. But we expect some - // traversal below. + if seg.Range().IsSupersetOf(r) && (seg.Range() != r || isSameSizeParent(existing.obj, typ)) { + // This object is contained within a previously-registered object. + // Perform traversal from the container to the new object. + ref.Root = wire.Uint(existing.id) + ref.Dots = traverse(existing.obj.Type(), typ, seg.Start(), addr) + ref.Type = es.findType(existing.obj.Type()) + existing.refs = append(existing.refs, ref) + return + } - case seg.Start() == addr && seg.End() == end: - if !isSameSizeParent(obj, existing.obj.Type()) { - break // Needs traversal. + // This object contains one or more previously-registered objects. + // Remove them and update existing references to use the new one. + oes := &objectEncodeState{ + // Reuse the root ID of the first contained element. + id: existing.id, + obj: obj, + } + type elementEncodeState struct { + addr uintptr + typ reflect.Type + refs []*wire.Ref + } + var ( + elems []elementEncodeState + gap addrGapIterator + ) + for { + // Each contained object should be completely contained within + // this one. + if raceEnabled && !r.IsSupersetOf(seg.Range()) { + Failf("containing object %#v does not contain existing object %#v", obj, existing.obj) } - fallthrough // Needs update. - - case (seg.Start() > addr && seg.End() <= end) || (seg.Start() >= addr && seg.End() < end): - // Update the object and redo the encoding. - old := existing.obj - existing.obj = obj + elems = append(elems, elementEncodeState{ + addr: seg.Start(), + typ: existing.obj.Type(), + refs: existing.refs, + }) + delete(es.pending, existing.id) es.deferred.Remove(existing) - es.deferred.PushBack(existing) - - // The previously registered object is superseded by - // this new object. We are guaranteed to not have any - // mergeable neighbours in this segment set. - if !raceEnabled { - seg.SetRangeUnchecked(r) - } else { - // Add extra paranoid. This will be statically - // removed at compile time unless a race build. - es.values.Remove(seg) - es.values.Add(r, existing) - seg = es.values.LowerBoundSegment(addr) + gap = es.values.Remove(seg) + seg = gap.NextSegment() + if !seg.Ok() || seg.Start() >= end { + break } - - // Compute the traversal required & update references. - dots := traverse(obj.Type(), old.Type(), addr, seg.Start()) - wt := es.findType(obj.Type()) - for _, ref := range existing.refs { + existing = seg.Value() + } + wt := es.findType(typ) + for _, elem := range elems { + dots := traverse(typ, elem.typ, addr, elem.addr) + for _, ref := range elem.refs { + ref.Root = wire.Uint(oes.id) ref.Dots = append(ref.Dots, dots...) ref.Type = wt } - default: - // There is a non-sensical overlap. - Failf("overlapping objects: [new object] %#v [existing object] %#v", obj, existing.obj) + oes.refs = append(oes.refs, elem.refs...) } - - // Compute the new reference, record and return it. - ref.Root = wire.Uint(existing.id) - ref.Dots = traverse(existing.obj.Type(), obj.Type(), seg.Start(), addr) - ref.Type = es.findType(obj.Type()) - existing.refs = append(existing.refs, ref) + // Finally register the new containing object. + if !raceEnabled { + es.values.InsertWithoutMergingUnchecked(gap, r, oes) + } else { + es.values.Insert(gap, r, oes) + } + es.pending[oes.id] = oes + es.deferred.PushBack(oes) + ref.Root = wire.Uint(oes.id) + oes.refs = append(oes.refs, ref) return } - // The only remaining case is a pointer value that doesn't overlap with - // any registered addresses. Create a new entry for it, and start - // tracking the first reference we just created. - oes := &objectEncodeState{ + // No existing object overlaps this one. Register a new object. + oes = &objectEncodeState{ id: es.nextID(), obj: obj, } + if seg.Ok() { + gap = seg.PrevGap() + } else { + gap = es.values.LastGap() + } if !raceEnabled { - es.values.AddWithoutMerging(r, oes) + es.values.InsertWithoutMergingUnchecked(gap, r, oes) } else { - // Merges should never happen. This is just enabled extra - // sanity checks because the Merge function below will panic. - es.values.Add(r, oes) + es.values.Insert(gap, r, oes) } - es.pending.PushBack(oes) + es.pending[oes.id] = oes es.deferred.PushBack(oes) ref.Root = wire.Uint(oes.id) oes.refs = append(oes.refs, ref) @@ -439,6 +479,14 @@ func (oe *objectEncoder) save(slot int, obj reflect.Value) { // encodeStruct encodes a composite object. func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) { + if s, ok := es.encodedStructs[obj]; ok { + *dest = s + return + } + s := &wire.Struct{} + *dest = s + es.encodedStructs[obj] = s + // Ensure that the obj is addressable. There are two cases when it is // not. First, is when this is dispatched via SaveValue. Second, when // this is a map key as a struct. Either way, we need to make a copy to @@ -449,10 +497,6 @@ func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) { obj = localObj.Elem() } - // Prepare the value. - s := &wire.Struct{} - *dest = s - // Look the type up in the database. te, ok := es.types.Lookup(obj.Type()) if te == nil { @@ -730,45 +774,43 @@ func (es *encodeState) Save(obj reflect.Value) { Failf("encoding error at object %#v: %w", oes.obj.Interface(), err) } - // Check that items are pending. - if es.pending.Front() == nil { + // Check that we have objects to serialize. + if len(es.pending) == 0 { Failf("pending is empty?") } - // Write the header with the number of objects. Note that there is no - // way that es.lastID could conflict with objectID, which would - // indicate that an impossibly large encoding. - if err := WriteHeader(es.w, uint64(es.lastID), true); err != nil { + // Write the header with the number of objects. + if err := WriteHeader(es.w, uint64(len(es.pending)), true); err != nil { Failf("error writing header: %w", err) } // Serialize all pending types and pending objects. Note that we don't // bother removing from this list as we walk it because that just // wastes time. It will not change after this point. - var id objectID if err := safely(func() { for _, wt := range es.pendingTypes { // Encode the type. wire.Save(es.w, &wt) } - for oes = es.pending.Front(); oes != nil; oes = oes.pendingEntry.Next() { - id++ // First object is 1. - if oes.id != id { - Failf("expected id %d, got %d", id, oes.id) - } - - // Marshall the object. + // Emit objects in ID order. + ids := make([]objectID, 0, len(es.pending)) + for id := range es.pending { + ids = append(ids, id) + } + sort.Slice(ids, func(i, j int) bool { + return ids[i] < ids[j] + }) + for _, id := range ids { + // Encode the id. + wire.Save(es.w, wire.Uint(id)) + // Marshal the object. + oes := es.pending[id] wire.Save(es.w, oes.encoded) } }); err != nil { // Include the object and the error. Failf("error serializing object %#v: %w", oes.encoded, err) } - - // Check what we wrote. - if id != es.lastID { - Failf("expected %d objects, wrote %d", es.lastID, id) - } } // objectFlag indicates that the length is a # of objects, rather than a raw @@ -797,11 +839,6 @@ func WriteHeader(w wire.Writer, length uint64, object bool) error { }) } -// pendingMapper is for the pending list. -type pendingMapper struct{} - -func (pendingMapper) linkerFor(oes *objectEncodeState) *pendingEntry { return &oes.pendingEntry } - // deferredMapper is for the deferred list. type deferredMapper struct{} diff --git a/pkg/state/pretty/pretty.go b/pkg/state/pretty/pretty.go index 887f453a9..c6e8bb31d 100644 --- a/pkg/state/pretty/pretty.go +++ b/pkg/state/pretty/pretty.go @@ -42,6 +42,7 @@ func (p *printer) formatRef(x *wire.Ref, graph uint64) string { buf.WriteString(typ) buf.WriteString(")(") buf.WriteString(baseRef) + buf.WriteString(")") for _, component := range x.Dots { switch v := component.(type) { case *wire.FieldName: @@ -53,7 +54,6 @@ func (p *printer) formatRef(x *wire.Ref, graph uint64) string { panic(fmt.Sprintf("unreachable: switch should be exhaustive, unhandled case %v", reflect.TypeOf(component))) } } - buf.WriteString(")") fullRef = buf.String() } if p.html { @@ -242,19 +242,22 @@ func (p *printer) printStream(w io.Writer, r wire.Reader) (err error) { // Note that this loop must match the general structure of the // loop in decode.go. But we don't register type information, // etc. and just print the raw structures. + type objectAndID struct { + id uint64 + obj wire.Object + } var ( tid uint64 = 1 - objects []wire.Object + objects []objectAndID ) - for oid := uint64(1); oid <= length; { - // Unmarshal the object. + for i := uint64(0); i < length; { + // Unmarshal either a type object or object ID. encoded := wire.Load(r) - - // Is this a type? - if typ, ok := encoded.(*wire.Type); ok { + switch we := encoded.(type) { + case *wire.Type: str, _ := p.format(graph, 0, encoded) tag := fmt.Sprintf("g%dt%d", graph, tid) - p.typeSpecs[tag] = typ + p.typeSpecs[tag] = we if p.html { // See below. tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">⚓</a>", tag, tag, tag) @@ -263,20 +266,22 @@ func (p *printer) printStream(w io.Writer, r wire.Reader) (err error) { return err } tid++ - continue + case wire.Uint: + // Unmarshal the actual object. + objects = append(objects, objectAndID{ + id: uint64(we), + obj: wire.Load(r), + }) + i++ + default: + return fmt.Errorf("wanted type or object ID, got %#v", encoded) } - - // Otherwise, it is a node. - objects = append(objects, encoded) - oid++ } - for i, encoded := range objects { - // oid starts at 1. - oid := i + 1 + for _, objAndID := range objects { // Format the node. - str, _ := p.format(graph, 0, encoded) - tag := fmt.Sprintf("g%dr%d", graph, oid) + str, _ := p.format(graph, 0, objAndID.obj) + tag := fmt.Sprintf("g%dr%d", graph, objAndID.id) if p.html { // Create a little tag with an anchor next to it for linking. tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">⚓</a>", tag, tag, tag) diff --git a/pkg/state/state.go b/pkg/state/state.go index acb629969..6b8540f03 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -90,10 +90,12 @@ func (e *ErrState) Unwrap() error { func Save(ctx context.Context, w wire.Writer, rootPtr interface{}) (Stats, error) { // Create the encoding state. es := encodeState{ - ctx: ctx, - w: w, - types: makeTypeEncodeDatabase(), - zeroValues: make(map[reflect.Type]*objectEncodeState), + ctx: ctx, + w: w, + types: makeTypeEncodeDatabase(), + zeroValues: make(map[reflect.Type]*objectEncodeState), + pending: make(map[objectID]*objectEncodeState), + encodedStructs: make(map[reflect.Value]*wire.Struct), } // Perform the encoding. diff --git a/pkg/state/tests/struct.go b/pkg/state/tests/struct.go index bd2c2b399..69143d194 100644 --- a/pkg/state/tests/struct.go +++ b/pkg/state/tests/struct.go @@ -54,12 +54,47 @@ type outerArray struct { } // +stateify savable +type outerSlice struct { + inner []inner +} + +// +stateify savable type inner struct { v int64 } // +stateify savable +type outerFieldValue struct { + inner innerFieldValue +} + +// +stateify savable +type innerFieldValue struct { + v int64 `state:".(*savedFieldValue)"` +} + +// +stateify savable +type savedFieldValue struct { + v int64 +} + +func (ifv *innerFieldValue) saveV() *savedFieldValue { + return &savedFieldValue{ifv.v} +} + +func (ifv *innerFieldValue) loadV(sfv *savedFieldValue) { + ifv.v = sfv.v +} + +// +stateify savable type system struct { v1 interface{} v2 interface{} } + +// +stateify savable +type system3 struct { + v1 interface{} + v2 interface{} + v3 interface{} +} diff --git a/pkg/state/tests/struct_test.go b/pkg/state/tests/struct_test.go index de9d17aa7..c91c2c032 100644 --- a/pkg/state/tests/struct_test.go +++ b/pkg/state/tests/struct_test.go @@ -15,6 +15,7 @@ package tests import ( + "math/rand" "testing" "gvisor.dev/gvisor/pkg/state" @@ -67,12 +68,23 @@ func TestRegisterTypeOnlyStruct(t *testing.T) { } func TestEmbeddedPointers(t *testing.T) { - var ( - ofs outerSame - of1 outerFieldFirst - of2 outerFieldSecond - oa outerArray - ) + // Give each int64 a random value to prevent Go from using + // runtime.staticuint64s, which confounds tests for struct duplication. + magic := func() int64 { + for { + n := rand.Int63() + if n < 0 || n > 255 { + return n + } + } + } + + ofs := outerSame{inner{magic()}} + of1 := outerFieldFirst{inner{magic()}, magic()} + of2 := outerFieldSecond{magic(), inner{magic()}} + oa := outerArray{[2]inner{{magic()}, {magic()}}} + osl := outerSlice{oa.inner[:]} + ofv := outerFieldValue{innerFieldValue{magic()}} runTestCases(t, false, "embedded-pointers", []interface{}{ system{&ofs, &ofs.inner}, @@ -85,5 +97,15 @@ func TestEmbeddedPointers(t *testing.T) { system{&oa, &oa.inner[1]}, system{&oa.inner[0], &oa}, system{&oa.inner[1], &oa}, + system3{&oa, &oa.inner[0], &oa.inner[1]}, + system3{&oa, &oa.inner[1], &oa.inner[0]}, + system3{&oa.inner[0], &oa, &oa.inner[1]}, + system3{&oa.inner[1], &oa, &oa.inner[0]}, + system3{&oa.inner[0], &oa.inner[1], &oa}, + system3{&oa.inner[1], &oa.inner[0], &oa}, + system{&oa, &osl}, + system{&osl, &oa}, + system{&ofv, &ofv.inner}, + system{&ofv.inner, &ofv}, }) } diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 12b061def..b196324c7 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -97,6 +97,9 @@ type testConnection struct { func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) { wq := &waiter.Queue{} ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + return nil, err + } entry, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&entry, waiter.EventOut) @@ -145,7 +148,9 @@ func TestCloseReader(t *testing.T) { defer close(done) c, err := l.Accept() if err != nil { - t.Fatalf("l.Accept() = %v", err) + t.Errorf("l.Accept() = %v", err) + // Cannot call Fatalf in goroutine. Just return from the goroutine. + return } // Give c.Read() a chance to block before closing the connection. @@ -416,7 +421,9 @@ func TestDeadlineChange(t *testing.T) { defer close(done) c, err := l.Accept() if err != nil { - t.Fatalf("l.Accept() = %v", err) + t.Errorf("l.Accept() = %v", err) + // Cannot call Fatalf in goroutine. Just return from the goroutine. + return } c.SetDeadline(time.Now().Add(time.Minute)) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 6f81b0164..530f2ae2f 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -205,7 +205,7 @@ func IPv4Options(want []byte) NetworkChecker { if !ok { t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0]) } - options := ip.Options() + options := []byte(ip.Options()) // cmp.Diff does not consider nil slices equal to empty slices, but we do. if len(want) == 0 && len(options) == 0 { return @@ -859,6 +859,21 @@ func ICMPv4Seq(want uint16) TransportChecker { } } +// ICMPv4Pointer creates a checker that checks the ICMPv4 Param Problem pointer. +func ICMPv4Pointer(want uint8) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + if got := icmpv4.Pointer(); got != want { + t.Fatalf("unexpected ICMP Param Problem pointer, got = %d, want = %d", got, want) + } + } +} + // ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum. // This assumes that the payload exactly makes up the rest of the slice. func ICMPv4Checksum() TransportChecker { @@ -953,6 +968,38 @@ func ICMPv6Code(want header.ICMPv6Code) TransportChecker { } } +// ICMPv6TypeSpecific creates a checker that checks the ICMPv6 TypeSpecific +// field. +func ICMPv6TypeSpecific(want uint32) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv6, ok := h.(header.ICMPv6) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) + } + if got := icmpv6.TypeSpecific(); got != want { + t.Fatalf("unexpected ICMP TypeSpecific, got = %d, want = %d", got, want) + } + } +} + +// ICMPv6Payload creates a checker that checks the payload in an ICMPv6 packet. +func ICMPv6Payload(want []byte) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv6, ok := h.(header.ICMPv6) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) + } + payload := icmpv6.Payload() + if diff := cmp.Diff(want, payload); diff != "" { + t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) + } + } +} + // NDP creates a checker that checks that the packet contains a valid NDP // message for type of ty, with potentially additional checks specified by // checkers. diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index 504408878..2f13dea6a 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -99,7 +99,8 @@ const ( // ICMP codes for ICMPv4 Time Exceeded messages as defined in RFC 792. const ( - ICMPv4TTLExceeded ICMPv4Code = 0 + ICMPv4TTLExceeded ICMPv4Code = 0 + ICMPv4ReassemblyTimeout ICMPv4Code = 1 ) // ICMP codes for ICMPv4 Destination Unreachable messages as defined in RFC 792. @@ -126,6 +127,12 @@ func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) } // SetCode sets the ICMP code field. func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) } +// Pointer returns the pointer field in a Parameter Problem packet. +func (b ICMPv4) Pointer() byte { return b[icmpv4PointerOffset] } + +// SetPointer sets the pointer field in a Parameter Problem packet. +func (b ICMPv4) SetPointer(c byte) { b[icmpv4PointerOffset] = c } + // Checksum is the ICMP checksum field. func (b ICMPv4) Checksum() uint16 { return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:]) diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 4c6e4be64..961b77628 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -16,6 +16,7 @@ package header import ( "encoding/binary" + "errors" "fmt" "gvisor.dev/gvisor/pkg/tcpip" @@ -38,7 +39,6 @@ import ( // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Options | Padding | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// const ( versIHL = 0 tos = 1 @@ -93,7 +93,7 @@ type IPv4Fields struct { DstAddr tcpip.Address } -// IPv4 represents an ipv4 header stored in a byte array. +// IPv4 is an IPv4 header. // Most of the methods of IPv4 access to the underlying slice without // checking the boundaries and could panic because of 'index out of range'. // Always call IsValid() to validate an instance of IPv4 before using other @@ -106,10 +106,13 @@ const ( IPv4MinimumSize = 20 // IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given - // that there are only 4 bits to represents the header length in 32-bit - // units, the header cannot exceed 15*4 = 60 bytes. + // that there are only 4 bits (max 0xF (15)) to represent the header length + // in 32-bit (4 byte) units, the header cannot exceed 15*4 = 60 bytes. IPv4MaximumHeaderSize = 60 + // IPv4MaximumOptionsSize is the largest size the IPv4 options can be. + IPv4MaximumOptionsSize = IPv4MaximumHeaderSize - IPv4MinimumSize + // IPv4MaximumPayloadSize is the maximum size of a valid IPv4 payload. // // Linux limits this to 65,515 octets (the max IP datagram size - the IPv4 @@ -130,7 +133,7 @@ const ( // IPv4ProtocolNumber is IPv4's network protocol number. IPv4ProtocolNumber tcpip.NetworkProtocolNumber = 0x0800 - // IPv4Version is the version of the ipv4 protocol. + // IPv4Version is the version of the IPv4 protocol. IPv4Version = 4 // IPv4AllSystems is the all systems IPv4 multicast address as per @@ -148,6 +151,13 @@ const ( // packet that every IPv4 capable host must be able to // process/reassemble. IPv4MinimumProcessableDatagramSize = 576 + + // IPv4MinimumMTU is the minimum MTU required by IPv4, per RFC 791, + // section 3.2: + // Every internet module must be able to forward a datagram of 68 octets + // without further fragmentation. This is because an internet header may be + // up to 60 octets, and the minimum fragment is 8 octets. + IPv4MinimumMTU = 68 ) // Flags that may be set in an IPv4 packet. @@ -191,14 +201,13 @@ func IPVersion(b []byte) int { // Internet Header Length is the length of the internet header in 32 // bit words, and thus points to the beginning of the data. Note that // the minimum value for a correct header is 5. -// const ( ipVersionShift = 4 ipIHLMask = 0x0f IPv4IHLStride = 4 ) -// HeaderLength returns the value of the "header length" field of the ipv4 +// HeaderLength returns the value of the "header length" field of the IPv4 // header. The length returned is in bytes. func (b IPv4) HeaderLength() uint8 { return (b[versIHL] & ipIHLMask) * IPv4IHLStride @@ -212,17 +221,17 @@ func (b IPv4) SetHeaderLength(hdrLen uint8) { b[versIHL] = (IPv4Version << ipVersionShift) | ((hdrLen / IPv4IHLStride) & ipIHLMask) } -// ID returns the value of the identifier field of the ipv4 header. +// ID returns the value of the identifier field of the IPv4 header. func (b IPv4) ID() uint16 { return binary.BigEndian.Uint16(b[id:]) } -// Protocol returns the value of the protocol field of the ipv4 header. +// Protocol returns the value of the protocol field of the IPv4 header. func (b IPv4) Protocol() uint8 { return b[protocol] } -// Flags returns the "flags" field of the ipv4 header. +// Flags returns the "flags" field of the IPv4 header. func (b IPv4) Flags() uint8 { return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13) } @@ -232,41 +241,44 @@ func (b IPv4) More() bool { return b.Flags()&IPv4FlagMoreFragments != 0 } -// TTL returns the "TTL" field of the ipv4 header. +// TTL returns the "TTL" field of the IPv4 header. func (b IPv4) TTL() uint8 { return b[ttl] } -// FragmentOffset returns the "fragment offset" field of the ipv4 header. +// FragmentOffset returns the "fragment offset" field of the IPv4 header. func (b IPv4) FragmentOffset() uint16 { return binary.BigEndian.Uint16(b[flagsFO:]) << 3 } -// TotalLength returns the "total length" field of the ipv4 header. +// TotalLength returns the "total length" field of the IPv4 header. func (b IPv4) TotalLength() uint16 { return binary.BigEndian.Uint16(b[IPv4TotalLenOffset:]) } -// Checksum returns the checksum field of the ipv4 header. +// Checksum returns the checksum field of the IPv4 header. func (b IPv4) Checksum() uint16 { return binary.BigEndian.Uint16(b[checksum:]) } -// SourceAddress returns the "source address" field of the ipv4 header. +// SourceAddress returns the "source address" field of the IPv4 header. func (b IPv4) SourceAddress() tcpip.Address { return tcpip.Address(b[srcAddr : srcAddr+IPv4AddressSize]) } -// DestinationAddress returns the "destination address" field of the ipv4 +// DestinationAddress returns the "destination address" field of the IPv4 // header. func (b IPv4) DestinationAddress() tcpip.Address { return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize]) } -// Options returns a a buffer holding the options. -func (b IPv4) Options() []byte { +// IPv4Options is a buffer that holds all the raw IP options. +type IPv4Options []byte + +// Options returns a buffer holding the options. +func (b IPv4) Options() IPv4Options { hdrLen := b.HeaderLength() - return b[options:hdrLen:hdrLen] + return IPv4Options(b[options:hdrLen:hdrLen]) } // TransportProtocol implements Network.TransportProtocol. @@ -279,17 +291,17 @@ func (b IPv4) Payload() []byte { return b[b.HeaderLength():][:b.PayloadLength()] } -// PayloadLength returns the length of the payload portion of the ipv4 packet. +// PayloadLength returns the length of the payload portion of the IPv4 packet. func (b IPv4) PayloadLength() uint16 { return b.TotalLength() - uint16(b.HeaderLength()) } -// TOS returns the "type of service" field of the ipv4 header. +// TOS returns the "type of service" field of the IPv4 header. func (b IPv4) TOS() (uint8, uint32) { return b[tos], 0 } -// SetTOS sets the "type of service" field of the ipv4 header. +// SetTOS sets the "type of service" field of the IPv4 header. func (b IPv4) SetTOS(v uint8, _ uint32) { b[tos] = v } @@ -299,18 +311,18 @@ func (b IPv4) SetTTL(v byte) { b[ttl] = v } -// SetTotalLength sets the "total length" field of the ipv4 header. +// SetTotalLength sets the "total length" field of the IPv4 header. func (b IPv4) SetTotalLength(totalLength uint16) { binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength) } -// SetChecksum sets the checksum field of the ipv4 header. +// SetChecksum sets the checksum field of the IPv4 header. func (b IPv4) SetChecksum(v uint16) { binary.BigEndian.PutUint16(b[checksum:], v) } // SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the -// ipv4 header. +// IPv4 header. func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) { v := (uint16(flags) << 13) | (offset >> 3) binary.BigEndian.PutUint16(b[flagsFO:], v) @@ -321,23 +333,23 @@ func (b IPv4) SetID(v uint16) { binary.BigEndian.PutUint16(b[id:], v) } -// SetSourceAddress sets the "source address" field of the ipv4 header. +// SetSourceAddress sets the "source address" field of the IPv4 header. func (b IPv4) SetSourceAddress(addr tcpip.Address) { copy(b[srcAddr:srcAddr+IPv4AddressSize], addr) } -// SetDestinationAddress sets the "destination address" field of the ipv4 +// SetDestinationAddress sets the "destination address" field of the IPv4 // header. func (b IPv4) SetDestinationAddress(addr tcpip.Address) { copy(b[dstAddr:dstAddr+IPv4AddressSize], addr) } -// CalculateChecksum calculates the checksum of the ipv4 header. +// CalculateChecksum calculates the checksum of the IPv4 header. func (b IPv4) CalculateChecksum() uint16 { return Checksum(b[:b.HeaderLength()], 0) } -// Encode encodes all the fields of the ipv4 header. +// Encode encodes all the fields of the IPv4 header. func (b IPv4) Encode(i *IPv4Fields) { b.SetHeaderLength(i.IHL) b[tos] = i.TOS @@ -351,7 +363,7 @@ func (b IPv4) Encode(i *IPv4Fields) { copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr) } -// EncodePartial updates the total length and checksum fields of ipv4 header, +// EncodePartial updates the total length and checksum fields of IPv4 header, // taking in the partial checksum, which is the checksum of the header without // the total length and checksum fields. It is useful in cases when similar // packets are produced. @@ -398,3 +410,424 @@ func IsV4LoopbackAddress(addr tcpip.Address) bool { } return addr[0] == 0x7f } + +// ========================= Options ========================== + +// An IPv4OptionType can hold the valuse for the Type in an IPv4 option. +type IPv4OptionType byte + +// These constants are needed to identify individual options in the option list. +// While RFC 791 (page 31) says "Every internet module must be able to act on +// every option." This has not generally been adhered to and some options have +// very low rates of support. We do not support options other than those shown +// below. + +const ( + // IPv4OptionListEndType is the option type for the End Of Option List + // option. Anything following is ignored. + IPv4OptionListEndType IPv4OptionType = 0 + + // IPv4OptionNOPType is the No-Operation option. May appear between other + // options and may appear multiple times. + IPv4OptionNOPType IPv4OptionType = 1 + + // IPv4OptionRecordRouteType is used by each router on the path of the packet + // to record its path. It is carried over to an Echo Reply. + IPv4OptionRecordRouteType IPv4OptionType = 7 + + // IPv4OptionTimestampType is the option type for the Timestamp option. + IPv4OptionTimestampType IPv4OptionType = 68 + + // ipv4OptionTypeOffset is the offset in an option of its type field. + ipv4OptionTypeOffset = 0 + + // IPv4OptionLengthOffset is the offset in an option of its length field. + IPv4OptionLengthOffset = 1 +) + +// Potential errors when parsing generic IP options. +var ( + ErrIPv4OptZeroLength = errors.New("zero length IP option") + ErrIPv4OptDuplicate = errors.New("duplicate IP option") + ErrIPv4OptInvalid = errors.New("invalid IP option") + ErrIPv4OptMalformed = errors.New("malformed IP option") + ErrIPv4OptionTruncated = errors.New("truncated IP option") + ErrIPv4OptionAddress = errors.New("bad IP option address") +) + +// IPv4Option is an interface representing various option types. +type IPv4Option interface { + // Type returns the type identifier of the option. + Type() IPv4OptionType + + // Size returns the size of the option in bytes. + Size() uint8 + + // Contents returns a slice holding the contents of the option. + Contents() []byte +} + +var _ IPv4Option = (*IPv4OptionGeneric)(nil) + +// IPv4OptionGeneric is an IPv4 Option of unknown type. +type IPv4OptionGeneric []byte + +// Type implements IPv4Option. +func (o *IPv4OptionGeneric) Type() IPv4OptionType { + return IPv4OptionType((*o)[ipv4OptionTypeOffset]) +} + +// Size implements IPv4Option. +func (o *IPv4OptionGeneric) Size() uint8 { return uint8(len(*o)) } + +// Contents implements IPv4Option. +func (o *IPv4OptionGeneric) Contents() []byte { return []byte(*o) } + +// IPv4OptionIterator is an iterator pointing to a specific IP option +// at any point of time. It also holds information as to a new options buffer +// that we are building up to hand back to the caller. +type IPv4OptionIterator struct { + options IPv4Options + // ErrCursor is where we are while parsing options. It is exported as any + // resulting ICMP packet is supposed to have a pointer to the byte within + // the IP packet where the error was detected. + ErrCursor uint8 + nextErrCursor uint8 + newOptions [IPv4MaximumOptionsSize]byte + writePoint int +} + +// MakeIterator sets up and returns an iterator of options. It also sets up the +// building of a new option set. +func (o IPv4Options) MakeIterator() IPv4OptionIterator { + return IPv4OptionIterator{ + options: o, + nextErrCursor: IPv4MinimumSize, + } +} + +// RemainingBuffer returns the remaining (unused) part of the new option buffer, +// into which a new option may be written. +func (i *IPv4OptionIterator) RemainingBuffer() IPv4Options { + return IPv4Options(i.newOptions[i.writePoint:]) +} + +// ConsumeBuffer marks a portion of the new buffer as used. +func (i *IPv4OptionIterator) ConsumeBuffer(size int) { + i.writePoint += size +} + +// PushNOPOrEnd puts one of the single byte options onto the new options. +// Only values 0 or 1 (ListEnd or NOP) are valid input. +func (i *IPv4OptionIterator) PushNOPOrEnd(val IPv4OptionType) { + if val > IPv4OptionNOPType { + panic(fmt.Sprintf("invalid option type %d pushed onto option build buffer", val)) + } + i.newOptions[i.writePoint] = byte(val) + i.writePoint++ +} + +// Finalize returns the completed replacement options buffer padded +// as needed. +func (i *IPv4OptionIterator) Finalize() IPv4Options { + // RFC 791 page 31 says: + // The options might not end on a 32-bit boundary. The internet header + // must be filled out with octets of zeros. The first of these would + // be interpreted as the end-of-options option, and the remainder as + // internet header padding. + // Since the buffer is already zero filled we just need to step the write + // pointer up to the next multiple of 4. + options := IPv4Options(i.newOptions[:(i.writePoint+0x3) & ^0x3]) + // Poison the write pointer. + i.writePoint = len(i.newOptions) + return options +} + +// Next returns the next IP option in the buffer/list of IP options. +// It returns +// - A slice of bytes holding the next option or nil if there is error. +// - A boolean which is true if parsing of all the options is complete. +// - An error which is non-nil if an error condition was encountered. +func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) { + // The opts slice gets shorter as we process the options. When we have no + // bytes left we are done. + if len(i.options) == 0 { + return nil, true, nil + } + + i.ErrCursor = i.nextErrCursor + + optType := IPv4OptionType(i.options[ipv4OptionTypeOffset]) + + if optType == IPv4OptionNOPType || optType == IPv4OptionListEndType { + optionBody := i.options[:1] + i.options = i.options[1:] + i.nextErrCursor = i.ErrCursor + 1 + retval := IPv4OptionGeneric(optionBody) + return &retval, false, nil + } + + // There are no more single byte options defined. All the rest have a length + // field so we need to sanity check it. + if len(i.options) == 1 { + return nil, true, ErrIPv4OptMalformed + } + + optLen := i.options[IPv4OptionLengthOffset] + + if optLen == 0 { + i.ErrCursor++ + return nil, true, ErrIPv4OptZeroLength + } + + if optLen == 1 { + i.ErrCursor++ + return nil, true, ErrIPv4OptMalformed + } + + if optLen > uint8(len(i.options)) { + i.ErrCursor++ + return nil, true, ErrIPv4OptionTruncated + } + + optionBody := i.options[:optLen] + i.nextErrCursor = i.ErrCursor + optLen + i.options = i.options[optLen:] + + // Check the length of some option types that we know. + switch optType { + case IPv4OptionTimestampType: + if optLen < IPv4OptionTimestampHdrLength { + i.ErrCursor++ + return nil, true, ErrIPv4OptMalformed + } + retval := IPv4OptionTimestamp(optionBody) + return &retval, false, nil + + case IPv4OptionRecordRouteType: + if optLen < IPv4OptionRecordRouteHdrLength { + i.ErrCursor++ + return nil, true, ErrIPv4OptMalformed + } + retval := IPv4OptionRecordRoute(optionBody) + return &retval, false, nil + } + retval := IPv4OptionGeneric(optionBody) + return &retval, false, nil +} + +// +// IP Timestamp option - RFC 791 page 22. +// +--------+--------+--------+--------+ +// |01000100| length | pointer|oflw|flg| +// +--------+--------+--------+--------+ +// | internet address | +// +--------+--------+--------+--------+ +// | timestamp | +// +--------+--------+--------+--------+ +// | ... | +// +// Type = 68 +// +// The Option Length is the number of octets in the option counting +// the type, length, pointer, and overflow/flag octets (maximum +// length 40). +// +// The Pointer is the number of octets from the beginning of this +// option to the end of timestamps plus one (i.e., it points to the +// octet beginning the space for next timestamp). The smallest +// legal value is 5. The timestamp area is full when the pointer +// is greater than the length. +// +// The Overflow (oflw) [4 bits] is the number of IP modules that +// cannot register timestamps due to lack of space. +// +// The Flag (flg) [4 bits] values are +// +// 0 -- time stamps only, stored in consecutive 32-bit words, +// +// 1 -- each timestamp is preceded with internet address of the +// registering entity, +// +// 3 -- the internet address fields are prespecified. An IP +// module only registers its timestamp if it matches its own +// address with the next specified internet address. +// +// Timestamps are defined in RFC 791 page 22 as milliseconds since midnight UTC. +// +// The Timestamp is a right-justified, 32-bit timestamp in +// milliseconds since midnight UT. If the time is not available in +// milliseconds or cannot be provided with respect to midnight UT +// then any time may be inserted as a timestamp provided the high +// order bit of the timestamp field is set to one to indicate the +// use of a non-standard value. + +// IPv4OptTSFlags sefines the values expected in the Timestamp +// option Flags field. +type IPv4OptTSFlags uint8 + +// +// Timestamp option specific related constants. +const ( + // IPv4OptionTimestampHdrLength is the length of the timestamp option header. + IPv4OptionTimestampHdrLength = 4 + + // IPv4OptionTimestampSize is the size of an IP timestamp. + IPv4OptionTimestampSize = 4 + + // IPv4OptionTimestampWithAddrSize is the size of an IP timestamp + Address. + IPv4OptionTimestampWithAddrSize = IPv4AddressSize + IPv4OptionTimestampSize + + // IPv4OptionTimestampMaxSize is limited by space for options + IPv4OptionTimestampMaxSize = IPv4MaximumOptionsSize + + // IPv4OptionTimestampOnlyFlag is a flag indicating that only timestamp + // is present. + IPv4OptionTimestampOnlyFlag IPv4OptTSFlags = 0 + + // IPv4OptionTimestampWithIPFlag is a flag indicating that both timestamps and + // IP are present. + IPv4OptionTimestampWithIPFlag IPv4OptTSFlags = 1 + + // IPv4OptionTimestampWithPredefinedIPFlag is a flag indicating that + // predefined IP is present. + IPv4OptionTimestampWithPredefinedIPFlag IPv4OptTSFlags = 3 +) + +// ipv4TimestampTime provides the current time as specified in RFC 791. +func ipv4TimestampTime(clock tcpip.Clock) uint32 { + const millisecondsPerDay = 24 * 3600 * 1000 + const nanoPerMilli = 1000000 + return uint32((clock.NowNanoseconds() / nanoPerMilli) % millisecondsPerDay) +} + +// IP Timestamp option fields. +const ( + // IPv4OptTSPointerOffset is the offset of the Timestamp pointer field. + IPv4OptTSPointerOffset = 2 + + // IPv4OptTSPointerOffset is the offset of the combined Flag and Overflow + // fields, (each being 4 bits). + IPv4OptTSOFLWAndFLGOffset = 3 + // These constants define the sub byte fields of the Flag and OverFlow field. + ipv4OptionTimestampOverflowshift = 4 + ipv4OptionTimestampFlagsMask byte = 0x0f +) + +var _ IPv4Option = (*IPv4OptionTimestamp)(nil) + +// IPv4OptionTimestamp is a Timestamp option from RFC 791. +type IPv4OptionTimestamp []byte + +// Type implements IPv4Option.Type(). +func (ts *IPv4OptionTimestamp) Type() IPv4OptionType { return IPv4OptionTimestampType } + +// Size implements IPv4Option. +func (ts *IPv4OptionTimestamp) Size() uint8 { return uint8(len(*ts)) } + +// Contents implements IPv4Option. +func (ts *IPv4OptionTimestamp) Contents() []byte { return []byte(*ts) } + +// Pointer returns the pointer field in the IP Timestamp option. +func (ts *IPv4OptionTimestamp) Pointer() uint8 { + return (*ts)[IPv4OptTSPointerOffset] +} + +// Flags returns the flags field in the IP Timestamp option. +func (ts *IPv4OptionTimestamp) Flags() IPv4OptTSFlags { + return IPv4OptTSFlags((*ts)[IPv4OptTSOFLWAndFLGOffset] & ipv4OptionTimestampFlagsMask) +} + +// Overflow returns the Overflow field in the IP Timestamp option. +func (ts *IPv4OptionTimestamp) Overflow() uint8 { + return (*ts)[IPv4OptTSOFLWAndFLGOffset] >> ipv4OptionTimestampOverflowshift +} + +// IncOverflow increments the Overflow field in the IP Timestamp option. It +// returns the incremented value. If the return value is 0 then the field +// overflowed. +func (ts *IPv4OptionTimestamp) IncOverflow() uint8 { + (*ts)[IPv4OptTSOFLWAndFLGOffset] += 1 << ipv4OptionTimestampOverflowshift + return ts.Overflow() +} + +// UpdateTimestamp updates the fields of the next free timestamp slot. +func (ts *IPv4OptionTimestamp) UpdateTimestamp(addr tcpip.Address, clock tcpip.Clock) { + slot := (*ts)[ts.Pointer()-1:] + + switch ts.Flags() { + case IPv4OptionTimestampOnlyFlag: + binary.BigEndian.PutUint32(slot, ipv4TimestampTime(clock)) + (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampSize + case IPv4OptionTimestampWithIPFlag: + if n := copy(slot, addr); n != IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IPv4AddressSize)) + } + binary.BigEndian.PutUint32(slot[IPv4AddressSize:], ipv4TimestampTime(clock)) + (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampWithAddrSize + case IPv4OptionTimestampWithPredefinedIPFlag: + if tcpip.Address(slot[:IPv4AddressSize]) == addr { + binary.BigEndian.PutUint32(slot[IPv4AddressSize:], ipv4TimestampTime(clock)) + (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampWithAddrSize + } + } +} + +// RecordRoute option specific related constants. +// +// from RFC 791 page 20: +// Record Route +// +// +--------+--------+--------+---------//--------+ +// |00000111| length | pointer| route data | +// +--------+--------+--------+---------//--------+ +// Type=7 +// +// The record route option provides a means to record the route of +// an internet datagram. +// +// The option begins with the option type code. The second octet +// is the option length which includes the option type code and the +// length octet, the pointer octet, and length-3 octets of route +// data. The third octet is the pointer into the route data +// indicating the octet which begins the next area to store a route +// address. The pointer is relative to this option, and the +// smallest legal value for the pointer is 4. +const ( + // IPv4OptionRecordRouteHdrLength is the length of the Record Route option + // header. + IPv4OptionRecordRouteHdrLength = 3 + + // IPv4OptRRPointerOffset is the offset to the pointer field in an RR + // option, which points to the next free slot in the list of addresses. + IPv4OptRRPointerOffset = 2 +) + +var _ IPv4Option = (*IPv4OptionRecordRoute)(nil) + +// IPv4OptionRecordRoute is an IPv4 RecordRoute option defined by RFC 791. +type IPv4OptionRecordRoute []byte + +// Pointer returns the pointer field in the IP RecordRoute option. +func (rr *IPv4OptionRecordRoute) Pointer() uint8 { + return (*rr)[IPv4OptRRPointerOffset] +} + +// StoreAddress stores the given IPv4 address into the next free slot. +func (rr *IPv4OptionRecordRoute) StoreAddress(addr tcpip.Address) { + start := rr.Pointer() - 1 // A one based number. + // start and room checked by caller. + if n := copy((*rr)[start:], addr); n != IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IPv4AddressSize)) + } + (*rr)[IPv4OptRRPointerOffset] += IPv4AddressSize +} + +// Type implements IPv4Option. +func (rr *IPv4OptionRecordRoute) Type() IPv4OptionType { return IPv4OptionRecordRouteType } + +// Size implements IPv4Option. +func (rr *IPv4OptionRecordRoute) Size() uint8 { return uint8(len(*rr)) } + +// Contents implements IPv4Option. +func (rr *IPv4OptionRecordRoute) Contents() []byte { return []byte(*rr) } diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index c5d8a3456..09cb153b1 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -101,8 +101,10 @@ const ( // The address is ff02::2. IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460, - // section 5. + // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 8200, + // section 5: + // IPv6 requires that every link in the Internet have an MTU of 1280 octets + // or greater. This is known as the IPv6 minimum link MTU. IPv6MinimumMTU = 1280 // IPv6Loopback is the IPv6 Loopback address. diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go index dc239a0d0..2777f1411 100644 --- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go +++ b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go @@ -470,6 +470,7 @@ func TestConcurrentReaderWriter(t *testing.T) { const count = 1000000 var wg sync.WaitGroup + defer wg.Wait() wg.Add(1) go func() { defer wg.Done() @@ -489,30 +490,23 @@ func TestConcurrentReaderWriter(t *testing.T) { } }() - wg.Add(1) - go func() { - defer wg.Done() - runtime.Gosched() - for i := 0; i < count; i++ { - n := 1 + rr.Intn(80) - rb := rx.Pull() - for rb == nil { - rb = rx.Pull() - } + for i := 0; i < count; i++ { + n := 1 + rr.Intn(80) + rb := rx.Pull() + for rb == nil { + rb = rx.Pull() + } - if n != len(rb) { - t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n) - } + if n != len(rb) { + t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n) + } - for j := range rb { - if v := byte(rr.Intn(256)); v != rb[j] { - t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v) - } + for j := range rb { + if v := byte(rr.Intn(256)); v != rb[j] { + t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v) } - - rx.Flush() } - }() - wg.Wait() + rx.Flush() + } } diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index 0243424f6..86f14db76 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "tun_endpoint_refs.go", package = "tun", prefix = "tunEndpoint", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "tunEndpoint", }, @@ -28,6 +28,7 @@ go_library( "//pkg/context", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sync", "//pkg/syserror", "//pkg/tcpip", diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index f94491026..cda6328a2 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -150,7 +150,6 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE // 2. Creating a new NIC. id := tcpip.NICID(s.UniqueID()) - // TODO(gvisor.dev/1486): enable leak check for tunEndpoint. endpoint := &tunEndpoint{ Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""), stack: s, @@ -158,6 +157,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE name: name, isTap: prefix == "tap", } + endpoint.EnableLeakCheck() endpoint.Endpoint.LinkEPCapabilities = linkCaps if endpoint.name == "" { endpoint.name = fmt.Sprintf("%s%d", prefix, id) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index b40dde96b..8a6bcfc2c 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -30,5 +30,6 @@ go_test( "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7df77c66e..a79379abb 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -18,6 +18,7 @@ package arp import ( + "fmt" "sync/atomic" "gvisor.dev/gvisor/pkg/tcpip" @@ -150,28 +151,36 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { remoteAddr := tcpip.Address(h.ProtocolAddressSender()) remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - e.nud.HandleProbe(remoteAddr, localAddr, ProtocolNumber, remoteLinkAddr, e.protocol) + e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol) } - // As per RFC 826, under Packet Reception: - // Swap hardware and protocol fields, putting the local hardware and - // protocol addresses in the sender fields. - // - // Send the packet to the (new) target hardware address on the same - // hardware on which the request was received. - origSender := h.HardwareAddressSender() - r.RemoteLinkAddress = tcpip.LinkAddress(origSender) respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize, }) packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize)) packet.SetIPv4OverEthernet() packet.SetOp(header.ARPReply) - copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:]) - copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) - copy(packet.HardwareAddressTarget(), origSender) - copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) - _ = e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, respPkt) + // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a + // link address. + _ = copy(packet.HardwareAddressSender(), e.nic.LinkAddress()) + if n := copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + origSender := h.HardwareAddressSender() + if n := copy(packet.HardwareAddressTarget(), origSender); n != header.EthernetAddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.EthernetAddressSize)) + } + if n := copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + + // As per RFC 826, under Packet Reception: + // Swap hardware and protocol fields, putting the local hardware and + // protocol addresses in the sender fields. + // + // Send the packet to the (new) target hardware address on the same + // hardware on which the request was received. + _ = e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt) case header.ARPReply: addr := tcpip.Address(h.ProtocolAddressSender()) @@ -199,6 +208,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // protocol implements stack.NetworkProtocol and stack.LinkAddressResolver. type protocol struct { + stack *stack.Stack } func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber } @@ -227,26 +237,44 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. -func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { - r := &stack.Route{ - NetProto: ProtocolNumber, - RemoteLinkAddress: remoteLinkAddr, +func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error { + if len(remoteLinkAddr) == 0 { + remoteLinkAddr = header.EthernetBroadcastAddress } - if len(r.RemoteLinkAddress) == 0 { - r.RemoteLinkAddress = header.EthernetBroadcastAddress + + nicID := nic.ID() + if len(localAddr) == 0 { + addr, err := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) + if err != nil { + return err + } + + if len(addr.Address) == 0 { + return tcpip.ErrNetworkUnreachable + } + + localAddr = addr.Address + } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 { + return tcpip.ErrBadLocalAddress } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.ARPSize, + ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.ARPSize, }) h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) + pkt.NetworkProtocolNumber = ProtocolNumber h.SetIPv4OverEthernet() h.SetOp(header.ARPRequest) - copy(h.HardwareAddressSender(), linkEP.LinkAddress()) - copy(h.ProtocolAddressSender(), localAddr) - copy(h.ProtocolAddressTarget(), addr) - - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) + // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a + // link address. + _ = copy(h.HardwareAddressSender(), nic.LinkAddress()) + if n := copy(h.ProtocolAddressSender(), localAddr); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + return nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt) } // ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. @@ -286,6 +314,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu // Note, to make sure that the ARP endpoint receives ARP packets, the "arp" // address must be added to every NIC that should respond to ARP requests. See // ProtocolAddress for more details. -func NewProtocol(*stack.Stack) stack.NetworkProtocol { - return &protocol{} +func NewProtocol(s *stack.Stack) stack.NetworkProtocol { + return &protocol{stack: s} } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 626af975a..bf1292bb8 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -22,6 +22,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -78,13 +79,11 @@ func (t eventType) String() string { type eventInfo struct { eventType eventType nicID tcpip.NICID - addr tcpip.Address - linkAddr tcpip.LinkAddress - state stack.NeighborState + entry stack.NeighborEntry } func (e eventInfo) String() string { - return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.eventType, e.nicID, e.addr, e.linkAddr, e.state) + return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry) } // arpDispatcher implements NUDDispatcher to validate the dispatching of @@ -96,35 +95,29 @@ type arpDispatcher struct { var _ stack.NUDDispatcher = (*arpDispatcher)(nil) -func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) { +func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) { e := eventInfo{ eventType: entryAdded, nicID: nicID, - addr: addr, - linkAddr: linkAddr, - state: state, + entry: entry, } d.C <- e } -func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) { +func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) { e := eventInfo{ eventType: entryChanged, nicID: nicID, - addr: addr, - linkAddr: linkAddr, - state: state, + entry: entry, } d.C <- e } -func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) { +func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) { e := eventInfo{ eventType: entryRemoved, nicID: nicID, - addr: addr, - linkAddr: linkAddr, - state: state, + entry: entry, } d.C <- e } @@ -132,7 +125,7 @@ func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error { select { case got := <-d.C: - if diff := cmp.Diff(got, want, cmp.AllowUnexported(got)); diff != "" { + if diff := cmp.Diff(got, want, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" { return fmt.Errorf("got invalid event (-got +want):\n%s", diff) } case <-ctx.Done(): @@ -373,9 +366,11 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { wantEvent := eventInfo{ eventType: entryAdded, nicID: nicID, - addr: test.senderAddr, - linkAddr: tcpip.LinkAddress(test.senderLinkAddr), - state: stack.Stale, + entry: stack.NeighborEntry{ + Addr: test.senderAddr, + LinkAddr: tcpip.LinkAddress(test.senderLinkAddr), + State: stack.Stale, + }, } if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil { t.Fatal(err) @@ -404,9 +399,6 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { if got, want := neigh.LinkAddr, test.senderLinkAddr; got != want { t.Errorf("got neighbor LinkAddr = %s, want = %s", got, want) } - if got, want := neigh.LocalAddr, stackAddr; got != want { - t.Errorf("got neighbor LocalAddr = %s, want = %s", got, want) - } if got, want := neigh.State, stack.Stale; got != want { t.Errorf("got neighbor State = %s, want = %s", got, want) } @@ -423,43 +415,164 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { } } +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct { + stack.LinkEndpoint + + nicID tcpip.NICID +} + +func (t *testInterface) ID() tcpip.NICID { + return t.nicID +} + +func (*testInterface) IsLoopback() bool { + return false +} + +func (*testInterface) Name() string { + return "" +} + +func (*testInterface) Enabled() bool { + return true +} + +func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + r := stack.Route{ + NetProto: protocol, + RemoteLinkAddress: remoteLinkAddr, + } + return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) +} + func TestLinkAddressRequest(t *testing.T) { + const nicID = 1 + + testAddr := tcpip.Address([]byte{1, 2, 3, 4}) + tests := []struct { name string + nicAddr tcpip.Address + localAddr tcpip.Address remoteLinkAddr tcpip.LinkAddress - expectLinkAddr tcpip.LinkAddress + + expectedErr *tcpip.Error + expectedLocalAddr tcpip.Address + expectedRemoteLinkAddr tcpip.LinkAddress }{ { - name: "Unicast", + name: "Unicast", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: remoteLinkAddr, + }, + { + name: "Multicast", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: "", + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + }, + { + name: "Unicast with unspecified source", + nicAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: remoteLinkAddr, + }, + { + name: "Multicast with unspecified source", + nicAddr: stackAddr, + remoteLinkAddr: "", + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + }, + { + name: "Unicast with unassigned address", + localAddr: testAddr, + remoteLinkAddr: remoteLinkAddr, + expectedErr: tcpip.ErrBadLocalAddress, + }, + { + name: "Multicast with unassigned address", + localAddr: testAddr, + remoteLinkAddr: "", + expectedErr: tcpip.ErrBadLocalAddress, + }, + { + name: "Unicast with no local address available", remoteLinkAddr: remoteLinkAddr, - expectLinkAddr: remoteLinkAddr, + expectedErr: tcpip.ErrNetworkUnreachable, }, { - name: "Multicast", + name: "Multicast with no local address available", remoteLinkAddr: "", - expectLinkAddr: header.EthernetBroadcastAddress, + expectedErr: tcpip.ErrNetworkUnreachable, }, } for _, test := range tests { - p := arp.NewProtocol(nil) - linkRes, ok := p.(stack.LinkAddressResolver) - if !ok { - t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") - } + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + }) + p := s.NetworkProtocolInstance(arp.ProtocolNumber) + linkRes, ok := p.(stack.LinkAddressResolver) + if !ok { + t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") + } - linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) - if err := linkRes.LinkAddressRequest(stackAddr, remoteAddr, test.remoteLinkAddr, linkEP); err != nil { - t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr, remoteAddr, test.remoteLinkAddr, err) - } + linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) + if err := s.CreateNIC(nicID, linkEP); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } - pkt, ok := linkEP.Read() - if !ok { - t.Fatal("expected to send a link address request") - } + if len(test.nicAddr) != 0 { + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err) + } + } - if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want) - } + // We pass a test network interface to LinkAddressRequest with the same + // NIC ID and link endpoint used by the NIC we created earlier so that we + // can mock a link address request and observe the packets sent to the + // link endpoint even though the stack uses the real NIC to validate the + // local address. + if err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr { + t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) + } + + if test.expectedErr != nil { + return + } + + pkt, ok := linkEP.Read() + if !ok { + t.Fatal("expected to send a link address request") + } + + if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) + } + + rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader())) + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { + t.Errorf("got HardwareAddressSender = %s, want = %s", got, stackLinkAddr) + } + if got := tcpip.Address(rep.ProtocolAddressSender()); got != test.expectedLocalAddr { + t.Errorf("got ProtocolAddressSender = %s, want = %s", got, test.expectedLocalAddr) + } + if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"); got != want { + t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want) + } + if got := tcpip.Address(rep.ProtocolAddressTarget()); got != remoteAddr { + t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, remoteAddr) + } + }) } } diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index ed502a473..936601287 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -136,8 +136,16 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea // proto is the protocol number marked in the fragment being processed. It has // to be given here outside of the FragmentID struct because IPv6 should not use // the protocol to identify a fragment. +// +// releaseCB is a callback that will run when the fragment reassembly of a +// packet is complete or cancelled. releaseCB take a a boolean argument which is +// true iff the reassembly is cancelled due to timeout. releaseCB should be +// passed only with the first fragment of a packet. If more than one releaseCB +// are passed for the same packet, only the first releaseCB will be saved for +// the packet and the succeeding ones will be dropped by running them +// immediately with a false argument. func (f *Fragmentation) Process( - id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) ( + id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView, releaseCB func(bool)) ( buffer.VectorisedView, uint8, bool, error) { if first > last { return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs) @@ -171,6 +179,12 @@ func (f *Fragmentation) Process( f.releaseReassemblersLocked() } } + if releaseCB != nil { + if !r.setCallback(releaseCB) { + // We got a duplicate callback. Release it immediately. + releaseCB(false /* timedOut */) + } + } f.mu.Unlock() res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, vv) @@ -178,14 +192,14 @@ func (f *Fragmentation) Process( // We probably got an invalid sequence of fragments. Just // discard the reassembler and move on. f.mu.Lock() - f.release(r) + f.release(r, false /* timedOut */) f.mu.Unlock() return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragmentation processing error: %w", err) } f.mu.Lock() f.size += consumed if done { - f.release(r) + f.release(r, false /* timedOut */) } // Evict reassemblers if we are consuming more memory than highLimit until // we reach lowLimit. @@ -195,14 +209,14 @@ func (f *Fragmentation) Process( if tail == nil { break } - f.release(tail) + f.release(tail, false /* timedOut */) } } f.mu.Unlock() return res, firstFragmentProto, done, nil } -func (f *Fragmentation) release(r *reassembler) { +func (f *Fragmentation) release(r *reassembler, timedOut bool) { // Before releasing a fragment we need to check if r is already marked as done. // Otherwise, we would delete it twice. if r.checkDoneOrMark() { @@ -216,6 +230,8 @@ func (f *Fragmentation) release(r *reassembler) { log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size) f.size = 0 } + + r.release(timedOut) // releaseCB may run. } // releaseReassemblersLocked releases already-expired reassemblers, then @@ -238,31 +254,31 @@ func (f *Fragmentation) releaseReassemblersLocked() { break } // If the oldest reassembler has already expired, release it. - f.release(r) + f.release(r, true /* timedOut*/) } } // PacketFragmenter is the book-keeping struct for packet fragmentation. type PacketFragmenter struct { - transportHeader buffer.View - data buffer.VectorisedView - reserve int - innerMTU int - fragmentCount int - currentFragment int - fragmentOffset int + transportHeader buffer.View + data buffer.VectorisedView + reserve int + fragmentPayloadLen int + fragmentCount int + currentFragment int + fragmentOffset int } // MakePacketFragmenter prepares the struct needed for packet fragmentation. // // pkt is the packet to be fragmented. // -// innerMTU is the maximum number of bytes of fragmentable data a fragment can +// fragmentPayloadLen is the maximum number of bytes of fragmentable data a fragment can // have. // // reserve is the number of bytes that should be reserved for the headers in // each generated fragment. -func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) PacketFragmenter { +func MakePacketFragmenter(pkt *stack.PacketBuffer, fragmentPayloadLen uint32, reserve int) PacketFragmenter { // As per RFC 8200 Section 4.5, some IPv6 extension headers should not be // repeated in each fragment. However we do not currently support any header // of that kind yet, so the following computation is valid for both IPv4 and @@ -273,13 +289,13 @@ func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) Pa var fragmentableData buffer.VectorisedView fragmentableData.AppendView(pkt.TransportHeader().View()) fragmentableData.Append(pkt.Data) - fragmentCount := (fragmentableData.Size() + innerMTU - 1) / innerMTU + fragmentCount := (uint32(fragmentableData.Size()) + fragmentPayloadLen - 1) / fragmentPayloadLen return PacketFragmenter{ - data: fragmentableData, - reserve: reserve, - innerMTU: innerMTU, - fragmentCount: fragmentCount, + data: fragmentableData, + reserve: reserve, + fragmentPayloadLen: int(fragmentPayloadLen), + fragmentCount: int(fragmentCount), } } @@ -302,7 +318,7 @@ func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int, }) // Copy data for the fragment. - copied := pf.data.ReadToVV(&fragPkt.Data, pf.innerMTU) + copied := pf.data.ReadToVV(&fragPkt.Data, pf.fragmentPayloadLen) offset := pf.fragmentOffset pf.fragmentOffset += copied diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index d3c7d7f92..5dcd10730 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -105,7 +105,7 @@ func TestFragmentationProcess(t *testing.T) { f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}) firstFragmentProto := c.in[0].proto for i, in := range c.in { - vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv) + vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv, nil) if err != nil { t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %X) failed: %s", in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), err) @@ -240,7 +240,7 @@ func TestReassemblingTimeout(t *testing.T) { for _, event := range test.events { clock.Advance(event.clockAdvance) if frag := event.fragment; frag != nil { - _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data)) + _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data), nil) if err != nil { t.Fatalf("%s: f.Process failed: %s", event.name, err) } @@ -259,15 +259,15 @@ func TestReassemblingTimeout(t *testing.T) { func TestMemoryLimits(t *testing.T) { f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{}) // Send first fragment with id = 0. - f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0")) + f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"), nil) // Send first fragment with id = 1. - f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1")) + f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1"), nil) // Send first fragment with id = 2. - f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2")) + f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2"), nil) // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be // evicted. - f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3")) + f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3"), nil) if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { t.Errorf("Memory limits are not respected: id=0 has not been evicted.") @@ -283,9 +283,9 @@ func TestMemoryLimits(t *testing.T) { func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{}) // Send first fragment with id = 0. - f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0")) + f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil) // Send the same packet again. - f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0")) + f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil) got := f.size want := 1 @@ -377,7 +377,7 @@ func TestErrors(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}) - _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data)) + _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data), nil) if !errors.Is(err, test.err) { t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err) } @@ -403,14 +403,14 @@ func TestPacketFragmenter(t *testing.T) { tests := []struct { name string - innerMTU int + fragmentPayloadLen uint32 transportHeaderLen int payloadSize int wantFragments []fragmentInfo }{ { name: "Packet exactly fits in MTU", - innerMTU: 1280, + fragmentPayloadLen: 1280, transportHeaderLen: 0, payloadSize: 1280, wantFragments: []fragmentInfo{ @@ -419,7 +419,7 @@ func TestPacketFragmenter(t *testing.T) { }, { name: "Packet exactly does not fit in MTU", - innerMTU: 1000, + fragmentPayloadLen: 1000, transportHeaderLen: 0, payloadSize: 1001, wantFragments: []fragmentInfo{ @@ -429,7 +429,7 @@ func TestPacketFragmenter(t *testing.T) { }, { name: "Packet has a transport header", - innerMTU: 560, + fragmentPayloadLen: 560, transportHeaderLen: 40, payloadSize: 560, wantFragments: []fragmentInfo{ @@ -439,7 +439,7 @@ func TestPacketFragmenter(t *testing.T) { }, { name: "Packet has a huge transport header", - innerMTU: 500, + fragmentPayloadLen: 500, transportHeaderLen: 1300, payloadSize: 500, wantFragments: []fragmentInfo{ @@ -458,7 +458,7 @@ func TestPacketFragmenter(t *testing.T) { originalPayload.AppendView(pkt.TransportHeader().View()) originalPayload.Append(pkt.Data) var reassembledPayload buffer.VectorisedView - pf := MakePacketFragmenter(pkt, test.innerMTU, reserve) + pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve) for i := 0; ; i++ { fragPkt, offset, copied, more := pf.BuildNextFragment() wantFragment := test.wantFragments[i] @@ -474,8 +474,8 @@ func TestPacketFragmenter(t *testing.T) { if more != wantFragment.more { t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more) } - if got := fragPkt.Size(); got > test.innerMTU { - t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.innerMTU) + if got := uint32(fragPkt.Size()); got > test.fragmentPayloadLen { + t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.fragmentPayloadLen) } if got := fragPkt.AvailableHeaderBytes(); got != reserve { t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve) @@ -497,3 +497,89 @@ func TestPacketFragmenter(t *testing.T) { }) } } + +func TestReleaseCallback(t *testing.T) { + const ( + proto = 99 + ) + + var result int + var callbackReasonIsTimeout bool + cb1 := func(timedOut bool) { result = 1; callbackReasonIsTimeout = timedOut } + cb2 := func(timedOut bool) { result = 2; callbackReasonIsTimeout = timedOut } + + tests := []struct { + name string + callbacks []func(bool) + timeout bool + wantResult int + wantCallbackReasonIsTimeout bool + }{ + { + name: "callback runs on release", + callbacks: []func(bool){cb1}, + timeout: false, + wantResult: 1, + wantCallbackReasonIsTimeout: false, + }, + { + name: "first callback is nil", + callbacks: []func(bool){nil, cb2}, + timeout: false, + wantResult: 2, + wantCallbackReasonIsTimeout: false, + }, + { + name: "two callbacks - first one is set", + callbacks: []func(bool){cb1, cb2}, + timeout: false, + wantResult: 1, + wantCallbackReasonIsTimeout: false, + }, + { + name: "callback runs on timeout", + callbacks: []func(bool){cb1}, + timeout: true, + wantResult: 1, + wantCallbackReasonIsTimeout: true, + }, + { + name: "no callbacks", + callbacks: []func(bool){nil}, + timeout: false, + wantResult: 0, + wantCallbackReasonIsTimeout: false, + }, + } + + id := FragmentID{ID: 0} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result = 0 + callbackReasonIsTimeout = false + + f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}) + + for i, cb := range test.callbacks { + _, _, _, err := f.Process(id, uint16(i), uint16(i), true, proto, vv(1, "0"), cb) + if err != nil { + t.Errorf("f.Process error = %s", err) + } + } + + r, ok := f.reassemblers[id] + if !ok { + t.Fatalf("Reassemberr not found") + } + f.release(r, test.timeout) + + if result != test.wantResult { + t.Errorf("got result = %d, want = %d", result, test.wantResult) + } + if callbackReasonIsTimeout != test.wantCallbackReasonIsTimeout { + t.Errorf("got callbackReasonIsTimeout = %t, want = %t", callbackReasonIsTimeout, test.wantCallbackReasonIsTimeout) + } + }) + } +} diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 9bb051a30..c0cc0bde0 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -41,6 +41,7 @@ type reassembler struct { heap fragHeap done bool creationTime int64 + callback func(bool) } func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { @@ -123,3 +124,24 @@ func (r *reassembler) checkDoneOrMark() bool { r.mu.Unlock() return prev } + +func (r *reassembler) setCallback(c func(bool)) bool { + r.mu.Lock() + defer r.mu.Unlock() + if r.callback != nil { + return false + } + r.callback = c + return true +} + +func (r *reassembler) release(timedOut bool) { + r.mu.Lock() + callback := r.callback + r.callback = nil + r.mu.Unlock() + + if callback != nil { + callback(timedOut) + } +} diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go index a0a04a027..fa2a70dc8 100644 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/fragmentation/reassembler_test.go @@ -105,3 +105,26 @@ func TestUpdateHoles(t *testing.T) { } } } + +func TestSetCallback(t *testing.T) { + result := 0 + reasonTimeout := false + + cb1 := func(timedOut bool) { result = 1; reasonTimeout = timedOut } + cb2 := func(timedOut bool) { result = 2; reasonTimeout = timedOut } + + r := newReassembler(FragmentID{}, &faketime.NullClock{}) + if !r.setCallback(cb1) { + t.Errorf("setCallback failed") + } + if r.setCallback(cb2) { + t.Errorf("setCallback should fail if one is already set") + } + r.release(true) + if result != 1 { + t.Errorf("got result = %d, want = 1", result) + } + if !reasonTimeout { + t.Errorf("got reasonTimeout = %t, want = true", reasonTimeout) + } +} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index f20b94d97..969579601 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -304,6 +304,10 @@ func (t *testInterface) setEnabled(v bool) { t.mu.disabled = !v } +func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { + return tcpip.ErrNotSupported +} + func TestSourceAddressValidation(t *testing.T) { rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) { totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize @@ -690,6 +694,10 @@ func TestIPv4ReceiveControl(t *testing.T) { view[i] = uint8(i) } + icmp.SetChecksum(0) + checksum := ^header.Checksum(icmp, 0 /* initial */) + icmp.SetChecksum(checksum) + // Give packet to IPv4 endpoint, dispatcher will validate that // it's ok. nic.testObject.protocol = 10 diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 7fc12e229..6252614ec 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -29,6 +29,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 3407755ed..cf287446e 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,6 +15,7 @@ package ipv4 import ( + "errors" "fmt" "gvisor.dev/gvisor/pkg/tcpip" @@ -23,10 +24,10 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) -// handleControl handles the case when an ICMP packet contains the headers of -// the original packet that caused the ICMP one to be sent. This information is -// used to find out which transport endpoint must be notified about the ICMP -// packet. +// handleControl handles the case when an ICMP error packet contains the headers +// of the original packet that caused the ICMP one to be sent. This information +// is used to find out which transport endpoint must be notified about the ICMP +// packet. We only expect the payload, not the enclosing ICMP packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { @@ -73,20 +74,65 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { } h := header.ICMPv4(v) + // Only do in-stack processing if the checksum is correct. + if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xffff { + received.Invalid.Increment() + // It's possible that a raw socket expects to receive this regardless + // of checksum errors. If it's an echo request we know it's safe because + // we are the only handler, however other types do not cope well with + // packets with checksum errors. + switch h.Type() { + case header.ICMPv4Echo: + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) + } + return + } + + iph := header.IPv4(pkt.NetworkHeader().View()) + var newOptions header.IPv4Options + if len(iph) > header.IPv4MinimumSize { + // RFC 1122 section 3.2.2.6 (page 43) (and similar for other round trip + // type ICMP packets): + // If a Record Route and/or Time Stamp option is received in an + // ICMP Echo Request, this option (these options) SHOULD be + // updated to include the current host and included in the IP + // header of the Echo Reply message, without "truncation". + // Thus, the recorded route will be for the entire round trip. + // + // So we need to let the option processor know how it should handle them. + var op optionsUsage + if h.Type() == header.ICMPv4Echo { + op = &optionUsageEcho{} + } else { + op = &optionUsageReceive{} + } + aux, tmp, err := processIPOptions(r, iph.Options(), op) + if err != nil { + switch { + case + errors.Is(err, header.ErrIPv4OptDuplicate), + errors.Is(err, errIPv4RecordRouteOptInvalidLength), + errors.Is(err, errIPv4RecordRouteOptInvalidPointer), + errors.Is(err, errIPv4TimestampOptInvalidLength), + errors.Is(err, errIPv4TimestampOptInvalidPointer), + errors.Is(err, errIPv4TimestampOptOverflow): + _ = e.protocol.returnError(r, &icmpReasonParamProblem{pointer: aux}, pkt) + e.protocol.stack.Stats().MalformedRcvdPackets.Increment() + r.Stats().IP.MalformedPacketsReceived.Increment() + } + return + } + newOptions = tmp + } + // TODO(b/112892170): Meaningfully handle all ICMP types. switch h.Type() { case header.ICMPv4Echo: received.Echo.Increment() - // Only send a reply if the checksum is valid. - headerChecksum := h.Checksum() - h.SetChecksum(0) - calculatedChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) - h.SetChecksum(headerChecksum) - if calculatedChecksum != headerChecksum { - // It's possible that a raw socket still expects to receive this. - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) - received.Invalid.Increment() + sent := stats.ICMP.V4PacketsSent + if !r.Stack().AllowICMPMessage() { + sent.RateLimited.Increment() return } @@ -98,9 +144,14 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { // waiting endpoints. Consider moving responsibility for doing the copy to // DeliverTransportPacket so that is is only done when needed. replyData := pkt.Data.ToOwnedView() - replyIPHdr := header.IPv4(append(buffer.View(nil), pkt.NetworkHeader().View()...)) + // It's possible that a raw socket expects to receive this. e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) + pkt = nil + // Take the base of the incoming request IP header but replace the options. + replyHeaderLength := uint8(header.IPv4MinimumSize + len(newOptions)) + replyIPHdr := header.IPv4(append(iph[:header.IPv4MinimumSize:header.IPv4MinimumSize], newOptions...)) + replyIPHdr.SetHeaderLength(replyHeaderLength) // As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP // source address MUST be one of its own IP addresses (but not a broadcast @@ -139,7 +190,8 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { // The fields we need to alter. // // We need to produce the entire packet in the data segment in order to - // use WriteHeaderIncludedPacket(). + // use WriteHeaderIncludedPacket(). WriteHeaderIncludedPacket sets the + // total length and the header checksum so we don't need to set those here. replyIPHdr.SetSourceAddress(r.LocalAddress) replyIPHdr.SetDestinationAddress(r.RemoteAddress) replyIPHdr.SetTTL(r.DefaultTTL()) @@ -157,8 +209,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { }) replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber - // The checksum will be calculated so we don't need to do it here. - sent := stats.ICMP.V4PacketsSent if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil { sent.Dropped.Increment() return @@ -182,8 +232,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { e.handleControl(stack.ControlPortUnreachable, 0, pkt) case header.ICMPv4FragmentationNeeded: - mtu := uint32(h.MTU()) - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) + networkMTU, err := calculateNetworkMTU(uint32(h.MTU()), header.IPv4MinimumSize) + if err != nil { + networkMTU = 0 + } + e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt) } case header.ICMPv4SrcQuench: @@ -234,6 +287,21 @@ type icmpReasonProtoUnreachable struct{} func (*icmpReasonProtoUnreachable) isICMPReason() {} +// icmpReasonReassemblyTimeout is an error where insufficient fragments are +// received to complete reassembly of a packet within a configured time after +// the reception of the first-arriving fragment of that packet. +type icmpReasonReassemblyTimeout struct{} + +func (*icmpReasonReassemblyTimeout) isICMPReason() {} + +// icmpReasonParamProblem is an error to use to request a Parameter Problem +// message to be sent. +type icmpReasonParamProblem struct { + pointer byte +} + +func (*icmpReasonParamProblem) isICMPReason() {} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv4 and sends it back to the remote device that sent // the problematic packet. It incorporates as much of that packet as @@ -374,17 +442,29 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) - switch reason.(type) { + var counter *tcpip.StatCounter + switch reason := reason.(type) { case *icmpReasonPortUnreachable: + icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4PortUnreachable) + counter = sent.DstUnreachable case *icmpReasonProtoUnreachable: + icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) + counter = sent.DstUnreachable + case *icmpReasonReassemblyTimeout: + icmpHdr.SetType(header.ICMPv4TimeExceeded) + icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout) + counter = sent.TimeExceeded + case *icmpReasonParamProblem: + icmpHdr.SetType(header.ICMPv4ParamProblem) + icmpHdr.SetCode(header.ICMPv4UnusedCode) + icmpHdr.SetPointer(reason.pointer) + counter = sent.ParamProblem default: panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } - icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) - counter := sent.DstUnreachable if err := route.WritePacket( nil, /* gso */ diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index e7c58ae0a..4592984a5 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -16,7 +16,9 @@ package ipv4 import ( + "errors" "fmt" + "math" "sync/atomic" "time" @@ -31,6 +33,8 @@ import ( ) const ( + // ReassembleTimeout is the time a packet stays in the reassembly + // system before being evicted. // As per RFC 791 section 3.2: // The current recommendation for the initial timer setting is 15 seconds. // This may be changed as experience with this protocol accumulates. @@ -38,7 +42,7 @@ const ( // Considering that it is an old recommendation, we use the same reassembly // timeout that linux defines, which is 30 seconds: // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ip.h#L138 - reassembleTimeout = 30 * time.Second + ReassembleTimeout = 30 * time.Second // ProtocolNumber is the ipv4 protocol number. ProtocolNumber = header.IPv4ProtocolNumber @@ -176,7 +180,11 @@ func (e *endpoint) DefaultTTL() uint8 { // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus // the network layer max header length. func (e *endpoint) MTU() uint32 { - return calculateMTU(e.nic.MTU()) + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv4MinimumSize) + if err != nil { + return 0 + } + return networkMTU } // MaxHeaderLength returns the maximum length needed by ipv4 headers (and @@ -211,18 +219,15 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s pkt.NetworkProtocolNumber = ProtocolNumber } -func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool { - return (gso == nil || gso.Type == stack.GSONone) && pkt.Size() > int(e.nic.MTU()) -} - // handleFragments fragments pkt and calls the handler function on each // fragment. It returns the number of fragments handled and the number of // fragments left to be processed. The IP header must already be present in the -// original packet. The mtu is the maximum size of the packets. -func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { - fragMTU := int(calculateFragmentInnerMTU(mtu, pkt)) +// original packet. +func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { + // Round the MTU down to align to 8 bytes. + fragmentPayloadSize := networkMTU &^ 7 networkHeader := header.IPv4(pkt.NetworkHeader().View()) - pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, pkt.AvailableHeaderBytes()+len(networkHeader)) + pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadSize, pkt.AvailableHeaderBytes()+len(networkHeader)) var n int for { @@ -280,8 +285,14 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet return nil } - if e.packetMustBeFragmented(pkt, gso) { - sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } + + if packetMustBeFragmented(pkt, networkMTU, gso) { + sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to @@ -292,6 +303,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain)) return err } + if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err @@ -311,17 +323,23 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.addIPHeader(r, pkt, params) - if e.packetMustBeFragmented(pkt, gso) { + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + return 0, err + } + + if packetMustBeFragmented(pkt, networkMTU, gso) { // Keep track of the packet that is about to be fragmented so it can be // removed once the fragmentation is done. originalPkt := pkt - if _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // Modify the packet list in place with the new fragments. pkts.InsertAfter(pkt, fragPkt) pkt = fragPkt return nil }); err != nil { - panic(fmt.Sprintf("e.handleFragments(_, _, %d, _, _) = %s", e.nic.MTU(), err)) + panic(fmt.Sprintf("e.handleFragments(_, _, %d, _, _) = %s", networkMTU, err)) } // Remove the packet that was just fragmented and process the rest. pkts.Remove(originalPkt) @@ -506,6 +524,28 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { r.Stats().IP.MalformedFragmentsReceived.Increment() return } + + // Set up a callback in case we need to send a Time Exceeded Message, as per + // RFC 792: + // + // If a host reassembling a fragmented datagram cannot complete the + // reassembly due to missing fragments within its time limit it discards + // the datagram, and it may send a time exceeded message. + // + // If fragment zero is not available then no time exceeded need be sent at + // all. + var releaseCB func(bool) + if start == 0 { + pkt := pkt.Clone() + r := r.Clone() + releaseCB = func(timedOut bool) { + if timedOut { + _ = e.protocol.returnError(&r, &icmpReasonReassemblyTimeout{}, pkt) + } + r.Release() + } + } + var ready bool var err error proto := h.Protocol() @@ -523,6 +563,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { h.More(), proto, pkt.Data, + releaseCB, ) if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() @@ -532,9 +573,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { if !ready { return } - } + // The reassembler doesn't take care of fixing up the header, so we need + // to do it here. + h.SetTotalLength(uint16(pkt.Data.Size() + len((h)))) + h.SetFlagsFragmentOffset(0, 0) + } r.Stats().IP.PacketsDelivered.Increment() + p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { // TODO(gvisor.dev/issues/3810): when we sort out ICMP and transport @@ -544,6 +590,27 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { e.handleICMP(r, pkt) return } + if len(h.Options()) != 0 { + // TODO(gvisor.dev/issue/4586): + // When we add forwarding support we should use the verified options + // rather than just throwing them away. + aux, _, err := processIPOptions(r, h.Options(), &optionUsageReceive{}) + if err != nil { + switch { + case + errors.Is(err, header.ErrIPv4OptDuplicate), + errors.Is(err, errIPv4RecordRouteOptInvalidPointer), + errors.Is(err, errIPv4RecordRouteOptInvalidLength), + errors.Is(err, errIPv4TimestampOptInvalidLength), + errors.Is(err, errIPv4TimestampOptInvalidPointer), + errors.Is(err, errIPv4TimestampOptOverflow): + _ = e.protocol.returnError(r, &icmpReasonParamProblem{pointer: aux}, pkt) + e.protocol.stack.Stats().MalformedRcvdPackets.Increment() + r.Stats().IP.MalformedPacketsReceived.Increment() + } + return + } + } switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res { case stack.TransportPacketHandled: @@ -778,26 +845,32 @@ func (p *protocol) SetForwarding(v bool) { } } -// calculateMTU calculates the network-layer payload MTU based on the link-layer -// payload mtu. -func calculateMTU(mtu uint32) uint32 { - if mtu > MaxTotalSize { - mtu = MaxTotalSize +// calculateNetworkMTU calculates the network-layer payload MTU based on the +// link-layer payload mtu. +func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, *tcpip.Error) { + if linkMTU < header.IPv4MinimumMTU { + return 0, tcpip.ErrInvalidEndpointState } - return mtu - header.IPv4MinimumSize -} -// calculateFragmentInnerMTU calculates the maximum number of bytes of -// fragmentable data a fragment can have, based on the link layer mtu and pkt's -// network header size. -func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 { - if mtu > MaxTotalSize { - mtu = MaxTotalSize + // As per RFC 791 section 3.1, an IPv4 header cannot exceed 60 bytes in + // length: + // The maximal internet header is 60 octets, and a typical internet header + // is 20 octets, allowing a margin for headers of higher level protocols. + if networkHeaderSize > header.IPv4MaximumHeaderSize { + return 0, tcpip.ErrMalformedHeader } - mtu -= uint32(pkt.NetworkHeader().View().Size()) - // Round the MTU down to align to 8 bytes. - mtu &^= 7 - return mtu + + networkMTU := linkMTU + if networkMTU > MaxTotalSize { + networkMTU = MaxTotalSize + } + + return networkMTU - uint32(networkHeaderSize), nil +} + +func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool { + payload := pkt.TransportHeader().View().Size() + pkt.Data.Size() + return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU } // addressToUint32 translates an IPv4 address into its little endian uint32 @@ -836,7 +909,7 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol { ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL, - fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()), + fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()), } } @@ -862,3 +935,322 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader head return fragPkt, more } + +// optionAction describes possible actions that may be taken on an option +// while processing it. +type optionAction uint8 + +const ( + // optionRemove says that the option should not be in the output option set. + optionRemove optionAction = iota + + // optionProcess says that the option should be fully processed. + optionProcess + + // optionVerify says the option should be checked and passed unchanged. + optionVerify + + // optionPass says to pass the output set without checking. + optionPass +) + +// optionActions list what to do for each option in a given scenario. +type optionActions struct { + // timestamp controls what to do with a Timestamp option. + timestamp optionAction + + // recordroute controls what to do with a Record Route option. + recordRoute optionAction + + // unknown controls what to do with an unknown option. + unknown optionAction +} + +// optionsUsage specifies the ways options may be operated upon for a given +// scenario during packet processing. +type optionsUsage interface { + actions() optionActions +} + +// optionUsageReceive implements optionsUsage for received packets. +type optionUsageReceive struct{} + +// actions implements optionsUsage. +func (*optionUsageReceive) actions() optionActions { + return optionActions{ + timestamp: optionVerify, + recordRoute: optionVerify, + unknown: optionPass, + } +} + +// TODO(gvisor.dev/issue/4586): Add an entry here for forwarding when it +// is enabled (Process, Process, Pass) and for fragmenting (Process, Process, +// Pass for frag1, but Remove,Remove,Remove for all other frags). + +// optionUsageEcho implements optionsUsage for echo packet processing. +type optionUsageEcho struct{} + +// actions implements optionsUsage. +func (*optionUsageEcho) actions() optionActions { + return optionActions{ + timestamp: optionProcess, + recordRoute: optionProcess, + unknown: optionRemove, + } +} + +var ( + errIPv4TimestampOptInvalidLength = errors.New("invalid Timestamp length") + errIPv4TimestampOptInvalidPointer = errors.New("invalid Timestamp pointer") + errIPv4TimestampOptOverflow = errors.New("overflow in Timestamp") + errIPv4TimestampOptInvalidFlags = errors.New("invalid Timestamp flags") +) + +// handleTimestamp does any required processing on a Timestamp option +// in place. +func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Address, clock tcpip.Clock, usage optionsUsage) (uint8, error) { + flags := tsOpt.Flags() + var entrySize uint8 + switch flags { + case header.IPv4OptionTimestampOnlyFlag: + entrySize = header.IPv4OptionTimestampSize + case + header.IPv4OptionTimestampWithIPFlag, + header.IPv4OptionTimestampWithPredefinedIPFlag: + entrySize = header.IPv4OptionTimestampWithAddrSize + default: + return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptInvalidFlags + } + + pointer := tsOpt.Pointer() + // To simplify processing below, base further work on the array of timestamps + // beyond the header, rather than on the whole option. Also to aid + // calculations set 'nextSlot' to be 0 based as in the packet it is 1 based. + nextSlot := pointer - (header.IPv4OptionTimestampHdrLength + 1) + optLen := tsOpt.Size() + dataLength := optLen - header.IPv4OptionTimestampHdrLength + + // In the section below, we verify the pointer, length and overflow counter + // fields of the option. The distinction is in which byte you return as being + // in error in the ICMP packet. Offsets 1 (length), 2 pointer) + // or 3 (overflowed counter). + // + // The following RFC sections cover this section: + // + // RFC 791 (page 22): + // If there is some room but not enough room for a full timestamp + // to be inserted, or the overflow count itself overflows, the + // original datagram is considered to be in error and is discarded. + // In either case an ICMP parameter problem message may be sent to + // the source host [3]. + // + // You can get this situation in two ways. Firstly if the data area is not + // a multiple of the entry size or secondly, if the pointer is not at a + // multiple of the entry size. The wording of the RFC suggests that + // this is not an error until you actually run out of space. + if pointer > optLen { + // RFC 791 (page 22) says we should switch to using the overflow count. + // If the timestamp data area is already full (the pointer exceeds + // the length) the datagram is forwarded without inserting the + // timestamp, but the overflow count is incremented by one. + if flags == header.IPv4OptionTimestampWithPredefinedIPFlag { + // By definition we have nothing to do. + return 0, nil + } + + if tsOpt.IncOverflow() != 0 { + return 0, nil + } + // The overflow count is also full. + return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptOverflow + } + if nextSlot+entrySize > dataLength { + // The data area isn't full but there isn't room for a new entry. + // Either Length or Pointer could be bad. + if false { + // We must select Pointer for Linux compatibility, even if + // only the length is bad. + // The Linux code is at (in October 2020) + // https://github.com/torvalds/linux/blob/bbf5c979011a099af5dc76498918ed7df445635b/net/ipv4/ip_options.c#L367-L370 + // if (optptr[2]+3 > optlen) { + // pp_ptr = optptr + 2; + // goto error; + // } + // which doesn't distinguish between which of optptr[2] or optlen + // is wrong, but just arbitrarily decides on optptr+2. + if dataLength%entrySize != 0 { + // The Data section size should be a multiple of the expected + // timestamp entry size. + return header.IPv4OptionLengthOffset, errIPv4TimestampOptInvalidLength + } + // If the size is OK, the pointer must be corrupted. + } + return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer + } + + if usage.actions().timestamp == optionProcess { + tsOpt.UpdateTimestamp(localAddress, clock) + } + return 0, nil +} + +var ( + errIPv4RecordRouteOptInvalidLength = errors.New("invalid length in Record Route") + errIPv4RecordRouteOptInvalidPointer = errors.New("invalid pointer in Record Route") +) + +// handleRecordRoute checks and processes a Record route option. It is much +// like the timestamp type 1 option, but without timestamps. The passed in +// address is stored in the option in the correct spot if possible. +func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Address, usage optionsUsage) (uint8, error) { + optlen := rrOpt.Size() + + if optlen < header.IPv4AddressSize+header.IPv4OptionRecordRouteHdrLength { + return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength + } + + nextSlot := rrOpt.Pointer() - 1 // Pointer is 1 based. + + // RFC 791 page 21 says + // If the route data area is already full (the pointer exceeds the + // length) the datagram is forwarded without inserting the address + // into the recorded route. If there is some room but not enough + // room for a full address to be inserted, the original datagram is + // considered to be in error and is discarded. In either case an + // ICMP parameter problem message may be sent to the source + // host. + // The use of the words "In either case" suggests that a 'full' RR option + // could generate an ICMP at every hop after it fills up. We chose to not + // do this (as do most implementations). It is probable that the inclusion + // of these words is a copy/paste error from the timestamp option where + // there are two failure reasons given. + if nextSlot >= optlen { + return 0, nil + } + + // The data area isn't full but there isn't room for a new entry. + // Either Length or Pointer could be bad. We must select Pointer for Linux + // compatibility, even if only the length is bad. + if nextSlot+header.IPv4AddressSize > optlen { + if false { + // This is what we would do if we were not being Linux compatible. + // Check for bad pointer or length value. Must be a multiple of 4 after + // accounting for the 3 byte header and not within that header. + // RFC 791, page 20 says: + // The pointer is relative to this option, and the + // smallest legal value for the pointer is 4. + // + // A recorded route is composed of a series of internet addresses. + // Each internet address is 32 bits or 4 octets. + // Linux skips this test so we must too. See Linux code at: + // https://github.com/torvalds/linux/blob/bbf5c979011a099af5dc76498918ed7df445635b/net/ipv4/ip_options.c#L338-L341 + // if (optptr[2]+3 > optlen) { + // pp_ptr = optptr + 2; + // goto error; + // } + if (optlen-header.IPv4OptionRecordRouteHdrLength)%header.IPv4AddressSize != 0 { + // Length is bad, not on integral number of slots. + return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength + } + // If not length, the fault must be with the pointer. + } + return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer + } + if usage.actions().recordRoute == optionVerify { + return 0, nil + } + rrOpt.StoreAddress(localAddress) + return 0, nil +} + +// processIPOptions parses the IPv4 options and produces a new set of options +// suitable for use in the next step of packet processing as informed by usage. +// The original will not be touched. +// +// Returns +// - The location of an error if there was one (or 0 if no error) +// - If there is an error, information as to what it was was. +// - The replacement option set. +func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) { + + opts := header.IPv4Options(orig) + optIter := opts.MakeIterator() + + // Each option other than NOP must only appear (RFC 791 section 3.1, at the + // definition of every type). Keep track of each of the possible types in + // the 8 bit 'type' field. + var seenOptions [math.MaxUint8 + 1]bool + + // TODO(gvisor.dev/issue/4586): + // This will need tweaking when we start really forwarding packets + // as we may need to get two addresses, for rx and tx interfaces. + // We will also have to take usage into account. + prefixedAddress, err := r.Stack().GetMainNICAddress(r.NICID(), ProtocolNumber) + localAddress := prefixedAddress.Address + if err != nil { + if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) { + return 0 /* errCursor */, nil, header.ErrIPv4OptionAddress + } + localAddress = r.LocalAddress + } + + for { + option, done, err := optIter.Next() + if done || err != nil { + return optIter.ErrCursor, optIter.Finalize(), err + } + optType := option.Type() + if optType == header.IPv4OptionNOPType { + optIter.PushNOPOrEnd(optType) + continue + } + if optType == header.IPv4OptionListEndType { + optIter.PushNOPOrEnd(optType) + return 0 /* errCursor */, optIter.Finalize(), nil /* err */ + } + + // check for repeating options (multiple NOPs are OK) + if seenOptions[optType] { + return optIter.ErrCursor, nil, header.ErrIPv4OptDuplicate + } + seenOptions[optType] = true + + optLen := int(option.Size()) + switch option := option.(type) { + case *header.IPv4OptionTimestamp: + r.Stats().IP.OptionTSReceived.Increment() + if usage.actions().timestamp != optionRemove { + clock := r.Stack().Clock() + newBuffer := optIter.RemainingBuffer()[:len(*option)] + _ = copy(newBuffer, option.Contents()) + offset, err := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage) + if err != nil { + return optIter.ErrCursor + offset, nil, err + } + optIter.ConsumeBuffer(optLen) + } + + case *header.IPv4OptionRecordRoute: + r.Stats().IP.OptionRRReceived.Increment() + if usage.actions().recordRoute != optionRemove { + newBuffer := optIter.RemainingBuffer()[:len(*option)] + _ = copy(newBuffer, option.Contents()) + offset, err := handleRecordRoute(header.IPv4OptionRecordRoute(newBuffer), localAddress, usage) + if err != nil { + return optIter.ErrCursor + offset, nil, err + } + optIter.ConsumeBuffer(optLen) + } + + default: + r.Stats().IP.OptionUnknownReceived.Increment() + if usage.actions().unknown == optionPass { + newBuffer := optIter.RemainingBuffer()[:optLen] + // Arguments already heavily checked.. ignore result. + _ = copy(newBuffer, option.Contents()) + optIter.ConsumeBuffer(optLen) + } + } + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index fee11bb38..61672a5ff 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -21,11 +21,13 @@ import ( "math" "net" "testing" + "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" @@ -39,7 +41,10 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -const extraHeaderReserve = 50 +const ( + extraHeaderReserve = 50 + defaultMTU = 65536 +) func TestExcludeBroadcast(t *testing.T) { s := stack.New(stack.Options{ @@ -47,7 +52,6 @@ func TestExcludeBroadcast(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - const defaultMTU = 65536 ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) if testing.Verbose() { ep = sniffer.New(ep) @@ -103,7 +107,6 @@ func TestExcludeBroadcast(t *testing.T) { // checks the response. func TestIPv4Sanity(t *testing.T) { const ( - defaultMTU = header.IPv6MinimumMTU ttl = 255 nicID = 1 randomSequence = 123 @@ -118,27 +121,29 @@ func TestIPv4Sanity(t *testing.T) { ) tests := []struct { - name string - headerLength uint8 // value of 0 means "use correct size" - badHeaderChecksum bool - maxTotalLength uint16 - transportProtocol uint8 - TTL uint8 - shouldFail bool - expectICMP bool - ICMPType header.ICMPv4Type - ICMPCode header.ICMPv4Code - options []byte + name string + headerLength uint8 // value of 0 means "use correct size" + badHeaderChecksum bool + maxTotalLength uint16 + transportProtocol uint8 + TTL uint8 + options []byte + replyOptions []byte // if succeeds, reply should look like this + shouldFail bool + expectErrorICMP bool + ICMPType header.ICMPv4Type + ICMPCode header.ICMPv4Code + paramProblemPointer uint8 }{ { - name: "valid", - maxTotalLength: defaultMTU, + name: "valid no options", + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, }, { name: "bad header checksum", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, badHeaderChecksum: true, @@ -157,47 +162,47 @@ func TestIPv4Sanity(t *testing.T) { // received with TTL less than 2. { name: "zero TTL", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: 0, - shouldFail: false, }, { name: "one TTL", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: 1, - shouldFail: false, }, { name: "End options", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: []byte{0, 0, 0, 0}, + replyOptions: []byte{0, 0, 0, 0}, }, { name: "NOP options", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: []byte{1, 1, 1, 1}, + replyOptions: []byte{1, 1, 1, 1}, }, { name: "NOP and End options", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: []byte{1, 1, 0, 0}, + replyOptions: []byte{1, 1, 0, 0}, }, { name: "bad header length", headerLength: header.IPv4MinimumSize - 1, - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, - expectICMP: false, }, { name: "bad total length (0)", @@ -205,7 +210,6 @@ func TestIPv4Sanity(t *testing.T) { transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, - expectICMP: false, }, { name: "bad total length (ip - 1)", @@ -213,7 +217,6 @@ func TestIPv4Sanity(t *testing.T) { transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, - expectICMP: false, }, { name: "bad total length (ip + icmp - 1)", @@ -221,28 +224,361 @@ func TestIPv4Sanity(t *testing.T) { transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, - expectICMP: false, }, { name: "bad protocol", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: 99, TTL: ttl, shouldFail: true, - expectICMP: true, + expectErrorICMP: true, ICMPType: header.ICMPv4DstUnreachable, ICMPCode: header.ICMPv4ProtoUnreachable, }, + { + name: "timestamp option overflow", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 12, 13, 0x11, + 192, 168, 1, 12, + 1, 2, 3, 4, + }, + replyOptions: []byte{ + 68, 12, 13, 0x21, + 192, 168, 1, 12, + 1, 2, 3, 4, + }, + }, + { + name: "timestamp option overflow full", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 12, 13, 0xF1, + // ^ Counter full (15/0xF) + 192, 168, 1, 12, + 1, 2, 3, 4, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 3, + replyOptions: []byte{}, + }, + { + name: "unknown option", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{10, 4, 9, 0}, + // ^^ + // The unknown option should be stripped out of the reply. + replyOptions: []byte{}, + }, + { + name: "bad option - length 0", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 0, 9, 0, + // ^ + 1, 2, 3, 4, + }, + shouldFail: true, + }, + { + name: "bad option - length big", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 9, 9, 0, + // ^ + // There are only 8 bytes allocated to options so 9 bytes of timestamp + // space is not possible. (Second byte) + 1, 2, 3, 4, + }, + shouldFail: true, + }, + { + // This tests for some linux compatible behaviour. + // The ICMP pointer returned is 22 for Linux but the + // error is actually in spot 21. + name: "bad option - length bad", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + // Timestamps are in multiples of 4 or 8 but never 7. + // The option space should be padded out. + options: []byte{ + 68, 7, 5, 0, + // ^ ^ Linux points here which is wrong. + // | Not a multiple of 4 + 1, 2, 3, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + { + name: "multiple type 0 with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 24, 21, 0x00, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 0, 0, 0, 0, + }, + replyOptions: []byte{ + 68, 24, 25, 0x00, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock + }, + }, + { + // The timestamp area is full so add to the overflow count. + name: "multiple type 1 timestamps", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 20, 21, 0x11, + // ^ + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + }, + // Overflow count is the top nibble of the 4th byte. + replyOptions: []byte{ + 68, 20, 21, 0x21, + // ^ + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + }, + }, + { + name: "multiple type 1 timestamps with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 28, 21, 0x01, + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + replyOptions: []byte{ + 68, 28, 29, 0x01, + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + 192, 168, 1, 58, // New IP Address. + 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock + }, + }, + { + // Needs 8 bytes for a type 1 timestamp but there are only 4 free. + name: "bad timer element alignment", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 20, 17, 0x01, + // ^^ ^^ 20 byte area, next free spot at 17. + 192, 168, 1, 12, + 1, 2, 3, 4, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + // End of option list with illegal option after it, which should be ignored. + { + name: "end of options list", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 12, 13, 0x11, + 192, 168, 1, 12, + 1, 2, 3, 4, + 0, 10, 3, 99, + }, + replyOptions: []byte{ + 68, 12, 13, 0x21, + 192, 168, 1, 12, + 1, 2, 3, 4, + 0, 0, 0, 0, // 3 bytes unknown option + }, // ^ End of options hides following bytes. + }, + { + // Timestamp with a size too small. + name: "timestamp truncated", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{68, 1, 0, 0}, + // ^ Smallest possible is 8. + shouldFail: true, + }, + { + name: "single record route with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 7, 4, // 3 byte header + 0, 0, 0, 0, + 0, + }, + replyOptions: []byte{ + 7, 7, 8, // 3 byte header + 192, 168, 1, 58, // New IP Address. + 0, // padding to multiple of 4 bytes. + }, + }, + { + name: "multiple record route with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 23, 20, // 3 byte header + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 0, 0, 0, 0, + 0, + }, + replyOptions: []byte{ + 7, 23, 24, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 192, 168, 1, 58, // New IP Address. + 0, // padding to multiple of 4 bytes. + }, + }, + { + name: "single record route with no room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 0, + }, + replyOptions: []byte{ + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 0, // padding to multiple of 4 bytes. + }, + }, + { + // Unlike timestamp, this should just succeed. + name: "multiple record route with no room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 23, 24, // 3 byte header + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 0, + }, + replyOptions: []byte{ + 7, 23, 24, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 0, // padding to multiple of 4 bytes. + }, + }, + { + // Confirm linux bug for bug compatibility. + // Linux returns slot 22 but the error is in slot 21. + name: "multiple record route with not enough room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 8, 8, // 3 byte header + // ^ ^ Linux points here. We must too. + // | Not enough room. 1 byte free, need 4. + 1, 2, 3, 4, + 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + replyOptions: []byte{}, + }, + { + name: "duplicate record route", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 0, 0, // pad + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 7, + replyOptions: []byte{}, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + Clock: clock, }) // We expect at most a single packet in response to our ICMP Echo Request. - e := channel.New(1, defaultMTU, "") + e := channel.New(1, ipv4.MaxTotalSize, "") if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } @@ -250,6 +586,9 @@ func TestIPv4Sanity(t *testing.T) { if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) } + // Advance the clock by some unimportant amount to make + // sure it's all set up. + clock.Advance(time.Millisecond * 0x10203040) // Default routes for IPv4 so ICMP can find a route to the remote // node when attempting to send the ICMP Echo Reply. @@ -312,14 +651,20 @@ func TestIPv4Sanity(t *testing.T) { reply, ok := e.Read() if !ok { if test.shouldFail { - if test.expectICMP { - t.Fatal("expected ICMP error response missing") + if test.expectErrorICMP { + t.Fatalf("ICMP error response (type %d, code %d) missing", test.ICMPType, test.ICMPCode) } return // Expected silent failure. } t.Fatal("expected ICMP echo reply missing") } + // We didn't expect a packet. Register our surprise but carry on to + // provide more information about what we got. + if test.shouldFail && !test.expectErrorICMP { + t.Error("unexpected packet response") + } + // Check the route that brought the packet to us. if reply.Route.LocalAddress != ipv4Addr.Address { t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address) @@ -328,57 +673,90 @@ func TestIPv4Sanity(t *testing.T) { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr) } - // Make sure it's all in one buffer. - vv := buffer.NewVectorisedView(reply.Pkt.Size(), reply.Pkt.Views()) - replyIPHeader := header.IPv4(vv.ToView()) + // Make sure it's all in one buffer for checker. + replyIPHeader := header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())) - // At this stage we only know it's an IP header so verify that much. + // At this stage we only know it's probably an IP+ICMP header so verify + // that much. checker.IPv4(t, replyIPHeader, checker.SrcAddr(ipv4Addr.Address), checker.DstAddr(remoteIPv4Addr), + checker.ICMPv4( + checker.ICMPv4Checksum(), + ), ) - // All expected responses are ICMP packets. - if got, want := replyIPHeader.Protocol(), uint8(header.ICMPv4ProtocolNumber); got != want { - t.Fatalf("not ICMP response, got protocol %d, want = %d", got, want) + // Don't proceed any further if the checker found problems. + if t.Failed() { + t.FailNow() } - replyICMPHeader := header.ICMPv4(replyIPHeader.Payload()) - // Sanity check the response. + // OK it's ICMP. We can safely look at the type now. + replyICMPHeader := header.ICMPv4(replyIPHeader.Payload()) switch replyICMPHeader.Type() { - case header.ICMPv4DstUnreachable: + case header.ICMPv4ParamProblem: + if !test.shouldFail { + t.Fatalf("got Parameter Problem with pointer %d, wanted Echo Reply", replyICMPHeader.Pointer()) + } + if !test.expectErrorICMP { + t.Fatalf("got Parameter Problem with pointer %d, wanted no response", replyICMPHeader.Pointer()) + } checker.IPv4(t, replyIPHeader, checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), checker.IPv4HeaderLength(header.IPv4MinimumSize), checker.ICMPv4( + checker.ICMPv4Type(test.ICMPType), checker.ICMPv4Code(test.ICMPCode), - checker.ICMPv4Checksum(), + checker.ICMPv4Pointer(test.paramProblemPointer), checker.ICMPv4Payload([]byte(hdr.View())), ), ) - if !test.shouldFail || !test.expectICMP { - t.Fatalf("unexpected packet rejection, got ICMP error packet type %d, code %d", + return + case header.ICMPv4DstUnreachable: + if !test.shouldFail { + t.Fatalf("got ICMP error packet type %d, code %d, wanted Echo Reply", header.ICMPv4DstUnreachable, replyICMPHeader.Code()) } + if !test.expectErrorICMP { + t.Fatalf("got ICMP error packet type %d, code %d, wanted no response", + header.ICMPv4DstUnreachable, replyICMPHeader.Code()) + } + checker.IPv4(t, replyIPHeader, + checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.ICMPv4( + checker.ICMPv4Type(test.ICMPType), + checker.ICMPv4Code(test.ICMPCode), + checker.ICMPv4Payload([]byte(hdr.View())), + ), + ) return case header.ICMPv4EchoReply: + if test.shouldFail { + if !test.expectErrorICMP { + t.Error("got Echo Reply packet, want no response") + } else { + t.Errorf("got Echo Reply, want ICMP error type %d, code %d", test.ICMPType, test.ICMPCode) + } + } + // If the IP options change size then the packet will change size, so + // some IP header fields will need to be adjusted for the checks. + sizeChange := len(test.replyOptions) - len(test.options) + checker.IPv4(t, replyIPHeader, - checker.IPv4HeaderLength(ipHeaderLength), - checker.IPv4Options(test.options), - checker.IPFullLength(uint16(requestPkt.Size())), + checker.IPv4HeaderLength(ipHeaderLength+sizeChange), + checker.IPv4Options(test.replyOptions), + checker.IPFullLength(uint16(requestPkt.Size()+sizeChange)), checker.ICMPv4( + checker.ICMPv4Checksum(), checker.ICMPv4Code(header.ICMPv4UnusedCode), checker.ICMPv4Seq(randomSequence), checker.ICMPv4Ident(randomIdent), - checker.ICMPv4Checksum(), ), ) - if test.shouldFail { - t.Fatalf("unexpected Echo Reply packet\n") - } default: - t.Fatalf("unexpected ICMP response, got type %d, want = %d or %d", - replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable) + t.Fatalf("unexpected ICMP response, got type %d, want = %d, %d or %d", + replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable, header.ICMPv4ParamProblem) } }) } @@ -462,7 +840,7 @@ var fragmentationTests = []struct { wantFragments []fragmentInfo }{ { - description: "No Fragmentation", + description: "No fragmentation", mtu: 1280, gso: nil, transportHeaderLength: 0, @@ -483,6 +861,30 @@ var fragmentationTests = []struct { }, }, { + description: "Fragmented with the minimum mtu", + mtu: header.IPv4MinimumMTU, + gso: nil, + transportHeaderLength: 0, + payloadSize: 100, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 48, more: true}, + {offset: 48, payloadSize: 48, more: true}, + {offset: 96, payloadSize: 4, more: false}, + }, + }, + { + description: "Fragmented with mtu not a multiple of 8", + mtu: header.IPv4MinimumMTU + 1, + gso: nil, + transportHeaderLength: 0, + payloadSize: 100, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 48, more: true}, + {offset: 48, payloadSize: 48, more: true}, + {offset: 96, payloadSize: 4, more: false}, + }, + }, + { description: "No fragmentation with big header", mtu: 2000, gso: nil, @@ -647,43 +1049,50 @@ func TestFragmentationWritePackets(t *testing.T) { } } -// TestFragmentationErrors checks that errors are returned from write packet +// TestFragmentationErrors checks that errors are returned from WritePacket // correctly. func TestFragmentationErrors(t *testing.T) { const ttl = 42 - expectedError := tcpip.ErrAborted - fragTests := []struct { + tests := []struct { description string mtu uint32 transportHeaderLength int payloadSize int allowPackets int - fragmentCount int + outgoingErrors int + mockError *tcpip.Error + wantError *tcpip.Error }{ { description: "No frag", mtu: 2000, - transportHeaderLength: 0, payloadSize: 1000, + transportHeaderLength: 0, allowPackets: 0, - fragmentCount: 1, + outgoingErrors: 1, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, }, { description: "Error on first frag", mtu: 500, - transportHeaderLength: 0, payloadSize: 1000, + transportHeaderLength: 0, allowPackets: 0, - fragmentCount: 3, + outgoingErrors: 3, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, }, { description: "Error on second frag", mtu: 500, - transportHeaderLength: 0, payloadSize: 1000, + transportHeaderLength: 0, allowPackets: 1, - fragmentCount: 3, + outgoingErrors: 2, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, }, { description: "Error on first frag MTU smaller than header", @@ -691,28 +1100,40 @@ func TestFragmentationErrors(t *testing.T) { transportHeaderLength: 1000, payloadSize: 500, allowPackets: 0, - fragmentCount: 4, + outgoingErrors: 4, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, + }, + { + description: "Error when MTU is smaller than IPv4 minimum MTU", + mtu: header.IPv4MinimumMTU - 1, + transportHeaderLength: 0, + payloadSize: 500, + allowPackets: 0, + outgoingErrors: 1, + mockError: nil, + wantError: tcpip.ErrInvalidEndpointState, }, } - for _, ft := range fragTests { + for _, ft := range tests { t.Run(ft.description, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(ft.mtu, expectedError, ft.allowPackets) - r := buildRoute(t, ep) pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) + r := buildRoute(t, ep) err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS, }, pkt) - if err != expectedError { - t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, expectedError) + if err != ft.wantError { + t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError) } - if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want { - t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want) + if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { + t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) } - if got, want := int(r.Stats().IP.OutgoingPacketErrors.Value()), ft.fragmentCount-ft.allowPackets; got != want { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, want) + if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors) } }) } @@ -744,7 +1165,6 @@ func TestInvalidFragments(t *testing.T) { autoChecksum bool // if true, the Checksum field will be overwritten. } - // These packets have both IHL and TotalLength set to 0. tests := []struct { name string fragments []fragmentData @@ -984,7 +1404,6 @@ func TestInvalidFragments(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ ipv4.NewProtocol, @@ -1027,6 +1446,259 @@ func TestInvalidFragments(t *testing.T) { } } +func TestFragmentReassemblyTimeout(t *testing.T) { + const ( + nicID = 1 + linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + addr1 = "\x0a\x00\x00\x01" + addr2 = "\x0a\x00\x00\x02" + tos = 0 + ident = 1 + ttl = 48 + protocol = 99 + data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT" + ) + + type fragmentData struct { + ipv4fields header.IPv4Fields + payload []byte + } + + tests := []struct { + name string + fragments []fragmentData + expectICMP bool + }{ + { + name: "first fragment only", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 16, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + { + name: "two first fragments", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 16, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:16], + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 16, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + { + name: "second fragment only", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), + ID: ident, + Flags: 0, + FragmentOffset: 8, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[16:], + }, + }, + expectICMP: false, + }, + { + name: "two fragments with a gap", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:8], + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), + ID: ident, + Flags: 0, + FragmentOffset: 16, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[16:], + }, + }, + expectICMP: true, + }, + { + name: "two fragments with a gap in reverse order", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), + ID: ident, + Flags: 0, + FragmentOffset: 16, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[16:], + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:8], + }, + }, + expectICMP: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, + }, + Clock: clock, + }) + e := channel.New(1, 1500, linkAddr) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }}) + + var firstFragmentSent buffer.View + for _, f := range test.fragments { + pktSize := header.IPv4MinimumSize + hdr := buffer.NewPrependable(pktSize) + + ip := header.IPv4(hdr.Prepend(pktSize)) + ip.Encode(&f.ipv4fields) + + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + + vv := hdr.View().ToVectorisedView() + vv.AppendView(f.payload) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) + + if firstFragmentSent == nil && ip.FragmentOffset() == 0 { + firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader()) + } + + e.InjectInbound(header.IPv4ProtocolNumber, pkt) + } + + clock.Advance(ipv4.ReassembleTimeout) + + reply, ok := e.Read() + if !test.expectICMP { + if ok { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + return + } + if !ok { + t.Fatal("expected ICMP error message missing") + } + if firstFragmentSent == nil { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + + checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), + checker.SrcAddr(addr2), + checker.DstAddr(addr1), + checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+firstFragmentSent.Size())), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4TimeExceeded), + checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout), + checker.ICMPv4Checksum(), + checker.ICMPv4Payload([]byte(firstFragmentSent)), + ), + ) + }) + } +} + // TestReceiveFragments feeds fragments in through the incoming packet path to // test reassembly func TestReceiveFragments(t *testing.T) { @@ -1577,7 +2249,7 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumSize+header.UDPMinimumSize, tcpip.ErrInvalidEndpointState, test.allowPackets) + ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets) rt := buildRoute(t, ep) var pkts stack.PacketBufferList @@ -1783,7 +2455,7 @@ func TestPacketQueing(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr) + e := channel.New(1, defaultMTU, host1NICLinkAddr) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index a30437f02..0ac24a6fb 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -36,6 +36,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index ead6bedcb..3c15e41a7 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -170,8 +170,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := header.ICMPv6(hdr).MTU() - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) + networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize) + if err != nil { + networkMTU = 0 + } + e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() @@ -284,7 +287,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme received.Invalid.Increment() return } else if e.nud != nil { - e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(r.RemoteAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } else { e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr) } @@ -555,7 +558,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme if e.nud != nil { // A RS with a specified source IP address modifies the NUD state // machine in the same way a reachability probe would. - e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(r.RemoteAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } } @@ -605,7 +608,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // If the RA has the source link layer option, update the link address // cache with the link address for the advertised router. if len(sourceLinkAddr) != 0 && e.nud != nil { - e.nud.HandleProbe(routerAddr, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(routerAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } e.mu.Lock() @@ -648,52 +651,46 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements stack.LinkAddressResolver. -func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { - // TODO(b/148672031): Use stack.FindRoute instead of manually creating the - // route here. Note, we would need the nicID to do this properly so the right - // NIC (associated to linkEP) is used to send the NDP NS message. - r := stack.Route{ - LocalAddress: localAddr, - RemoteAddress: addr, - LocalLinkAddress: linkEP.LinkAddress(), - RemoteLinkAddress: remoteLinkAddr, +func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error { + remoteAddr := targetAddr + if len(remoteLinkAddr) == 0 { + remoteAddr = header.SolicitedNodeAddr(targetAddr) + remoteLinkAddr = header.EthernetAddressFromMulticastIPv6Address(remoteAddr) } - // If a remote address is not already known, then send a multicast - // solicitation since multicast addresses have a static mapping to link - // addresses. - if len(r.RemoteLinkAddress) == 0 { - r.RemoteAddress = header.SolicitedNodeAddr(addr) - r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(r.RemoteAddress) + r, err := p.stack.FindRoute(nic.ID(), localAddr, remoteAddr, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err } + defer r.Release() + r.ResolveWith(remoteLinkAddr) optsSerializer := header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(linkEP.LinkAddress()), + header.NDPSourceLinkLayerAddressOption(nic.LinkAddress()), } neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length() pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + neighborSolicitSize, + ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborSolicitSize, }) pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize)) packet.SetType(header.ICMPv6NeighborSolicit) ns := header.NDPNeighborSolicit(packet.NDPPayload()) - ns.SetTargetAddress(addr) + ns.SetTargetAddress(targetAddr) ns.Options().Serialize(optsSerializer) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - length := uint16(pkt.Size()) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: length, - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, - }) + stat := p.stack.Stats().ICMP.V6PacketsSent + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + }, pkt); err != nil { + stat.Dropped.Increment() + return err + } - // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(&r, nil /* gso */, ProtocolNumber, pkt) + stat.NeighborSolicit.Increment() + return nil } // ResolveStaticAddress implements stack.LinkAddressResolver. @@ -747,6 +744,13 @@ type icmpReasonPortUnreachable struct{} func (*icmpReasonPortUnreachable) isICMPReason() {} +// icmpReasonReassemblyTimeout is an error where insufficient fragments are +// received to complete reassembly of a packet within a configured time after +// the reception of the first-arriving fragment of that packet. +type icmpReasonReassemblyTimeout struct{} + +func (*icmpReasonReassemblyTimeout) isICMPReason() {} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv6 and sends it. func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { @@ -839,7 +843,9 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) + payload := network.ToVectorisedView() + payload.AppendView(transport) + payload.Append(pkt.Data) payload.CapLength(payloadLen) newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -860,6 +866,10 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac icmpHdr.SetType(header.ICMPv6DstUnreachable) icmpHdr.SetCode(header.ICMPv6PortUnreachable) counter = sent.DstUnreachable + case *icmpReasonReassemblyTimeout: + icmpHdr.SetType(header.ICMPv6TimeExceeded) + icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout) + counter = sent.TimeExceeded default: panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 8dc33c560..aa8b5f2e5 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -51,6 +51,7 @@ const ( var ( lladdr0 = header.LinkLocalAddr(linkAddr0) lladdr1 = header.LinkLocalAddr(linkAddr1) + lladdr2 = header.LinkLocalAddr(linkAddr2) ) type stubLinkEndpoint struct { @@ -108,31 +109,27 @@ type stubNUDHandler struct { var _ stack.NUDHandler = (*stubNUDHandler)(nil) -func (s *stubNUDHandler) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes stack.LinkAddressResolver) { +func (s *stubNUDHandler) HandleProbe(tcpip.Address, tcpip.NetworkProtocolNumber, tcpip.LinkAddress, stack.LinkAddressResolver) { s.probeCount++ } -func (s *stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags stack.ReachabilityConfirmationFlags) { +func (s *stubNUDHandler) HandleConfirmation(tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) { s.confirmationCount++ } -func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) { +func (*stubNUDHandler) HandleUpperLevelConfirmation(tcpip.Address) { } var _ stack.NetworkInterface = (*testInterface)(nil) type testInterface struct { - stack.NetworkLinkEndpoint - - linkAddr tcpip.LinkAddress -} + stack.LinkEndpoint -func (i *testInterface) LinkAddress() tcpip.LinkAddress { - return i.linkAddr + nicID tcpip.NICID } func (*testInterface) ID() tcpip.NICID { - return 0 + return nicID } func (*testInterface) IsLoopback() bool { @@ -147,6 +144,14 @@ func (*testInterface) Enabled() bool { return true } +func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + r := stack.Route{ + NetProto: protocol, + RemoteLinkAddress: remoteLinkAddr, + } + return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) +} + func TestICMPCounts(t *testing.T) { tests := []struct { name string @@ -1235,26 +1240,72 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { } func TestLinkAddressRequest(t *testing.T) { + const nicID = 1 + snaddr := header.SolicitedNodeAddr(lladdr0) mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr) tests := []struct { - name string - remoteLinkAddr tcpip.LinkAddress - expectedLinkAddr tcpip.LinkAddress - expectedAddr tcpip.Address + name string + nicAddr tcpip.Address + localAddr tcpip.Address + remoteLinkAddr tcpip.LinkAddress + + expectedErr *tcpip.Error + expectedRemoteAddr tcpip.Address + expectedRemoteLinkAddr tcpip.LinkAddress }{ { - name: "Unicast", - remoteLinkAddr: linkAddr1, - expectedLinkAddr: linkAddr1, - expectedAddr: lladdr0, + name: "Unicast", + nicAddr: lladdr1, + localAddr: lladdr1, + remoteLinkAddr: linkAddr1, + expectedRemoteAddr: lladdr0, + expectedRemoteLinkAddr: linkAddr1, + }, + { + name: "Multicast", + nicAddr: lladdr1, + localAddr: lladdr1, + remoteLinkAddr: "", + expectedRemoteAddr: snaddr, + expectedRemoteLinkAddr: mcaddr, + }, + { + name: "Unicast with unspecified source", + nicAddr: lladdr1, + remoteLinkAddr: linkAddr1, + expectedRemoteAddr: lladdr0, + expectedRemoteLinkAddr: linkAddr1, }, { - name: "Multicast", - remoteLinkAddr: "", - expectedLinkAddr: mcaddr, - expectedAddr: snaddr, + name: "Multicast with unspecified source", + nicAddr: lladdr1, + remoteLinkAddr: "", + expectedRemoteAddr: snaddr, + expectedRemoteLinkAddr: mcaddr, + }, + { + name: "Unicast with unassigned address", + localAddr: lladdr1, + remoteLinkAddr: linkAddr1, + expectedErr: tcpip.ErrNetworkUnreachable, + }, + { + name: "Multicast with unassigned address", + localAddr: lladdr1, + remoteLinkAddr: "", + expectedErr: tcpip.ErrNetworkUnreachable, + }, + { + name: "Unicast with no local address available", + remoteLinkAddr: linkAddr1, + expectedErr: tcpip.ErrNetworkUnreachable, + }, + { + name: "Multicast with no local address available", + remoteLinkAddr: "", + expectedErr: tcpip.ErrNetworkUnreachable, }, } @@ -1269,26 +1320,43 @@ func TestLinkAddressRequest(t *testing.T) { } linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) - if err := linkRes.LinkAddressRequest(lladdr0, lladdr1, test.remoteLinkAddr, linkEP); err != nil { - t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", lladdr0, lladdr1, test.remoteLinkAddr, err) + if err := s.CreateNIC(nicID, linkEP); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if len(test.nicAddr) != 0 { + if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) + } + } + + // We pass a test network interface to LinkAddressRequest with the same NIC + // ID and link endpoint used by the NIC we created earlier so that we can + // mock a link address request and observe the packets sent to the link + // endpoint even though the stack uses the real NIC. + if err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr { + t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", lladdr0, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) + } + + if test.expectedErr != nil { + return } pkt, ok := linkEP.Read() if !ok { t.Fatal("expected to send a link address request") } - if pkt.Route.RemoteLinkAddress != test.expectedLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedLinkAddr) + if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) } - if pkt.Route.RemoteAddress != test.expectedAddr { - t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedAddr) + if pkt.Route.RemoteAddress != test.expectedRemoteAddr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr) } if pkt.Route.LocalAddress != lladdr1 { t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, lladdr1) } checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), checker.SrcAddr(lladdr1), - checker.DstAddr(test.expectedAddr), + checker.DstAddr(test.expectedRemoteAddr), checker.TTL(header.NDPHopLimit), checker.NDPNS( checker.NDPNSTargetAddress(lladdr0), @@ -1698,7 +1766,7 @@ func TestCallsToNeighborCache(t *testing.T) { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } nudHandler := &stubNUDHandler{} - ep := netProto.NewEndpoint(&testInterface{linkAddr: linkAddr0}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{}) + ep := netProto.NewEndpoint(&testInterface{LinkEndpoint: channel.New(0, header.IPv6MinimumMTU, linkAddr0)}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{}) defer ep.Close() if err := ep.Enable(); err != nil { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 9670696c7..1e38f3a9d 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -41,12 +41,12 @@ const ( // // Linux also uses 60 seconds for reassembly timeout: // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ipv6.h#L456 - reassembleTimeout = 60 * time.Second + ReassembleTimeout = 60 * time.Second // ProtocolNumber is the ipv6 protocol number. ProtocolNumber = header.IPv6ProtocolNumber - // maxTotalSize is maximum size that can be encoded in the 16-bit + // maxPayloadSize is the maximum size that can be encoded in the 16-bit // PayloadLength field of the ipv6 header. maxPayloadSize = 0xffff @@ -363,7 +363,11 @@ func (e *endpoint) DefaultTTL() uint8 { // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus // the network layer max header length. func (e *endpoint) MTU() uint32 { - return calculateMTU(e.nic.MTU()) + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv6MinimumSize) + if err != nil { + return 0 + } + return networkMTU } // MaxHeaderLength returns the maximum length needed by ipv6 headers (and @@ -386,27 +390,40 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s pkt.NetworkProtocolNumber = ProtocolNumber } -func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool { - return (gso == nil || gso.Type == stack.GSONone) && pkt.Size() > int(e.nic.MTU()) +func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool { + payload := pkt.TransportHeader().View().Size() + pkt.Data.Size() + return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU } // handleFragments fragments pkt and calls the handler function on each // fragment. It returns the number of fragments handled and the number of // fragments left to be processed. The IP header must already be present in the -// original packet. The mtu is the maximum size of the packets. The transport -// header protocol number is required to avoid parsing the IPv6 extension -// headers. -func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { - fragMTU := int(calculateFragmentInnerMTU(mtu, pkt)) - if fragMTU < pkt.TransportHeader().View().Size() { +// original packet. The transport header protocol number is required to avoid +// parsing the IPv6 extension headers. +func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { + networkHeader := header.IPv6(pkt.NetworkHeader().View()) + + // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are + // supported for outbound packets, their length should not affect the fragment + // maximum payload length because they should only be transmitted once. + fragmentPayloadLen := (networkMTU - header.IPv6FragmentHeaderSize) &^ 7 + if fragmentPayloadLen < header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit { + // We need at least 8 bytes of space left for the fragmentable part because + // the fragment payload must obviously be non-zero and must be a multiple + // of 8 as per RFC 8200 section 4.5: + // Each complete fragment, except possibly the last ("rightmost") one, is + // an integer multiple of 8 octets long. + return 0, 1, tcpip.ErrMessageTooLong + } + + if fragmentPayloadLen < uint32(pkt.TransportHeader().View().Size()) { // As per RFC 8200 Section 4.5, the Transport Header is expected to be small // enough to fit in the first fragment. return 0, 1, tcpip.ErrMessageTooLong } - pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, calculateFragmentReserve(pkt)) + pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadLen, calculateFragmentReserve(pkt)) id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, e.protocol.hashIV)%buckets], 1) - networkHeader := header.IPv6(pkt.NetworkHeader().View()) var n int for { @@ -468,8 +485,14 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet return nil } - if e.packetMustBeFragmented(pkt, gso) { - sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } + + if packetMustBeFragmented(pkt, networkMTU, gso) { + sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to @@ -499,13 +522,20 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return pkts.Len(), nil } + linkMTU := e.nic.MTU() for pb := pkts.Front(); pb != nil; pb = pb.Next() { e.addIPHeader(r, pb, params) - if e.packetMustBeFragmented(pb, gso) { + + networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + return 0, err + } + if packetMustBeFragmented(pb, networkMTU, gso) { // Keep track of the packet that is about to be fragmented so it can be // removed once the fragmentation is done. originalPkt := pb - if _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // Modify the packet list in place with the new fragments. pkts.InsertAfter(pb, fragPkt) pb = fragPkt @@ -569,7 +599,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return n + len(dropped), nil } -// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. +// WriteHeaderIncludedPacket implements stack.NetworkEndpoint. func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required checks. h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) @@ -747,6 +777,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { continue } + fragmentFieldOffset := it.ParseOffset() + // Don't consume the iterator if we have the first fragment because we // will use it to validate that the first fragment holds the upper layer // header. @@ -804,17 +836,59 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { return } + // As per RFC 2460 Section 4.5: + // + // If the length of a fragment, as derived from the fragment packet's + // Payload Length field, is not a multiple of 8 octets and the M flag + // of that fragment is 1, then that fragment must be discarded and an + // ICMP Parameter Problem, Code 0, message should be sent to the source + // of the fragment, pointing to the Payload Length field of the + // fragment packet. + if extHdr.More() && fragmentPayloadLen%header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit != 0 { + r.Stats().IP.MalformedPacketsReceived.Increment() + r.Stats().IP.MalformedFragmentsReceived.Increment() + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6ErroneousHeader, + pointer: header.IPv6PayloadLenOffset, + }, pkt) + return + } + // The packet is a fragment, let's try to reassemble it. start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit - // Drop the fragment if the size of the reassembled payload would exceed - // the maximum payload size. + // As per RFC 2460 Section 4.5: + // + // If the length and offset of a fragment are such that the Payload + // Length of the packet reassembled from that fragment would exceed + // 65,535 octets, then that fragment must be discarded and an ICMP + // Parameter Problem, Code 0, message should be sent to the source of + // the fragment, pointing to the Fragment Offset field of the fragment + // packet. if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() + _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + code: header.ICMPv6ErroneousHeader, + pointer: fragmentFieldOffset, + }, pkt) return } + // Set up a callback in case we need to send a Time Exceeded Message as + // per RFC 2460 Section 4.5. + var releaseCB func(bool) + if start == 0 { + pkt := pkt.Clone() + r := r.Clone() + releaseCB = func(timedOut bool) { + if timedOut { + _ = e.protocol.returnError(&r, &icmpReasonReassemblyTimeout{}, pkt) + } + r.Release() + } + } + // Note that pkt doesn't have its transport header set after reassembly, // and won't until DeliverNetworkPacket sets it. data, proto, ready, err := e.protocol.fragmentation.Process( @@ -830,6 +904,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { extHdr.More(), uint8(rawPayload.Identifier), rawPayload.Buf, + releaseCB, ) if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() @@ -1427,14 +1502,31 @@ func (p *protocol) SetForwarding(v bool) { } } -// calculateMTU calculates the network-layer payload MTU based on the link-layer -// payload mtu. -func calculateMTU(mtu uint32) uint32 { - mtu -= header.IPv6MinimumSize - if mtu <= maxPayloadSize { - return mtu +// calculateNetworkMTU calculates the network-layer payload MTU based on the +// link-layer payload MTU and the length of every IPv6 header. +// Note that this is different than the Payload Length field of the IPv6 header, +// which includes the length of the extension headers. +func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, *tcpip.Error) { + if linkMTU < header.IPv6MinimumMTU { + return 0, tcpip.ErrInvalidEndpointState } - return maxPayloadSize + + // As per RFC 7112 section 5, we should discard packets if their IPv6 header + // is bigger than 1280 bytes (ie, the minimum link MTU) since we do not + // support PMTU discovery: + // Hosts that do not discover the Path MTU MUST limit the IPv6 Header Chain + // length to 1280 bytes. Limiting the IPv6 Header Chain length to 1280 + // bytes ensures that the header chain length does not exceed the IPv6 + // minimum MTU. + if networkHeadersLen > header.IPv6MinimumMTU { + return 0, tcpip.ErrMalformedHeader + } + + networkMTU := linkMTU - uint32(networkHeadersLen) + if networkMTU > maxPayloadSize { + networkMTU = maxPayloadSize + } + return networkMTU, nil } // Options holds options to configure a new protocol. @@ -1488,7 +1580,7 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { return func(s *stack.Stack) stack.NetworkProtocol { p := &protocol{ stack: s, - fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()), + fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()), ids: ids, hashIV: hashIV, @@ -1509,23 +1601,6 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol { return NewProtocolWithOptions(Options{})(s) } -// calculateFragmentInnerMTU calculates the maximum number of bytes of -// fragmentable data a fragment can have, based on the link layer mtu and pkt's -// network header size. -func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 { - // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are - // supported for outbound packets, their length should not affect the fragment - // MTU because they should only be transmitted once. - mtu -= uint32(pkt.NetworkHeader().View().Size()) - mtu -= header.IPv6FragmentHeaderSize - // Round the MTU down to align to 8 bytes. - mtu &^= 7 - if mtu <= maxPayloadSize { - return mtu - } - return maxPayloadSize -} - func calculateFragmentReserve(pkt *stack.PacketBuffer) int { return pkt.AvailableHeaderBytes() + pkt.NetworkHeader().View().Size() + header.IPv6FragmentHeaderSize } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 297868f24..c593c0004 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/testutil" @@ -238,7 +239,7 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, }) - e := channel.New(10, 1280, linkAddr1) + e := channel.New(10, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) } @@ -271,7 +272,7 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, }) - e := channel.New(1, 1280, linkAddr1) + e := channel.New(1, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -825,7 +826,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - e := channel.New(1, 1280, linkAddr1) + e := channel.New(1, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -1844,7 +1845,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - e := channel.New(0, 1280, linkAddr1) + e := channel.New(0, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -1912,16 +1913,19 @@ func TestReceiveIPv6Fragments(t *testing.T) { func TestInvalidIPv6Fragments(t *testing.T) { const ( - nicID = 1 - fragmentExtHdrLen = 8 + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + nicID = 1 + hoplimit = 255 + ident = 1 + data = "TEST_INVALID_IPV6_FRAGMENTS" ) - payloadGen := func(payloadLen int) []byte { - payload := make([]byte, payloadLen) - for i := 0; i < len(payload); i++ { - payload[i] = 0x30 - } - return payload + type fragmentData struct { + ipv6Fields header.IPv6Fields + ipv6FragmentFields header.IPv6FragmentFields + payload []byte } tests := []struct { @@ -1929,31 +1933,64 @@ func TestInvalidIPv6Fragments(t *testing.T) { fragments []fragmentData wantMalformedIPPackets uint64 wantMalformedFragments uint64 + expectICMP bool + expectICMPType header.ICMPv6Type + expectICMPCode header.ICMPv6Code + expectICMPTypeSpecific uint32 }{ { + name: "fragment size is not a multiple of 8 and the M flag is true", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 9, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0 >> 3, + M: true, + Identification: ident, + }, + payload: []byte(data)[:9], + }, + }, + wantMalformedIPPackets: 1, + wantMalformedFragments: 1, + expectICMP: true, + expectICMPType: header.ICMPv6ParamProblem, + expectICMPCode: header.ICMPv6ErroneousHeader, + expectICMPTypeSpecific: header.IPv6PayloadLenOffset, + }, + { name: "fragments reassembled into a payload exceeding the max IPv6 payload size", fragments: []fragmentData{ { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+(header.IPv6MaximumPayloadSize+1)-16, - []buffer.View{ - // Fragment extension header. - // Fragment offset = 8190, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, - ((header.IPv6MaximumPayloadSize + 1) - 16) >> 8, - ((header.IPv6MaximumPayloadSize + 1) - 16) & math.MaxUint8, - 0, 0, 0, 1}), - // Payload length = 16 - payloadGen(16), - }, - ), + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3, + M: false, + Identification: ident, + }, + payload: []byte(data)[:16], }, }, wantMalformedIPPackets: 1, wantMalformedFragments: 1, + expectICMP: true, + expectICMPType: header.ICMPv6ParamProblem, + expectICMPCode: header.ICMPv6ErroneousHeader, + expectICMPTypeSpecific: header.IPv6MinimumSize + 2, /* offset for 'Fragment Offset' in the fragment header */ }, } @@ -1964,33 +2001,40 @@ func TestInvalidIPv6Fragments(t *testing.T) { NewProtocol, }, }) - e := channel.New(0, 1500, linkAddr1) + e := channel.New(1, 1500, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }}) + var expectICMPPayload buffer.View for _, f := range test.fragments { - hdr := buffer.NewPrependable(header.IPv6MinimumSize) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) - // Serialize IPv6 fixed header. - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(f.data.Size()), - NextHeader: f.nextHdr, - HopLimit: 255, - SrcAddr: f.srcAddr, - DstAddr: f.dstAddr, - }) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) + ip.Encode(&f.ipv6Fields) + + fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:]) + fragHDR.Encode(&f.ipv6FragmentFields) vv := hdr.View().ToVectorisedView() - vv.Append(f.data) + vv.AppendView(f.payload) - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, - })) + }) + + if test.expectICMP { + expectICMPPayload = stack.PayloadSince(pkt.NetworkHeader()) + } + + e.InjectInbound(ProtocolNumber, pkt) } if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want { @@ -1999,6 +2043,287 @@ func TestInvalidIPv6Fragments(t *testing.T) { if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want { t.Errorf("got Stats.IP.MalformedFragmentsReceived = %d, want = %d", got, want) } + + reply, ok := e.Read() + if !test.expectICMP { + if ok { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + return + } + if !ok { + t.Fatal("expected ICMP error message missing") + } + + checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), + checker.SrcAddr(addr2), + checker.DstAddr(addr1), + checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectICMPPayload.Size())), + checker.ICMPv6( + checker.ICMPv6Type(test.expectICMPType), + checker.ICMPv6Code(test.expectICMPCode), + checker.ICMPv6TypeSpecific(test.expectICMPTypeSpecific), + checker.ICMPv6Payload([]byte(expectICMPPayload)), + ), + ) + }) + } +} + +func TestFragmentReassemblyTimeout(t *testing.T) { + const ( + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + nicID = 1 + hoplimit = 255 + ident = 1 + data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT" + ) + + type fragmentData struct { + ipv6Fields header.IPv6Fields + ipv6FragmentFields header.IPv6FragmentFields + payload []byte + } + + tests := []struct { + name string + fragments []fragmentData + expectICMP bool + }{ + { + name: "first fragment only", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + { + name: "two first fragments", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + { + name: "second fragment only", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 8, + M: false, + Identification: ident, + }, + payload: []byte(data)[16:], + }, + }, + expectICMP: false, + }, + { + name: "two fragments with a gap", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 8, + M: false, + Identification: ident, + }, + payload: []byte(data)[16:], + }, + }, + expectICMP: true, + }, + { + name: "two fragments with a gap in reverse order", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 8, + M: false, + Identification: ident, + }, + payload: []byte(data)[16:], + }, + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + NewProtocol, + }, + Clock: clock, + }) + + e := channel.New(1, 1500, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }}) + + var firstFragmentSent buffer.View + for _, f := range test.fragments { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) + ip.Encode(&f.ipv6Fields) + + fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:]) + fragHDR.Encode(&f.ipv6FragmentFields) + + vv := hdr.View().ToVectorisedView() + vv.AppendView(f.payload) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) + + if firstFragmentSent == nil && fragHDR.FragmentOffset() == 0 { + firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader()) + } + + e.InjectInbound(ProtocolNumber, pkt) + } + + clock.Advance(ReassembleTimeout) + + reply, ok := e.Read() + if !test.expectICMP { + if ok { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + return + } + if !ok { + t.Fatal("expected ICMP error message missing") + } + if firstFragmentSent == nil { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + + checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), + checker.SrcAddr(addr2), + checker.DstAddr(addr1), + checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+firstFragmentSent.Size())), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6TimeExceeded), + checker.ICMPv6Code(header.ICMPv6ReassemblyTimeout), + checker.ICMPv6Payload([]byte(firstFragmentSent)), + ), + ) }) } } @@ -2230,8 +2555,8 @@ var fragmentationTests = []struct { wantFragments []fragmentInfo }{ { - description: "No Fragmentation", - mtu: 1280, + description: "No fragmentation", + mtu: header.IPv6MinimumMTU, gso: nil, transHdrLen: 0, payloadSize: 1000, @@ -2241,7 +2566,18 @@ var fragmentationTests = []struct { }, { description: "Fragmented", - mtu: 1280, + mtu: header.IPv6MinimumMTU, + gso: nil, + transHdrLen: 0, + payloadSize: 2000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1240, more: true}, + {offset: 154, payloadSize: 776, more: false}, + }, + }, + { + description: "Fragmented with mtu not a multiple of 8", + mtu: header.IPv6MinimumMTU + 1, gso: nil, transHdrLen: 0, payloadSize: 2000, @@ -2262,7 +2598,7 @@ var fragmentationTests = []struct { }, { description: "Fragmented with gso none", - mtu: 1280, + mtu: header.IPv6MinimumMTU, gso: &stack.GSO{Type: stack.GSONone}, transHdrLen: 0, payloadSize: 1400, @@ -2273,7 +2609,7 @@ var fragmentationTests = []struct { }, { description: "Fragmented with big header", - mtu: 1280, + mtu: header.IPv6MinimumMTU, gso: nil, transHdrLen: 100, payloadSize: 1200, @@ -2448,8 +2784,8 @@ func TestFragmentationErrors(t *testing.T) { wantError: tcpip.ErrAborted, }, { - description: "Error on packet with MTU smaller than transport header", - mtu: 1280, + description: "Error when MTU is smaller than transport header", + mtu: header.IPv6MinimumMTU, transHdrLen: 1500, payloadSize: 500, allowPackets: 0, @@ -2457,6 +2793,16 @@ func TestFragmentationErrors(t *testing.T) { mockError: nil, wantError: tcpip.ErrMessageTooLong, }, + { + description: "Error when MTU is smaller than IPv6 minimum MTU", + mtu: header.IPv6MinimumMTU - 1, + transHdrLen: 0, + payloadSize: 500, + allowPackets: 0, + outgoingErrors: 1, + mockError: nil, + wantError: tcpip.ErrInvalidEndpointState, + }, } for _, ft := range tests { diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index 4d3acab96..261705575 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -361,6 +361,8 @@ func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) * return tcpip.ErrInvalidEndpointState } + a.mu.Lock() + defer a.mu.Unlock() return a.removePermanentEndpointLocked(addrState) } diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index cf042309e..380688038 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -178,7 +178,7 @@ func (*fwdTestNetworkProtocol) Close() {} func (*fwdTestNetworkProtocol) Wait() {} -func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { +func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { if f.onLinkAddressResolved != nil { time.AfterFunc(f.addrResolveDelay, func() { f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr) diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 6f73a0ce4..c9b13cd0e 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -180,7 +180,7 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt } // get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { if linkRes != nil { if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok { return addr, nil, nil @@ -221,7 +221,7 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo } entry.done = make(chan struct{}) - go c.startAddressResolution(k, linkRes, localAddr, linkEP, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. } return entry.linkAddr, entry.done, tcpip.ErrWouldBlock @@ -240,11 +240,11 @@ func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) { } } -func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, done <-chan struct{}) { +func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) { for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check // whether the request succeeded. - linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, linkEP) + linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, nic) select { case now := <-time.After(c.resolutionTimeout): diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 33806340e..d2e37f38d 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -49,8 +49,8 @@ type testLinkAddressResolver struct { onLinkAddressRequest func() } -func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { - time.AfterFunc(r.delay, func() { r.fakeRequest(addr) }) +func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { + time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) }) if f := r.onLinkAddressRequest; f != nil { f() } diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 4df288798..eebf43a1f 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -68,7 +68,7 @@ var _ NUDHandler = (*neighborCache)(nil) // reset to state incomplete, and returned. If no matching entry exists and the // cache is not full, a new entry with state incomplete is allocated and // returned. -func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry { +func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry { n.mu.Lock() defer n.mu.Unlock() @@ -84,7 +84,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li // The entry that needs to be created must be dynamic since all static // entries are directly added to the cache via addStaticEntry. - entry := newNeighborEntry(n.nic, remoteAddr, localAddr, n.state, linkRes) + entry := newNeighborEntry(n.nic, remoteAddr, n.state, linkRes) if n.dynamic.count == neighborCacheSize { e := n.dynamic.lru.Back() e.mu.Lock() @@ -111,6 +111,10 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li // provided, it will be notified when address resolution is complete (success // or not). // +// If specified, the local address must be an address local to the interface the +// neighbor cache belongs to. The local address is the source address of a +// packet prompting NUD/link address resolution. +// // If address resolution is required, ErrNoLinkAddress and a notification // channel is returned for the top level caller to block. Channel is closed // once address resolution is complete (success or not). @@ -118,7 +122,6 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok { e := NeighborEntry{ Addr: remoteAddr, - LocalAddr: localAddr, LinkAddr: linkAddr, State: Static, UpdatedAt: time.Now(), @@ -126,13 +129,13 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA return e, nil, nil } - entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) + entry := n.getOrCreateEntry(remoteAddr, linkRes) entry.mu.Lock() defer entry.mu.Unlock() switch s := entry.neigh.State; s { case Stale: - entry.handlePacketQueuedLocked() + entry.handlePacketQueuedLocked(localAddr) fallthrough case Reachable, Static, Delay, Probe: // As per RFC 4861 section 7.3.3: @@ -152,7 +155,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA entry.done = make(chan struct{}) } - entry.handlePacketQueuedLocked() + entry.handlePacketQueuedLocked(localAddr) return entry.neigh, entry.done, tcpip.ErrWouldBlock case Failed: return entry.neigh, nil, tcpip.ErrNoLinkAddress @@ -207,7 +210,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd } else { // Static entry found with the same address but different link address. entry.neigh.LinkAddr = linkAddr - entry.dispatchChangeEventLocked(entry.neigh.State) + entry.dispatchChangeEventLocked() entry.mu.Unlock() return } @@ -220,8 +223,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd entry.mu.Unlock() } - entry := newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) - n.cache[addr] = entry + n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) } // removeEntryLocked removes the specified entry from the neighbor cache. @@ -292,8 +294,8 @@ func (n *neighborCache) setConfig(config NUDConfigurations) { // HandleProbe implements NUDHandler.HandleProbe by following the logic defined // in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled // by the caller. -func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { - entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) +func (n *neighborCache) HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { + entry := n.getOrCreateEntry(remoteAddr, linkRes) entry.mu.Lock() entry.handleProbeLocked(remoteLinkAddr) entry.mu.Unlock() diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index fcd54ed83..d81f00848 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -128,9 +128,8 @@ func newTestEntryStore() *testEntryStore { linkAddr := toLinkAddress(i) store.entriesMap[addr] = NeighborEntry{ - Addr: addr, - LocalAddr: testEntryLocalAddr, - LinkAddr: linkAddr, + Addr: addr, + LinkAddr: linkAddr, } } return store @@ -195,10 +194,10 @@ type testNeighborResolver struct { var _ LinkAddressResolver = (*testNeighborResolver)(nil) -func (r *testNeighborResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { +func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { // Delay handling the request to emulate network latency. r.clock.AfterFunc(r.delay, func() { - r.fakeRequest(addr) + r.fakeRequest(targetAddr) }) // Execute post address resolution action, if available. @@ -294,9 +293,8 @@ func TestNeighborCacheEntry(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -305,15 +303,19 @@ func TestNeighborCacheEntry(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -324,8 +326,8 @@ func TestNeighborCacheEntry(t *testing.T) { t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil { + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } // No more events should have been dispatched. @@ -354,9 +356,9 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -365,15 +367,19 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -391,9 +397,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -404,8 +412,8 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { } } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } } @@ -452,8 +460,8 @@ func (c *testContext) overflowCache(opts overflowOptions) error { if !ok { return fmt.Errorf("c.store.entry(%d) not found", i) } - if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock { - return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + return fmt.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(c.neigh.config().RetransmitTimer) @@ -470,23 +478,29 @@ func (c *testContext) overflowCache(opts overflowOptions) error { wantEvents = append(wantEvents, testEntryEventInfo{ EventType: entryTestRemoved, NICID: 1, - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + }, }) } wantEvents = append(wantEvents, testEntryEventInfo{ EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, testEntryEventInfo{ EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }) c.nudDisp.mu.Lock() @@ -508,10 +522,9 @@ func (c *testContext) overflowCache(opts overflowOptions) error { return fmt.Errorf("c.store.entry(%d) not found", i) } wantEntry := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } @@ -564,24 +577,27 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { if !ok { t.Fatalf("c.store.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(c.neigh.config().RetransmitTimer) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -600,9 +616,11 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -640,9 +658,11 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -682,9 +702,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -703,9 +725,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -740,9 +764,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -760,9 +786,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -800,24 +828,27 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { if !ok { t.Fatalf("c.store.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -836,16 +867,20 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -861,10 +896,9 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LocalAddr: "", // static entries don't need a local address - LinkAddr: staticLinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, }, }, } @@ -896,12 +930,12 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w) if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, _ = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Fatalf("got neigh.entry(%s, '', _, _ = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } if doneCh == nil { - t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr) } clock.Advance(typicalLatency) @@ -913,7 +947,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { id, ok := s.Fetch(false /* block */) if !ok { - t.Errorf("expected waker to be notified after neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + t.Errorf("expected waker to be notified after neigh.entry(%s, '', _, _)", entry.Addr) } if id != wakerID { t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID) @@ -923,15 +957,19 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -964,12 +1002,12 @@ func TestNeighborCacheRemoveWaker(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w) if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, _) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Fatalf("got neigh.entry(%s, '', _, _) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } if doneCh == nil { - t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr) } // Remove the waker before the neighbor cache has the opportunity to send a @@ -991,15 +1029,19 @@ func TestNeighborCacheRemoveWaker(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1028,10 +1070,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: "", // static entries don't need a local address - LinkAddr: entry.LinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) @@ -1041,9 +1082,11 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -1058,10 +1101,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LocalAddr: "", // static entries don't need a local address - LinkAddr: entry.LinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, }, }, } @@ -1089,9 +1131,8 @@ func TestNeighborCacheClear(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -1099,15 +1140,19 @@ func TestNeighborCacheClear(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1126,9 +1171,11 @@ func TestNeighborCacheClear(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + }, }, } nudDisp.mu.Lock() @@ -1149,16 +1196,20 @@ func TestNeighborCacheClear(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, { EventType: entryTestRemoved, NICID: 1, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + }, }, } nudDisp.mu.Lock() @@ -1185,24 +1236,27 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { if !ok { t.Fatalf("c.store.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -1220,9 +1274,11 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -1274,29 +1330,33 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { case <-doneCh: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) } wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1312,9 +1372,8 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { for i := neighborCacheSize; i < store.size(); i++ { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { - _, _, err := neigh.entry(frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, linkRes, nil) - if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, err) + if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil { + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", frequentlyUsedEntry.Addr, err) } } @@ -1322,15 +1381,15 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { case <-doneCh: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) } // An entry should have been removed, as per the LRU eviction strategy @@ -1342,22 +1401,28 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + }, }, { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1374,10 +1439,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { // have to be sorted before comparison. wantUnsortedEntries := []NeighborEntry{ { - Addr: frequentlyUsedEntry.Addr, - LocalAddr: frequentlyUsedEntry.LocalAddr, - LinkAddr: frequentlyUsedEntry.LinkAddr, - State: Reachable, + Addr: frequentlyUsedEntry.Addr, + LinkAddr: frequentlyUsedEntry.LinkAddr, + State: Reachable, }, } @@ -1387,10 +1451,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { t.Fatalf("store.entry(%d) not found", i) } wantEntry := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } @@ -1430,9 +1493,8 @@ func TestNeighborCacheConcurrent(t *testing.T) { wg.Add(1) go func(entry NeighborEntry) { defer wg.Done() - e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != nil && err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, entry.LocalAddr, e, err, tcpip.ErrWouldBlock) + if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock) } }(entry) } @@ -1456,10 +1518,9 @@ func TestNeighborCacheConcurrent(t *testing.T) { t.Errorf("store.entry(%d) not found", i) } wantEntry := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } @@ -1488,37 +1549,36 @@ func TestNeighborCacheReplace(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { case <-doneCh: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) } // Verify the entry exists { - e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + e, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } if doneCh != nil { - t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh) + t.Errorf("unexpected done channel from neigh.entry(%s, '', _, nil): %v", entry.Addr, doneCh) } if t.Failed() { t.FailNow() } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } } @@ -1542,37 +1602,35 @@ func TestNeighborCacheReplace(t *testing.T) { // // Verify the entry's new link address and the new state. { - e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Fatalf("neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Fatalf("neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: updatedLinkAddr, - State: Delay, + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Delay, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } clock.Advance(config.DelayFirstProbeTime + typicalLatency) } // Verify that the neighbor is now reachable. { - e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) clock.Advance(typicalLatency) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: updatedLinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Reachable, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } } } @@ -1601,35 +1659,34 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) - got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + got, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } // Verify that address resolution for an unknown address returns ErrNoLinkAddress before := atomic.LoadUint32(&requestCount) entry.Addr += "2" - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress) } maxAttempts := neigh.config().MaxUnicastProbes @@ -1659,13 +1716,13 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress) } } @@ -1683,18 +1740,17 @@ func TestNeighborCacheStaticResolution(t *testing.T) { delay: typicalLatency, } - got, _, err := neigh.entry(testEntryBroadcastAddr, testEntryLocalAddr, linkRes, nil) + got, _, err := neigh.entry(testEntryBroadcastAddr, "", linkRes, nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", testEntryBroadcastAddr, testEntryLocalAddr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", testEntryBroadcastAddr, err) } want := NeighborEntry{ - Addr: testEntryBroadcastAddr, - LocalAddr: testEntryLocalAddr, - LinkAddr: testEntryBroadcastLinkAddr, - State: Static, + Addr: testEntryBroadcastAddr, + LinkAddr: testEntryBroadcastLinkAddr, + State: Static, } if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, testEntryLocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff) } } @@ -1719,9 +1775,9 @@ func BenchmarkCacheClear(b *testing.B) { if !ok { b.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - b.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + b.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } if doneCh != nil { <-doneCh diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index be61a21af..bd80f95bd 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -27,7 +27,6 @@ import ( // NeighborEntry describes a neighboring device in the local network. type NeighborEntry struct { Addr tcpip.Address - LocalAddr tcpip.Address LinkAddr tcpip.LinkAddress State NeighborState UpdatedAt time.Time @@ -106,35 +105,35 @@ type neighborEntry struct { // state, Unknown. Transition out of Unknown by calling either // `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created // neighborEntry. -func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, localAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry { +func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry { return &neighborEntry{ nic: nic, linkRes: linkRes, nudState: nudState, neigh: NeighborEntry{ - Addr: remoteAddr, - LocalAddr: localAddr, - State: Unknown, + Addr: remoteAddr, + State: Unknown, }, } } -// newStaticNeighborEntry creates a neighbor cache entry starting at the Static -// state. The entry can only transition out of Static by directly calling -// `setStateLocked`. +// newStaticNeighborEntry creates a neighbor cache entry starting at the +// Static state. The entry can only transition out of Static by directly +// calling `setStateLocked`. func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { + entry := NeighborEntry{ + Addr: addr, + LinkAddr: linkAddr, + State: Static, + UpdatedAt: time.Now(), + } if nic.stack.nudDisp != nil { - nic.stack.nudDisp.OnNeighborAdded(nic.id, addr, linkAddr, Static, time.Now()) + nic.stack.nudDisp.OnNeighborAdded(nic.id, entry) } return &neighborEntry{ nic: nic, nudState: state, - neigh: NeighborEntry{ - Addr: addr, - LinkAddr: linkAddr, - State: Static, - UpdatedAt: time.Now(), - }, + neigh: entry, } } @@ -165,17 +164,17 @@ func (e *neighborEntry) notifyWakersLocked() { // dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has // been added. -func (e *neighborEntry) dispatchAddEventLocked(nextState NeighborState) { +func (e *neighborEntry) dispatchAddEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborAdded(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now()) + nudDisp.OnNeighborAdded(e.nic.id, e.neigh) } } // dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry // has changed state or link-layer address. -func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) { +func (e *neighborEntry) dispatchChangeEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborChanged(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now()) + nudDisp.OnNeighborChanged(e.nic.id, e.neigh) } } @@ -183,7 +182,7 @@ func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) { // has been removed. func (e *neighborEntry) dispatchRemoveEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborRemoved(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, e.neigh.State, time.Now()) + nudDisp.OnNeighborRemoved(e.nic.id, e.neigh) } } @@ -206,63 +205,19 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { switch next { case Incomplete: - var retryCounter uint32 - var sendMulticastProbe func() - - sendMulticastProbe = func() { - if retryCounter == config.MaxMulticastProbes { - // "If no Neighbor Advertisement is received after - // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed. - // The sender MUST return ICMP destination unreachable indications with - // code 3 (Address Unreachable) for each packet queued awaiting address - // resolution." - RFC 4861 section 7.2.2 - // - // There is no need to send an ICMP destination unreachable indication - // since the failure to resolve the address is expected to only occur - // on this node. Thus, redirecting traffic is currently not supported. - // - // "If the error occurs on a node other than the node originating the - // packet, an ICMP error message is generated. If the error occurs on - // the originating node, an implementation is not required to actually - // create and send an ICMP error packet to the source, as long as the - // upper-layer sender is notified through an appropriate mechanism - // (e.g. return value from a procedure call). Note, however, that an - // implementation may find it convenient in some cases to return errors - // to the sender by taking the offending packet, generating an ICMP - // error message, and then delivering it (locally) through the generic - // error-handling routines.' - RFC 4861 section 2.1 - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } - - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.LinkEndpoint); err != nil { - // There is no need to log the error here; the NUD implementation may - // assume a working link. A valid link should be the responsibility of - // the NIC/stack.LinkEndpoint. - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } - - retryCounter++ - e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) - e.job.Schedule(config.RetransmitTimer) - } - - sendMulticastProbe() + panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.neigh, prev)) case Reachable: e.job = e.nic.stack.newJob(&e.mu, func() { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() }) e.job.Schedule(e.nudState.ReachableTime()) case Delay: e.job = e.nic.stack.newJob(&e.mu, func() { - e.dispatchChangeEventLocked(Probe) e.setStateLocked(Probe) + e.dispatchChangeEventLocked() }) e.job.Schedule(config.DelayFirstProbeTime) @@ -277,19 +232,13 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { return } - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.LinkEndpoint); err != nil { + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr, e.nic); err != nil { e.dispatchRemoveEventLocked() e.setStateLocked(Failed) return } retryCounter++ - if retryCounter == config.MaxUnicastProbes { - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } - e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) e.job.Schedule(config.RetransmitTimer) } @@ -315,15 +264,72 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // being queued for outgoing transmission. // // Follows the logic defined in RFC 4861 section 7.3.3. -func (e *neighborEntry) handlePacketQueuedLocked() { +func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { switch e.neigh.State { case Unknown: - e.dispatchAddEventLocked(Incomplete) - e.setStateLocked(Incomplete) + e.neigh.State = Incomplete + e.neigh.UpdatedAt = time.Now() + + e.dispatchAddEventLocked() + + config := e.nudState.Config() + + var retryCounter uint32 + var sendMulticastProbe func() + + sendMulticastProbe = func() { + if retryCounter == config.MaxMulticastProbes { + // "If no Neighbor Advertisement is received after + // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed. + // The sender MUST return ICMP destination unreachable indications with + // code 3 (Address Unreachable) for each packet queued awaiting address + // resolution." - RFC 4861 section 7.2.2 + // + // There is no need to send an ICMP destination unreachable indication + // since the failure to resolve the address is expected to only occur + // on this node. Thus, redirecting traffic is currently not supported. + // + // "If the error occurs on a node other than the node originating the + // packet, an ICMP error message is generated. If the error occurs on + // the originating node, an implementation is not required to actually + // create and send an ICMP error packet to the source, as long as the + // upper-layer sender is notified through an appropriate mechanism + // (e.g. return value from a procedure call). Note, however, that an + // implementation may find it convenient in some cases to return errors + // to the sender by taking the offending packet, generating an ICMP + // error message, and then delivering it (locally) through the generic + // error-handling routines.' - RFC 4861 section 2.1 + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + // As per RFC 4861 section 7.2.2: + // + // If the source address of the packet prompting the solicitation is the + // same as one of the addresses assigned to the outgoing interface, that + // address SHOULD be placed in the IP Source Address of the outgoing + // solicitation. + // + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, "", e.nic); err != nil { + // There is no need to log the error here; the NUD implementation may + // assume a working link. A valid link should be the responsibility of + // the NIC/stack.LinkEndpoint. + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + retryCounter++ + e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job.Schedule(config.RetransmitTimer) + } + + sendMulticastProbe() case Stale: - e.dispatchChangeEventLocked(Delay) e.setStateLocked(Delay) + e.dispatchChangeEventLocked() case Incomplete, Reachable, Delay, Probe, Static, Failed: // Do nothing @@ -345,21 +351,21 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { switch e.neigh.State { case Unknown, Incomplete, Failed: e.neigh.LinkAddr = remoteLinkAddr - e.dispatchAddEventLocked(Stale) e.setStateLocked(Stale) e.notifyWakersLocked() + e.dispatchAddEventLocked() case Reachable, Delay, Probe: if e.neigh.LinkAddr != remoteLinkAddr { e.neigh.LinkAddr = remoteLinkAddr - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() } case Stale: if e.neigh.LinkAddr != remoteLinkAddr { e.neigh.LinkAddr = remoteLinkAddr - e.dispatchChangeEventLocked(Stale) + e.dispatchChangeEventLocked() } case Static: @@ -393,12 +399,11 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla e.neigh.LinkAddr = linkAddr if flags.Solicited { - e.dispatchChangeEventLocked(Reachable) e.setStateLocked(Reachable) } else { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) } + e.dispatchChangeEventLocked() e.isRouter = flags.IsRouter e.notifyWakersLocked() @@ -411,8 +416,8 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla if isLinkAddrDifferent { if !flags.Override { if e.neigh.State == Reachable { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() } break } @@ -421,23 +426,24 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla if !flags.Solicited { if e.neigh.State != Stale { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() } else { // Notify the LinkAddr change, even though NUD state hasn't changed. - e.dispatchChangeEventLocked(e.neigh.State) + e.dispatchChangeEventLocked() } break } } if flags.Solicited && (flags.Override || !isLinkAddrDifferent) { - if e.neigh.State != Reachable { - e.dispatchChangeEventLocked(Reachable) - } + wasReachable := e.neigh.State == Reachable // Set state to Reachable again to refresh timers. e.setStateLocked(Reachable) e.notifyWakersLocked() + if !wasReachable { + e.dispatchChangeEventLocked() + } } if e.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.neigh.Addr) { @@ -475,11 +481,12 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla func (e *neighborEntry) handleUpperLevelConfirmationLocked() { switch e.neigh.State { case Reachable, Stale, Delay, Probe: - if e.neigh.State != Reachable { - e.dispatchChangeEventLocked(Reachable) - // Set state to Reachable again to refresh timers. - } + wasReachable := e.neigh.State == Reachable + // Set state to Reachable again to refresh timers. e.setStateLocked(Reachable) + if !wasReachable { + e.dispatchChangeEventLocked() + } case Unknown, Incomplete, Failed, Static: // Do nothing diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 3ee2a3b31..e8e0e571b 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -52,19 +52,16 @@ const ( // predict the time that an event will be dispatched. func eventDiffOpts() []cmp.Option { return []cmp.Option{ - cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"), + cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"), } } // eventDiffOptsWithSort is like eventDiffOpts but also includes an option to // sort slices of events for cases where ordering must be ignored. func eventDiffOptsWithSort() []cmp.Option { - return []cmp.Option{ - cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"), - cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { - return strings.Compare(string(a.Addr), string(b.Addr)) < 0 - }), - } + return append(eventDiffOpts(), cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { + return strings.Compare(string(a.Entry.Addr), string(b.Entry.Addr)) < 0 + })) } // The following unit tests exercise every state transition and verify its @@ -125,14 +122,11 @@ func (t testEntryEventType) String() string { type testEntryEventInfo struct { EventType testEntryEventType NICID tcpip.NICID - Addr tcpip.Address - LinkAddr tcpip.LinkAddress - State NeighborState - UpdatedAt time.Time + Entry NeighborEntry } func (e testEntryEventInfo) String() string { - return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.EventType, e.NICID, e.Addr, e.LinkAddr, e.State) + return fmt.Sprintf("%s event for NIC #%d, %#v", e.EventType, e.NICID, e.Entry) } // testNUDDispatcher implements NUDDispatcher to validate the dispatching of @@ -150,36 +144,27 @@ func (d *testNUDDispatcher) queueEvent(e testEntryEventInfo) { d.events = append(d.events, e) } -func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { +func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry NeighborEntry) { d.queueEvent(testEntryEventInfo{ EventType: entryTestAdded, NICID: nicID, - Addr: addr, - LinkAddr: linkAddr, - State: state, - UpdatedAt: updatedAt, + Entry: entry, }) } -func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { +func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry NeighborEntry) { d.queueEvent(testEntryEventInfo{ EventType: entryTestChanged, NICID: nicID, - Addr: addr, - LinkAddr: linkAddr, - State: state, - UpdatedAt: updatedAt, + Entry: entry, }) } -func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { +func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry NeighborEntry) { d.queueEvent(testEntryEventInfo{ EventType: entryTestRemoved, NICID: nicID, - Addr: addr, - LinkAddr: linkAddr, - State: state, - UpdatedAt: updatedAt, + Entry: entry, }) } @@ -202,9 +187,9 @@ func (p entryTestProbeInfo) String() string { // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts // to the local network if linkAddr is the zero value. -func (r *entryTestLinkResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { +func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { p := entryTestProbeInfo{ - RemoteAddress: addr, + RemoteAddress: targetAddr, RemoteLinkAddress: linkAddr, LocalAddress: localAddr, } @@ -245,7 +230,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e rng := rand.New(rand.NewSource(time.Now().UnixNano())) nudState := NewNUDState(c, rng) linkRes := entryTestLinkResolver{} - entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, entryTestAddr2 /* localAddr */, nudState, &linkRes) + entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, nudState, &linkRes) // Stub out the neighbor cache to verify deletion from the cache. nic.neigh = &neighborCache{ @@ -326,7 +311,7 @@ func TestEntryUnknownToIncomplete(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -350,9 +335,11 @@ func TestEntryUnknownToIncomplete(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, } { @@ -388,9 +375,11 @@ func TestEntryUnknownToStale(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -406,7 +395,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -468,16 +457,20 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, } nudDisp.mu.Lock() @@ -498,7 +491,7 @@ func TestEntryIncompleteToReachable(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -530,16 +523,20 @@ func TestEntryIncompleteToReachable(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -563,7 +560,7 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { defer s.Done() e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got := e.wakers; got != nil { t.Errorf("got e.wakers = %v, want = nil", got) } @@ -605,16 +602,20 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -629,7 +630,7 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -663,16 +664,20 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -687,7 +692,7 @@ func TestEntryIncompleteToStale(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -719,16 +724,20 @@ func TestEntryIncompleteToStale(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -744,7 +753,7 @@ func TestEntryIncompleteToFailed(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -783,16 +792,20 @@ func TestEntryIncompleteToFailed(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, } nudDisp.mu.Lock() @@ -822,7 +835,7 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, @@ -866,16 +879,20 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -896,7 +913,7 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, @@ -932,16 +949,20 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -961,7 +982,7 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, @@ -992,23 +1013,29 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1029,7 +1056,7 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, @@ -1062,23 +1089,29 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1099,7 +1132,7 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, @@ -1136,23 +1169,29 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1173,7 +1212,7 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, @@ -1210,23 +1249,29 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1247,7 +1292,7 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -1283,16 +1328,20 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1307,7 +1356,7 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -1347,23 +1396,29 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1378,7 +1433,7 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -1418,23 +1473,29 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1449,7 +1510,7 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -1489,23 +1550,29 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1520,7 +1587,7 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -1556,23 +1623,29 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1587,7 +1660,7 @@ func TestEntryStaleToDelay(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -1596,7 +1669,7 @@ func TestEntryStaleToDelay(t *testing.T) { if got, want := e.neigh.State, Stale; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Delay; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -1620,23 +1693,29 @@ func TestEntryStaleToDelay(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, } nudDisp.mu.Lock() @@ -1656,13 +1735,13 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Delay; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -1692,37 +1771,47 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1743,13 +1832,13 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Delay; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -1786,37 +1875,47 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1837,13 +1936,13 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if e.neigh.State != Delay { t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) } @@ -1880,37 +1979,47 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1925,13 +2034,13 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Delay; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -1966,23 +2075,29 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, } nudDisp.mu.Lock() @@ -1997,13 +2112,13 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Delay; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -2031,30 +2146,38 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2069,13 +2192,13 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { e, nudDisp, linkRes, _ := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Delay; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -2107,30 +2230,38 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2145,13 +2276,13 @@ func TestEntryDelayToProbe(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Delay; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -2170,7 +2301,6 @@ func TestEntryDelayToProbe(t *testing.T) { { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, } linkRes.mu.Lock() @@ -2184,30 +2314,38 @@ func TestEntryDelayToProbe(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() @@ -2228,13 +2366,13 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) @@ -2250,7 +2388,6 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, } linkRes.mu.Lock() @@ -2274,37 +2411,47 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2325,13 +2472,13 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) @@ -2347,7 +2494,6 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, } linkRes.mu.Lock() @@ -2375,37 +2521,47 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2426,13 +2582,13 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) @@ -2448,7 +2604,6 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, } linkRes.mu.Lock() @@ -2479,30 +2634,38 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() @@ -2529,7 +2692,7 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { e.mu.Lock() e.handleProbeLocked(entryTestLinkAddr1) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) @@ -2539,7 +2702,6 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, } linkRes.mu.Lock() @@ -2572,37 +2734,47 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2622,13 +2794,13 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) @@ -2644,7 +2816,6 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, } linkRes.mu.Lock() @@ -2677,44 +2848,56 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2734,13 +2917,13 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) @@ -2756,7 +2939,6 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, } linkRes.mu.Lock() @@ -2786,44 +2968,56 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2843,13 +3037,13 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) @@ -2865,7 +3059,6 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, } linkRes.mu.Lock() @@ -2895,44 +3088,56 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2946,87 +3151,116 @@ func TestEntryProbeToFailed(t *testing.T) { c := DefaultNUDConfigurations() c.MaxMulticastProbes = 3 c.MaxUnicastProbes = 3 + c.DelayFirstProbeTime = c.RetransmitTimer e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + { + wantProbes := []entryTestProbeInfo{ + // Caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() - waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) - clock.Advance(waitFor) + // Observe each probe sent while in the Probe state. + for i := uint32(0); i < c.MaxUnicastProbes; i++ { + clock.Advance(c.RetransmitTimer) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probe #%d mismatch (-got, +want):\n%s", i+1, diff) + } - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The next three probe are caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, + e.mu.Lock() + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + } + e.mu.Unlock() } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + + // Wait for the last probe to expire, causing a transition to Failed. + clock.Advance(c.RetransmitTimer) + e.mu.Lock() + if e.neigh.State != Failed { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed) } + e.mu.Unlock() wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() @@ -3034,12 +3268,6 @@ func TestEntryProbeToFailed(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Failed; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryFailedGetsDeleted(t *testing.T) { @@ -3054,13 +3282,13 @@ func TestEntryFailedGetsDeleted(t *testing.T) { } e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime @@ -3077,17 +3305,14 @@ func TestEntryFailedGetsDeleted(t *testing.T) { { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, } linkRes.mu.Lock() @@ -3101,37 +3326,47 @@ func TestEntryFailedGetsDeleted(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index dcd4319bf..17f2e6b46 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -273,6 +273,15 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb return n.writePacket(r, gso, protocol, pkt) } +// WritePacketToRemote implements NetworkInterface. +func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { + r := Route{ + NetProto: protocol, + RemoteLinkAddress: remoteLinkAddr, + } + return n.writePacket(&r, gso, protocol, pkt) +} + func (n *NIC) writePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { // WritePacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Size() @@ -678,7 +687,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // n doesn't have a destination endpoint. // Send the packet out of n. - // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6. + // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease + // the TTL field for ipv4/ipv6. // pkt may have set its header and may not have enough headroom for // link-layer header for the other link to prepend. Here we create a new diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 97a96af62..4af04846f 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -169,7 +169,7 @@ func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements LinkAddressResolver. -func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { +func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { return nil } diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index e1ec15487..ab629b3a4 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -129,7 +129,7 @@ type NUDDispatcher interface { // the stack's operation. // // May be called concurrently. - OnNeighborAdded(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + OnNeighborAdded(tcpip.NICID, NeighborEntry) // OnNeighborChanged will be called when an entry in a NIC's (with ID nicID) // neighbor table changes state and/or link address. @@ -138,7 +138,7 @@ type NUDDispatcher interface { // the stack's operation. // // May be called concurrently. - OnNeighborChanged(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + OnNeighborChanged(tcpip.NICID, NeighborEntry) // OnNeighborRemoved will be called when an entry is removed from a NIC's // (with ID nicID) neighbor table. @@ -147,7 +147,7 @@ type NUDDispatcher interface { // the stack's operation. // // May be called concurrently. - OnNeighborRemoved(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + OnNeighborRemoved(tcpip.NICID, NeighborEntry) } // ReachabilityConfirmationFlags describes the flags used within a reachability @@ -177,7 +177,7 @@ type NUDHandler interface { // Neighbor Solicitation for ARP or NDP, respectively). Validation of the // probe needs to be performed before calling this function since the // Neighbor Cache doesn't have access to view the NIC's assigned addresses. - HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) + HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP // reply or Neighbor Advertisement for ARP or NDP, respectively). diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index defb9129b..203f3b51f 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -490,6 +490,9 @@ type NetworkInterface interface { // Enabled returns true if the interface is enabled. Enabled() bool + + // WritePacketToRemote writes the packet to the given remote link address. + WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error } // NetworkEndpoint is the interface that needs to be implemented by endpoints @@ -764,13 +767,13 @@ type InjectableLinkEndpoint interface { // A LinkAddressResolver is an extension to a NetworkProtocol that // can resolve link addresses. type LinkAddressResolver interface { - // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts - // the request on the local network if remoteLinkAddr is the zero value. The - // request is sent on linkEP with localAddr as the source. + // LinkAddressRequest sends a request for the link address of the target + // address. The request is broadcasted on the local network if a remote link + // address is not provided. // - // A valid response will cause the discovery protocol's network - // endpoint to call AddLinkAddress. - LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error + // The request is sent from the passed network interface. If the interface + // local address is unspecified, any interface local address may be used. + LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) *tcpip.Error // ResolveStaticAddress attempts to resolve address without sending // requests. It either resolves the name immediately or returns the diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 3a07577c8..e8f1c110e 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -830,6 +830,20 @@ func (s *Stack) AddRoute(route tcpip.Route) { s.routeTable = append(s.routeTable, route) } +// RemoveRoutes removes matching routes from the route table. +func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) { + s.mu.Lock() + defer s.mu.Unlock() + + var filteredRoutes []tcpip.Route + for _, route := range s.routeTable { + if !match(route) { + filteredRoutes = append(filteredRoutes, route) + } + } + s.routeTable = filteredRoutes +} + // NewEndpoint creates a new transport layer endpoint of the given protocol. func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { t, ok := s.transportProtocols[transport] @@ -1323,7 +1337,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} linkRes := s.linkAddrResolvers[protocol] - return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.LinkEndpoint, waker) + return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, waker) } // Neighbors returns all IP to MAC address associations. diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index e75f58c64..4eed4ced4 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -3672,3 +3672,87 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) } } + +// TestAddRoute tests Stack.AddRoute +func TestAddRoute(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{}) + + subnet1, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + + subnet2, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + + expected := []tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + {Destination: subnet2, Gateway: "\x00", NIC: 1}, + } + + // Initialize the route table with one route. + s.SetRouteTable([]tcpip.Route{expected[0]}) + + // Add another route. + s.AddRoute(expected[1]) + + rt := s.GetRouteTable() + if got, want := len(rt), len(expected); got != want { + t.Fatalf("Unexpected route table length got = %d, want = %d", got, want) + } + for i, route := range rt { + if got, want := route, expected[i]; got != want { + t.Fatalf("Unexpected route got = %#v, want = %#v", got, want) + } + } +} + +// TestRemoveRoutes tests Stack.RemoveRoutes +func TestRemoveRoutes(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{}) + + addressToRemove := tcpip.Address("\x01") + subnet1, err := tcpip.NewSubnet(addressToRemove, "\x01") + if err != nil { + t.Fatal(err) + } + + subnet2, err := tcpip.NewSubnet(addressToRemove, "\x01") + if err != nil { + t.Fatal(err) + } + + subnet3, err := tcpip.NewSubnet("\x02", "\x02") + if err != nil { + t.Fatal(err) + } + + // Initialize the route table with three routes. + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + {Destination: subnet2, Gateway: "\x00", NIC: 1}, + {Destination: subnet3, Gateway: "\x00", NIC: 1}, + }) + + // Remove routes with the specific address. + s.RemoveRoutes(func(r tcpip.Route) bool { + return r.Destination.ID() == addressToRemove + }) + + expected := []tcpip.Route{{Destination: subnet3, Gateway: "\x00", NIC: 1}} + rt := s.GetRouteTable() + if got, want := len(rt), len(expected); got != want { + t.Fatalf("Unexpected route table length got = %d, want = %d", got, want) + } + for i, route := range rt { + if got, want := route, expected[i]; got != want { + t.Fatalf("Unexpected route got = %#v, want = %#v", got, want) + } + } +} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 62ab6d92f..6b8071467 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -28,7 +28,7 @@ import ( const ( fakeTransNumber tcpip.TransportProtocolNumber = 1 - fakeTransHeaderLen = 3 + fakeTransHeaderLen int = 3 ) // fakeTransportEndpoint is a transport-layer protocol endpoint. It counts diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index d77848d61..3ab2b7654 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -356,10 +356,9 @@ func (s *Subnet) IsBroadcast(address Address) bool { return s.Prefix() <= 30 && s.Broadcast() == address } -// Equal returns true if s equals o. -// -// Needed to use cmp.Equal on Subnet as its fields are unexported. +// Equal returns true if this Subnet is equal to the given Subnet. func (s Subnet) Equal(o Subnet) bool { + // If this changes, update Route.Equal accordingly. return s == o } @@ -763,6 +762,10 @@ const ( // endpoint that all packets being written have an IP header and the // endpoint should not attach an IP header. IPHdrIncludedOption + + // AcceptConnOption is used by GetSockOptBool to indicate if the + // socket is a listening socket. + AcceptConnOption ) // SockOptInt represents socket options which values have the int type. @@ -1256,6 +1259,12 @@ func (r Route) String() string { return out.String() } +// Equal returns true if the given Route is equal to this Route. +func (r Route) Equal(to Route) bool { + // NOTE: This relies on the fact that r.Destination == to.Destination + return r == to +} + // TransportProtocolNumber is the number of a transport protocol. type TransportProtocolNumber uint32 @@ -1496,6 +1505,15 @@ type IPStats struct { // IPTablesOutputDropped is the total number of IP packets dropped in // the Output chain. IPTablesOutputDropped *StatCounter + + // OptionTSReceived is the number of Timestamp options seen. + OptionTSReceived *StatCounter + + // OptionRRReceived is the number of Record Route options seen. + OptionRRReceived *StatCounter + + // OptionUnknownReceived is the number of unknown IP options seen. + OptionUnknownReceived *StatCounter } // TCPStats collects TCP-specific stats. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 41eb0ca44..a17234946 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -378,7 +378,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { - case tcpip.KeepaliveEnabledOption: + case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: return false, nil default: diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 072601d2d..31831a6d8 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -389,7 +389,12 @@ func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - return false, tcpip.ErrNotSupported + switch opt { + case tcpip.AcceptConnOption: + return false, nil + default: + return false, tcpip.ErrNotSupported + } } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index e37c00523..79f688129 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -601,7 +601,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { - case tcpip.KeepaliveEnabledOption: + case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: return false, nil case tcpip.IPHdrIncludedOption: diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 33bfb56cd..7d97cbdc7 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -37,57 +37,57 @@ func (p *rawPacket) loadData(data buffer.VectorisedView) { } // beforeSave is invoked by stateify. -func (ep *endpoint) beforeSave() { +func (e *endpoint) beforeSave() { // Stop incoming packets from being handled (and mutate endpoint state). // The lock will be released after saveRcvBufSizeMax(), which would have - // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming + // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming // packets. - ep.rcvMu.Lock() + e.rcvMu.Lock() } // saveRcvBufSizeMax is invoked by stateify. -func (ep *endpoint) saveRcvBufSizeMax() int { - max := ep.rcvBufSizeMax +func (e *endpoint) saveRcvBufSizeMax() int { + max := e.rcvBufSizeMax // Make sure no new packets will be handled regardless of the lock. - ep.rcvBufSizeMax = 0 + e.rcvBufSizeMax = 0 // Release the lock acquired in beforeSave() so regular endpoint closing // logic can proceed after save. - ep.rcvMu.Unlock() + e.rcvMu.Unlock() return max } // loadRcvBufSizeMax is invoked by stateify. -func (ep *endpoint) loadRcvBufSizeMax(max int) { - ep.rcvBufSizeMax = max +func (e *endpoint) loadRcvBufSizeMax(max int) { + e.rcvBufSizeMax = max } // afterLoad is invoked by stateify. -func (ep *endpoint) afterLoad() { - stack.StackFromEnv.RegisterRestoredEndpoint(ep) +func (e *endpoint) afterLoad() { + stack.StackFromEnv.RegisterRestoredEndpoint(e) } // Resume implements tcpip.ResumableEndpoint.Resume. -func (ep *endpoint) Resume(s *stack.Stack) { - ep.stack = s +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s // If the endpoint is connected, re-connect. - if ep.connected { + if e.connected { var err *tcpip.Error - ep.route, err = ep.stack.FindRoute(ep.RegisterNICID, ep.BindAddr, ep.route.RemoteAddress, ep.NetProto, false) + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.route.RemoteAddress, e.NetProto, false) if err != nil { panic(err) } } // If the endpoint is bound, re-bind. - if ep.bound { - if ep.stack.CheckLocalAddress(ep.RegisterNICID, ep.NetProto, ep.BindAddr) == 0 { + if e.bound { + if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.BindAddr) == 0 { panic(tcpip.ErrBadLocalAddress) } } - if ep.associated { - if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil { + if e.associated { + if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil { panic(err) } } diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index b706438bd..6b3238d6b 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -425,20 +425,17 @@ func (e *endpoint) notifyAborted() { // cookies to accept connections. func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { defer ctx.synRcvdCount.dec() - defer func() { - e.mu.Lock() - e.decSynRcvdCount() - e.mu.Unlock() - }() defer s.decRef() n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() + e.decSynRcvdCount() return } ctx.removePendingEndpoint(n) + e.decSynRcvdCount() n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() @@ -456,7 +453,9 @@ func (e *endpoint) incSynRcvdCount() bool { } func (e *endpoint) decSynRcvdCount() { + e.mu.Lock() e.synRcvdCount-- + e.mu.Unlock() } func (e *endpoint) acceptQueueIsFull() bool { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 3bcd3923a..c826942e9 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1999,6 +1999,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { case tcpip.MulticastLoopOption: return true, nil + case tcpip.AcceptConnOption: + e.LockUser() + defer e.UnlockUser() + + return e.EndpointState() == StateListen, nil + default: return false, tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard.go b/pkg/tcpip/transport/tcp/sack_scoreboard.go index 7ef2df377..833a7b470 100644 --- a/pkg/tcpip/transport/tcp/sack_scoreboard.go +++ b/pkg/tcpip/transport/tcp/sack_scoreboard.go @@ -164,7 +164,7 @@ func (s *SACKScoreboard) IsSACKED(r header.SACKBlock) bool { return found } -// Dump prints the state of the scoreboard structure. +// String returns human-readable state of the scoreboard structure. func (s *SACKScoreboard) String() string { var str strings.Builder str.WriteString("SACKScoreboard: {") diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index a7149efd0..5f05608e2 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -5131,6 +5131,7 @@ func TestKeepalive(t *testing.T) { } func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { + t.Helper() // Send a SYN request. irs = seqnum.Value(789) c.SendPacket(nil, &context.Headers{ @@ -5175,6 +5176,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki } func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { + t.Helper() // Send a SYN request. irs = seqnum.Value(789) c.SendV6Packet(nil, &context.Headers{ @@ -5238,13 +5240,14 @@ func TestListenBacklogFull(t *testing.T) { // Test acceptance. // Start listening. - listenBacklog := 2 + listenBacklog := 10 if err := c.EP.Listen(listenBacklog); err != nil { t.Fatalf("Listen failed: %s", err) } - for i := 0; i < listenBacklog; i++ { - executeHandshake(t, c, context.TestPort+uint16(i), false /*synCookieInUse */) + lastPortOffset := uint16(0) + for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ { + executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) } time.Sleep(50 * time.Millisecond) @@ -5252,7 +5255,7 @@ func TestListenBacklogFull(t *testing.T) { // Now execute send one more SYN. The stack should not respond as the backlog // is full at this point. c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + 2, + SrcPort: context.TestPort + uint16(lastPortOffset), DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: seqnum.Value(789), @@ -5293,7 +5296,7 @@ func TestListenBacklogFull(t *testing.T) { } // Now a new handshake must succeed. - executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */) + executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) newEP, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { @@ -6722,6 +6725,13 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) + // drain any older notifications from the notification channel before attempting + // 2nd connection. + select { + case <-ch: + default: + } + // Send a SYN request w/ sequence number higher than // the highest sequence number sent. iss = seqnum.Value(792) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 4d7847142..79646fefe 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -373,6 +373,9 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code header.ICMPv4Code, const icmpv4VariableHeaderOffset = 4 copy(icmp[icmpv4VariableHeaderOffset:], p1) copy(icmp[header.ICMPv4PayloadOffset:], p2) + icmp.SetChecksum(0) + checksum := ^header.Checksum(icmp, 0 /* initial */) + icmp.SetChecksum(checksum) // Inject packet. pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index d57ed5d79..cdb5127ab 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -895,6 +895,9 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { return v, nil + case tcpip.AcceptConnOption: + return false, nil + default: return false, tcpip.ErrUnknownProtocolOption } @@ -1366,6 +1369,12 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { e.rcvMu.Unlock() } + e.lastErrorMu.Lock() + hasError := e.lastError != nil + e.lastErrorMu.Unlock() + if hasError { + result |= waiter.EventErr + } return result } @@ -1465,14 +1474,16 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { if typ == stack.ControlPortUnreachable { e.mu.RLock() - defer e.mu.RUnlock() - if e.state == StateConnected { e.lastErrorMu.Lock() - defer e.lastErrorMu.Unlock() - e.lastError = tcpip.ErrConnectionRefused + e.lastErrorMu.Unlock() + e.mu.RUnlock() + + e.waiterQueue.Notify(waiter.EventErr) + return } + e.mu.RUnlock() } } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index b4604ba35..fb7738dda 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -1452,6 +1452,10 @@ func (*testInterface) Enabled() bool { return true } +func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { + return tcpip.ErrNotSupported +} + func TestTTL(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -1791,7 +1795,6 @@ func TestV4UnknownDestination(t *testing.T) { // had only a minimal IP header but the ICMP sender will have allowed // for a maximally sized packet header. wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength - } // In the case of large payloads the IP packet may be truncated. Update diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go index 67a950444..08519d986 100644 --- a/pkg/waiter/waiter.go +++ b/pkg/waiter/waiter.go @@ -168,7 +168,7 @@ func NewChannelEntry(c chan struct{}) (Entry, chan struct{}) { // // +stateify savable type Queue struct { - list waiterList `state:"zerovalue"` + list waiterList mu sync.RWMutex `state:"nosave"` } diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 248f77c34..b97dc3c47 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -38,6 +38,7 @@ go_library( "//pkg/memutil", "//pkg/rand", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sentry/arch", "//pkg/sentry/arch:registers_go_proto", "//pkg/sentry/control", diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go index 894651519..4e0f0d57a 100644 --- a/runsc/boot/controller.go +++ b/runsc/boot/controller.go @@ -30,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/sentry/state" "gvisor.dev/gvisor/pkg/sentry/time" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sentry/watchdog" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/urpc" @@ -367,12 +368,20 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { cm.l.k = k // Set up the restore environment. + ctx := k.SupervisorContext() mntr := newContainerMounter(cm.l.root.spec, cm.l.root.goferFDs, cm.l.k, cm.l.mountHints) - renv, err := mntr.createRestoreEnvironment(cm.l.root.conf) - if err != nil { - return fmt.Errorf("creating RestoreEnvironment: %v", err) + if kernel.VFS2Enabled { + ctx, err = mntr.configureRestore(ctx, cm.l.root.conf) + if err != nil { + return fmt.Errorf("configuring filesystem restore: %v", err) + } + } else { + renv, err := mntr.createRestoreEnvironment(cm.l.root.conf) + if err != nil { + return fmt.Errorf("creating RestoreEnvironment: %v", err) + } + fs.SetRestoreEnvironment(*renv) } - fs.SetRestoreEnvironment(*renv) // Prepare to load from the state file. if eps, ok := networkStack.(*netstack.Stack); ok { @@ -399,7 +408,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { // Load the state. loadOpts := state.LoadOpts{Source: specFile} - if err := loadOpts.Load(k, networkStack, time.NewCalibratedClocks()); err != nil { + if err := loadOpts.Load(ctx, k, networkStack, time.NewCalibratedClocks(), &vfs.CompleteRestoreOptions{}); err != nil { return err } diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go index ddf288456..6b6ae98d7 100644 --- a/runsc/boot/fs.go +++ b/runsc/boot/fs.go @@ -105,33 +105,28 @@ func addOverlay(ctx context.Context, conf *config.Config, lower *fs.Inode, name // mandatory mounts that are required by the OCI specification. func compileMounts(spec *specs.Spec) []specs.Mount { // Keep track of whether proc and sys were mounted. - var procMounted, sysMounted bool + var procMounted, sysMounted, devMounted, devptsMounted bool var mounts []specs.Mount - // Always mount /dev. - mounts = append(mounts, specs.Mount{ - Type: devtmpfs.Name, - Destination: "/dev", - }) - - mounts = append(mounts, specs.Mount{ - Type: devpts.Name, - Destination: "/dev/pts", - }) - // Mount all submounts from the spec. for _, m := range spec.Mounts { if !specutils.IsSupportedDevMount(m) { log.Warningf("ignoring dev mount at %q", m.Destination) continue } - mounts = append(mounts, m) switch filepath.Clean(m.Destination) { case "/proc": procMounted = true case "/sys": sysMounted = true + case "/dev": + m.Type = devtmpfs.Name + devMounted = true + case "/dev/pts": + m.Type = devpts.Name + devptsMounted = true } + mounts = append(mounts, m) } // Mount proc and sys even if the user did not ask for it, as the spec @@ -149,6 +144,18 @@ func compileMounts(spec *specs.Spec) []specs.Mount { Destination: "/sys", }) } + if !devMounted { + mandatoryMounts = append(mandatoryMounts, specs.Mount{ + Type: devtmpfs.Name, + Destination: "/dev", + }) + } + if !devptsMounted { + mandatoryMounts = append(mandatoryMounts, specs.Mount{ + Type: devpts.Name, + Destination: "/dev/pts", + }) + } // The mandatory mounts should be ordered right after the root, in case // there are submounts of these mandatory mounts already in the spec. diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 8ad000497..8c6ab213d 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/memutil" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/fdimport" @@ -476,6 +477,12 @@ func (l *Loader) Destroy() { // save/restore. l.k.Release() + // All sentry-created resources should have been released at this point; + // check for reference leaks. + if refsvfs2.LeakCheckEnabled() { + refsvfs2.DoLeakCheck() + } + // In the success case, stdioFDs and goferFDs will only contain // released/closed FDs that ownership has been passed over to host FDs and // gofer sessions. Close them here in case of failure. @@ -737,7 +744,7 @@ func (l *Loader) createContainerProcess(root bool, cid string, info *containerIn return nil, err } - // Add the HOME enviroment variable if it is not already set. + // Add the HOME environment variable if it is not already set. var envv []string if kernel.VFS2Enabled { envv, err = user.MaybeAddExecUserHomeVFS2(ctx, info.procArgs.MountNamespaceVFS2, diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go index 004da5b40..b157387ef 100644 --- a/runsc/boot/vfs.go +++ b/runsc/boot/vfs.go @@ -210,6 +210,9 @@ func (c *containerMounter) createMountNamespaceVFS2(ctx context.Context, conf *c ReadOnly: c.root.Readonly, GetFilesystemOptions: vfs.GetFilesystemOptions{ Data: strings.Join(data, ","), + InternalData: gofer.InternalFilesystemOptions{ + UniqueID: "/", + }, }, InternalMount: true, } @@ -427,6 +430,7 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo fsName := m.Type useOverlay := false var data []string + var iopts interface{} // Find filesystem name and FS specific data field. switch m.Type { @@ -451,6 +455,9 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo return "", nil, false, fmt.Errorf("9P mount requires a connection FD") } data = p9MountData(m.fd, c.getMountAccessType(m.Mount), true /* vfs2 */) + iopts = gofer.InternalFilesystemOptions{ + UniqueID: m.Destination, + } // If configured, add overlay to all writable mounts. useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly @@ -462,7 +469,8 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo opts := &vfs.MountOptions{ GetFilesystemOptions: vfs.GetFilesystemOptions{ - Data: strings.Join(data, ","), + Data: strings.Join(data, ","), + InternalData: iopts, }, InternalMount: true, } @@ -667,3 +675,21 @@ func (c *containerMounter) makeMountPoint(ctx context.Context, creds *auth.Crede } return c.k.VFS().MakeSyntheticMountpoint(ctx, dest, root, creds) } + +// configureRestore returns an updated context.Context including filesystem +// state used by restore defined by conf. +func (c *containerMounter) configureRestore(ctx context.Context, conf *config.Config) (context.Context, error) { + fdmap := make(map[string]int) + fdmap["/"] = c.fds.remove() + mounts, err := c.prepareMountsVFS2() + if err != nil { + return ctx, err + } + for i := range c.mounts { + submount := &mounts[i] + if submount.fd >= 0 { + fdmap[submount.Destination] = submount.fd + } + } + return context.WithValue(ctx, gofer.CtxRestoreServerFDMap, fdmap), nil +} diff --git a/runsc/cgroup/cgroup.go b/runsc/cgroup/cgroup.go index 56da21584..5bd0afc52 100644 --- a/runsc/cgroup/cgroup.go +++ b/runsc/cgroup/cgroup.go @@ -21,6 +21,7 @@ import ( "context" "errors" "fmt" + "io" "io/ioutil" "os" "path/filepath" @@ -198,8 +199,13 @@ func LoadPaths(pid string) (map[string]string, error) { } defer f.Close() + return loadPathsHelper(f) +} + +func loadPathsHelper(cgroup io.Reader) (map[string]string, error) { paths := make(map[string]string) - scanner := bufio.NewScanner(f) + + scanner := bufio.NewScanner(cgroup) for scanner.Scan() { // Format: ID:[name=]controller1,controller2:path // Example: 2:cpu,cpuacct:/user.slice @@ -207,6 +213,9 @@ func LoadPaths(pid string) (map[string]string, error) { if len(tokens) != 3 { return nil, fmt.Errorf("invalid cgroups file, line: %q", scanner.Text()) } + if len(tokens[1]) == 0 { + continue + } for _, ctrlr := range strings.Split(tokens[1], ",") { // Remove prefix for cgroups with no controller, eg. systemd. ctrlr = strings.TrimPrefix(ctrlr, "name=") diff --git a/runsc/cgroup/cgroup_test.go b/runsc/cgroup/cgroup_test.go index 4db5ee5c3..9794517a7 100644 --- a/runsc/cgroup/cgroup_test.go +++ b/runsc/cgroup/cgroup_test.go @@ -647,3 +647,83 @@ func TestPids(t *testing.T) { }) } } + +func TestLoadPaths(t *testing.T) { + for _, tc := range []struct { + name string + cgroups string + want map[string]string + err string + }{ + { + name: "abs-path", + cgroups: "0:ctr:/path", + want: map[string]string{"ctr": "/path"}, + }, + { + name: "rel-path", + cgroups: "0:ctr:rel-path", + want: map[string]string{"ctr": "rel-path"}, + }, + { + name: "non-controller", + cgroups: "0:name=systemd:/path", + want: map[string]string{"systemd": "/path"}, + }, + { + name: "empty", + }, + { + name: "multiple", + cgroups: "0:ctr0:/path0\n" + + "1:ctr1:/path1\n" + + "2::/empty\n", + want: map[string]string{ + "ctr0": "/path0", + "ctr1": "/path1", + }, + }, + { + name: "missing-field", + cgroups: "0:nopath\n", + err: "invalid cgroups file", + }, + { + name: "too-many-fields", + cgroups: "0:ctr:/path:extra\n", + err: "invalid cgroups file", + }, + { + name: "multiple-malformed", + cgroups: "0:ctr0:/path0\n" + + "1:ctr1:/path1\n" + + "2:\n", + err: "invalid cgroups file", + }, + } { + t.Run(tc.name, func(t *testing.T) { + r := strings.NewReader(tc.cgroups) + got, err := loadPathsHelper(r) + if len(tc.err) == 0 { + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + } else if !strings.Contains(err.Error(), tc.err) { + t.Fatalf("Wrong error message, want: *%s*, got: %v", tc.err, err) + } + for key, vWant := range tc.want { + vGot, ok := got[key] + if !ok { + t.Errorf("Missing controller %q", key) + } + if vWant != vGot { + t.Errorf("Wrong controller %q value, want: %q, got: %q", key, vWant, vGot) + } + delete(got, key) + } + for k, v := range got { + t.Errorf("Unexpected controller %q: %q", k, v) + } + }) + } +} diff --git a/runsc/cmd/start.go b/runsc/cmd/start.go index 88991b521..139edbd49 100644 --- a/runsc/cmd/start.go +++ b/runsc/cmd/start.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" + "gvisor.dev/gvisor/runsc/specutils" ) // Start implements subcommands.Command for the "start" command. @@ -58,6 +59,12 @@ func (*Start) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s if err != nil { Fatalf("loading container: %v", err) } + // Read the spec again here to ensure flag annotations from the spec are + // applied to "conf". + if _, err := specutils.ReadSpec(c.BundleDir, conf); err != nil { + Fatalf("reading spec: %v", err) + } + if err := c.Start(conf); err != nil { Fatalf("starting container: %v", err) } diff --git a/runsc/container/container.go b/runsc/container/container.go index 63f64ce6e..52e1755ce 100644 --- a/runsc/container/container.go +++ b/runsc/container/container.go @@ -312,6 +312,14 @@ func New(conf *config.Config, args Args) (*Container, error) { if isRoot(args.Spec) { log.Debugf("Creating new sandbox for container %q", args.ID) + if args.Spec.Linux == nil { + args.Spec.Linux = &specs.Linux{} + } + // Don't force the use of cgroups in tests because they lack permission to do so. + if args.Spec.Linux.CgroupsPath == "" && !conf.TestOnlyAllowRunAsCurrentUserWithoutChroot { + args.Spec.Linux.CgroupsPath = "/" + args.ID + } + // Create and join cgroup before processes are created to ensure they are // part of the cgroup from the start (and all their children processes). cg, err := cgroup.New(args.Spec) @@ -321,7 +329,13 @@ func New(conf *config.Config, args Args) (*Container, error) { if cg != nil { // If there is cgroup config, install it before creating sandbox process. if err := cg.Install(args.Spec.Linux.Resources); err != nil { - return nil, fmt.Errorf("configuring cgroup: %v", err) + switch { + case errors.Is(err, syscall.EACCES) && conf.Rootless: + log.Warningf("Skipping cgroup configuration in rootless mode: %v", err) + cg = nil + default: + return nil, fmt.Errorf("configuring cgroup: %v", err) + } } } if err := runInCgroup(cg, func() error { diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go index 0392e3e83..45abc1425 100644 --- a/runsc/specutils/specutils.go +++ b/runsc/specutils/specutils.go @@ -344,15 +344,9 @@ func IsSupportedDevMount(m specs.Mount) bool { var existingDevices = []string{ "/dev/fd", "/dev/stdin", "/dev/stdout", "/dev/stderr", "/dev/null", "/dev/zero", "/dev/full", "/dev/random", - "/dev/urandom", "/dev/shm", "/dev/pts", "/dev/ptmx", + "/dev/urandom", "/dev/shm", "/dev/ptmx", } dst := filepath.Clean(m.Destination) - if dst == "/dev" { - // OCI spec uses many different mounts for the things inside of '/dev'. We - // have a single mount at '/dev' that is always mounted, regardless of - // whether it was asked for, as the spec says we SHOULD. - return false - } for _, dev := range existingDevices { if dst == dev || strings.HasPrefix(dst, dev+"/") { return false @@ -425,7 +419,7 @@ func Mount(src, dst, typ string, flags uint32) error { // Special case, as there is no source directory for proc mounts. isDir = true } else if fi, err := os.Stat(src); err != nil { - return fmt.Errorf("Stat(%q) failed: %v", src, err) + return fmt.Errorf("stat(%q) failed: %v", src, err) } else { isDir = fi.IsDir() } @@ -433,25 +427,25 @@ func Mount(src, dst, typ string, flags uint32) error { if isDir { // Create the destination directory. if err := os.MkdirAll(dst, 0777); err != nil { - return fmt.Errorf("Mkdir(%q) failed: %v", dst, err) + return fmt.Errorf("mkdir(%q) failed: %v", dst, err) } } else { // Create the parent destination directory. parent := path.Dir(dst) if err := os.MkdirAll(parent, 0777); err != nil { - return fmt.Errorf("Mkdir(%q) failed: %v", parent, err) + return fmt.Errorf("mkdir(%q) failed: %v", parent, err) } // Create the destination file if it does not exist. f, err := os.OpenFile(dst, syscall.O_CREAT, 0777) if err != nil { - return fmt.Errorf("Open(%q) failed: %v", dst, err) + return fmt.Errorf("open(%q) failed: %v", dst, err) } f.Close() } // Do the mount. if err := syscall.Mount(src, dst, typ, uintptr(flags), ""); err != nil { - return fmt.Errorf("Mount(%q, %q, %d) failed: %v", src, dst, flags, err) + return fmt.Errorf("mount(%q, %q, %d) failed: %v", src, dst, flags, err) } return nil } diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go index 32bf2a992..d3e5efd4f 100644 --- a/test/iptables/filter_output.go +++ b/test/iptables/filter_output.go @@ -441,9 +441,20 @@ func (FilterOutputDestination) Name() string { // ContainerAction implements TestCase.ContainerAction. func (FilterOutputDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { - rules := [][]string{ - {"-A", "OUTPUT", "-d", ip.String(), "-j", "ACCEPT"}, - {"-P", "OUTPUT", "DROP"}, + var rules [][]string + if ipv6 { + rules = [][]string{ + {"-A", "OUTPUT", "-d", ip.String(), "-j", "ACCEPT"}, + // Allow solicited node multicast addresses so we can send neighbor + // solicitations. + {"-A", "OUTPUT", "-d", "ff02::1:ff00:0/104", "-j", "ACCEPT"}, + {"-P", "OUTPUT", "DROP"}, + } + } else { + rules = [][]string{ + {"-A", "OUTPUT", "-d", ip.String(), "-j", "ACCEPT"}, + {"-P", "OUTPUT", "DROP"}, + } } if err := filterTableRules(ipv6, rules); err != nil { return err diff --git a/test/iptables/nat.go b/test/iptables/nat.go index dd9a18339..b98d99fb8 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -577,11 +577,18 @@ func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net. connCh := make(chan int) errCh := make(chan error) go func() { - connFD, _, err := syscall.Accept(sockfd) - if err != nil { - errCh <- err + for { + connFD, _, err := syscall.Accept(sockfd) + if errors.Is(err, syscall.EINTR) { + continue + } + if err != nil { + errCh <- err + return + } + connCh <- connFD + return } - connCh <- connFD }() // Wait for accept() to return or for the context to finish. diff --git a/test/runner/runner.go b/test/runner/runner.go index 22d535f8d..7ab2c3edf 100644 --- a/test/runner/runner.go +++ b/test/runner/runner.go @@ -53,6 +53,9 @@ var ( runscPath = flag.String("runsc", "", "path to runsc binary") addUDSTree = flag.Bool("add-uds-tree", false, "expose a tree of UDS utilities for use in tests") + // TODO(gvisor.dev/issue/4572): properly support leak checking for runsc, and + // set to true as the default for the test runner. + leakCheck = flag.Bool("leak-check", false, "check for reference leaks") ) // runTestCaseNative runs the test case directly on the host machine. @@ -174,6 +177,9 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { if *addUDSTree { args = append(args, "-fsgofer-host-uds") } + if *leakCheck { + args = append(args, "-ref-leak-mode=log-names") + } testLogDir := "" if undeclaredOutputsDir, ok := syscall.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); ok { diff --git a/test/runtimes/exclude/java11.csv b/test/runtimes/exclude/java11.csv index d978baca7..e41441374 100644 --- a/test/runtimes/exclude/java11.csv +++ b/test/runtimes/exclude/java11.csv @@ -144,6 +144,7 @@ jdk/jfr/cmd/TestSplit.java,,java.lang.RuntimeException: 'Missing file' missing f jdk/jfr/cmd/TestSummary.java,,java.lang.RuntimeException: 'Missing file' missing from stdout/stderr jdk/jfr/event/compiler/TestCompilerStats.java,,java.lang.RuntimeException: Field nmetodsSize not in event jdk/jfr/event/metadata/TestDefaultConfigurations.java,,Setting 'threshold' in event 'jdk.SecurityPropertyModification' was not configured in the configuration 'default' +jdk/jfr/event/oldobject/TestLargeRootSet.java,,Flaky - `main' threw exception: java.lang.RuntimeException: Could not find root object jdk/jfr/event/runtime/TestActiveSettingEvent.java,,java.lang.Exception: Could not find setting with name jdk.X509Validation#threshold jdk/jfr/event/runtime/TestModuleEvents.java,,java.lang.RuntimeException: assertEquals: expected jdk.proxy1 to equal java.base jdk/jfr/event/runtime/TestNetworkUtilizationEvent.java,, diff --git a/test/runtimes/exclude/nodejs12.4.0.csv b/test/runtimes/exclude/nodejs12.4.0.csv index ba993814f..c4e7917ec 100644 --- a/test/runtimes/exclude/nodejs12.4.0.csv +++ b/test/runtimes/exclude/nodejs12.4.0.csv @@ -1,31 +1,22 @@ test name,bug id,comment async-hooks/test-statwatcher.js,https://github.com/nodejs/node/issues/21425,Check for fix inclusion in nodejs releases after 2020-03-29 -benchmark/test-benchmark-fs.js,, -benchmark/test-benchmark-napi.js,, +benchmark/test-benchmark-fs.js,,Broken test +benchmark/test-benchmark-napi.js,,Broken test doctool/test-make-doc.js,b/68848110,Expected to fail. internet/test-dgram-multicast-set-interface-lo.js,b/162798882, -internet/test-doctool-versions.js,, -internet/test-uv-threadpool-schedule.js,, -parallel/test-cluster-dgram-reuse.js,b/64024294, +internet/test-doctool-versions.js,,Broken test +internet/test-uv-threadpool-schedule.js,,Broken test parallel/test-dgram-bind-fd.js,b/132447356, parallel/test-dgram-socket-buffer-size.js,b/68847921, parallel/test-dns-channel-timeout.js,b/161893056, -parallel/test-fs-access.js,, -parallel/test-fs-watchfile.js,,Flaky - File already exists error -parallel/test-fs-write-stream.js,b/166819807,Flaky -parallel/test-fs-write-stream-double-close.js,b/166819807,Flaky -parallel/test-fs-write-stream-throw-type-error.js,b/166819807,Flaky -parallel/test-http-writable-true-after-close.js,,Flaky - Mismatched <anonymous> function calls. Expected exactly 1 actual 2 +parallel/test-fs-access.js,,Broken test +parallel/test-fs-watchfile.js,b/166819807,Flaky - VFS1 only +parallel/test-fs-write-stream.js,b/166819807,Flaky - VFS1 only +parallel/test-fs-write-stream-double-close.js,b/166819807,Flaky - VFS1 only +parallel/test-fs-write-stream-throw-type-error.js,b/166819807,Flaky - VFS1 only +parallel/test-http-writable-true-after-close.js,b/171301436,Flaky - Mismatched <anonymous> function calls. Expected exactly 1 actual 2 parallel/test-os.js,b/63997097, -parallel/test-net-server-listen-options.js,,Flaky - EADDRINUSE -parallel/test-process-uid-gid.js,, -parallel/test-tls-cli-min-version-1.0.js,,Flaky - EADDRINUSE -parallel/test-tls-cli-min-version-1.1.js,,Flaky - EADDRINUSE -parallel/test-tls-cli-min-version-1.2.js,,Flaky - EADDRINUSE -parallel/test-tls-cli-min-version-1.3.js,,Flaky - EADDRINUSE -parallel/test-tls-cli-max-version-1.2.js,,Flaky - EADDRINUSE -parallel/test-tls-cli-max-version-1.3.js,,Flaky - EADDRINUSE -parallel/test-tls-min-max-version.js,,Flaky - EADDRINUSE +parallel/test-process-uid-gid.js,,Does not work inside Docker with gid nobody pseudo-tty/test-assert-colors.js,b/162801321, pseudo-tty/test-assert-no-color.js,b/162801321, pseudo-tty/test-assert-position-indicator.js,b/162801321, @@ -48,11 +39,7 @@ pseudo-tty/test-tty-stdout-resize.js,b/162801321, pseudo-tty/test-tty-stream-constructors.js,b/162801321, pseudo-tty/test-tty-window-size.js,b/162801321, pseudo-tty/test-tty-wrap.js,b/162801321, -pummel/test-heapdump-http2.js,,Flaky -pummel/test-net-pingpong.js,, +pummel/test-net-pingpong.js,,Broken test pummel/test-vm-memleak.js,b/162799436, -pummel/test-watch-file.js,,Flaky - Timeout -sequential/test-child-process-pass-fd.js,b/63926391,Flaky -sequential/test-https-connect-localport.js,,Flaky - EADDRINUSE -sequential/test-net-bytes-per-incoming-chunk-overhead.js,,flaky - timeout -tick-processor/test-tick-processor-builtin.js,, +pummel/test-watch-file.js,,Flaky - VFS1 only +tick-processor/test-tick-processor-builtin.js,,Broken test diff --git a/test/runtimes/exclude/php7.3.6.csv b/test/runtimes/exclude/php7.3.6.csv index a73f3bcfb..9e1f4c050 100644 --- a/test/runtimes/exclude/php7.3.6.csv +++ b/test/runtimes/exclude/php7.3.6.csv @@ -26,12 +26,13 @@ ext/standard/tests/file/php_fd_wrapper_01.phpt,, ext/standard/tests/file/php_fd_wrapper_02.phpt,, ext/standard/tests/file/php_fd_wrapper_03.phpt,, ext/standard/tests/file/php_fd_wrapper_04.phpt,, -ext/standard/tests/file/realpath_bug77484.phpt,b/162894969, +ext/standard/tests/file/realpath_bug77484.phpt,b/162894969,VFS1 only failure ext/standard/tests/file/rename_variation.phpt,b/68717309, ext/standard/tests/file/symlink_link_linkinfo_is_link_variation4.phpt,b/162895341, ext/standard/tests/file/symlink_link_linkinfo_is_link_variation8.phpt,b/162896223, ext/standard/tests/general_functions/escapeshellarg_bug71270.phpt,, ext/standard/tests/general_functions/escapeshellcmd_bug71270.phpt,, +ext/standard/tests/network/bug20134.phpt,b/171347929,Flaky ext/standard/tests/streams/proc_open_bug60120.phpt,,Flaky until php-src 3852a35fdbcb ext/standard/tests/streams/proc_open_bug69900.phpt,,Flaky ext/standard/tests/streams/stream_socket_sendto.phpt,, diff --git a/test/runtimes/exclude/python3.7.3.csv b/test/runtimes/exclude/python3.7.3.csv index 8760f8951..911f22855 100644 --- a/test/runtimes/exclude/python3.7.3.csv +++ b/test/runtimes/exclude/python3.7.3.csv @@ -18,4 +18,3 @@ test_selectors,b/76116849,OSError not raised with epoll test_smtplib,b/162980434,unclosed sockets test_signal,,Flaky - signal: alarm clock test_socket,b/75983380, -test_subprocess,b/162980831, diff --git a/test/runtimes/proctor/main.go b/test/runtimes/proctor/main.go index e5607ac92..81cb68381 100644 --- a/test/runtimes/proctor/main.go +++ b/test/runtimes/proctor/main.go @@ -22,6 +22,7 @@ import ( "log" "os" "strings" + "syscall" "gvisor.dev/gvisor/test/runtimes/proctor/lib" ) @@ -33,6 +34,29 @@ var ( pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children") ) +// setNumFilesLimit changes the NOFILE soft rlimit if it is too high. +func setNumFilesLimit() error { + // In docker containers, the default value of the NOFILE limit is + // 1048576. A few runtime tests (e.g. python:test_subprocess) + // enumerates all possible file descriptors and these tests can fail by + // timeout if the NOFILE limit is too high. On gVisor, syscalls are + // slower so these tests will need even more time to pass. + const nofile = 32768 + rLimit := syscall.Rlimit{} + err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit) + if err != nil { + return fmt.Errorf("failed to get RLIMIT_NOFILE: %v", err) + } + if rLimit.Cur > nofile { + rLimit.Cur = nofile + err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rLimit) + if err != nil { + return fmt.Errorf("failed to set RLIMIT_NOFILE: %v", err) + } + } + return nil +} + func main() { flag.Parse() @@ -74,6 +98,10 @@ func main() { tests = strings.Split(*testNames, ",") } + if err := setNumFilesLimit(); err != nil { + log.Fatalf("%v", err) + } + // Run tests. cmds := tr.TestCmds(tests) for _, cmd := range cmds { diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 572f39a5d..c94c1d5bd 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1285,6 +1285,7 @@ cc_binary( "//test/util:mount_util", "//test/util:multiprocess_util", "//test/util:posix_error", + "//test/util:save_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", @@ -2434,6 +2435,7 @@ cc_library( "@com_google_absl//absl/memory", gtest, "//test/util:posix_error", + "//test/util:save_util", "//test/util:test_util", ], alwayslink = 1, @@ -3441,6 +3443,7 @@ cc_binary( "@com_google_absl//absl/strings", gtest, "//test/util:posix_error", + "//test/util:save_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", diff --git a/test/syscalls/linux/mknod.cc b/test/syscalls/linux/mknod.cc index b96907b30..1635c6d0c 100644 --- a/test/syscalls/linux/mknod.cc +++ b/test/syscalls/linux/mknod.cc @@ -125,6 +125,16 @@ TEST(MknodTest, Socket) { ASSERT_THAT(unlink(filename.c_str()), SyscallSucceeds()); } +PosixErrorOr<FileDescriptor> OpenRetryEINTR(std::string const& path, int flags, + mode_t mode = 0) { + while (true) { + auto maybe_fd = Open(path, flags, mode); + if (maybe_fd.ok() || maybe_fd.error().errno_value() != EINTR) { + return maybe_fd; + } + } +} + TEST(MknodTest, Fifo) { const std::string fifo = NewTempAbsPath(); ASSERT_THAT(mknod(fifo.c_str(), S_IFIFO | S_IRUSR | S_IWUSR, 0), @@ -139,14 +149,16 @@ TEST(MknodTest, Fifo) { // Read-end of the pipe. ScopedThread t([&fifo, &buf, &msg]() { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_RDONLY)); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(OpenRetryEINTR(fifo.c_str(), O_RDONLY)); EXPECT_THAT(ReadFd(fd.get(), buf.data(), buf.size()), SyscallSucceedsWithValue(msg.length())); EXPECT_EQ(msg, std::string(buf.data())); }); // Write-end of the pipe. - FileDescriptor wfd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_WRONLY)); + FileDescriptor wfd = + ASSERT_NO_ERRNO_AND_VALUE(OpenRetryEINTR(fifo.c_str(), O_WRONLY)); EXPECT_THAT(WriteFd(wfd.get(), msg.c_str(), msg.length()), SyscallSucceedsWithValue(msg.length())); } @@ -164,15 +176,16 @@ TEST(MknodTest, FifoOtrunc) { std::vector<char> buf(512); // Read-end of the pipe. ScopedThread t([&fifo, &buf, &msg]() { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_RDONLY)); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(OpenRetryEINTR(fifo.c_str(), O_RDONLY)); EXPECT_THAT(ReadFd(fd.get(), buf.data(), buf.size()), SyscallSucceedsWithValue(msg.length())); EXPECT_EQ(msg, std::string(buf.data())); }); // Write-end of the pipe. - FileDescriptor wfd = - ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_WRONLY | O_TRUNC)); + FileDescriptor wfd = ASSERT_NO_ERRNO_AND_VALUE( + OpenRetryEINTR(fifo.c_str(), O_WRONLY | O_TRUNC)); EXPECT_THAT(WriteFd(wfd.get(), msg.c_str(), msg.length()), SyscallSucceedsWithValue(msg.length())); } @@ -192,14 +205,15 @@ TEST(MknodTest, FifoTruncNoOp) { std::vector<char> buf(512); // Read-end of the pipe. ScopedThread t([&fifo, &buf, &msg]() { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_RDONLY)); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(OpenRetryEINTR(fifo.c_str(), O_RDONLY)); EXPECT_THAT(ReadFd(fd.get(), buf.data(), buf.size()), SyscallSucceedsWithValue(msg.length())); EXPECT_EQ(msg, std::string(buf.data())); }); - FileDescriptor wfd = - ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_WRONLY | O_TRUNC)); + FileDescriptor wfd = ASSERT_NO_ERRNO_AND_VALUE( + OpenRetryEINTR(fifo.c_str(), O_WRONLY | O_TRUNC)); EXPECT_THAT(ftruncate(wfd.get(), 0), SyscallFailsWithErrno(EINVAL)); EXPECT_THAT(WriteFd(wfd.get(), msg.c_str(), msg.length()), SyscallSucceedsWithValue(msg.length())); diff --git a/test/syscalls/linux/mount.cc b/test/syscalls/linux/mount.cc index 3aab25b23..d65b7d031 100644 --- a/test/syscalls/linux/mount.cc +++ b/test/syscalls/linux/mount.cc @@ -34,6 +34,7 @@ #include "test/util/mount_util.h" #include "test/util/multiprocess_util.h" #include "test/util/posix_error.h" +#include "test/util/save_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -131,7 +132,9 @@ TEST(MountTest, UmountDetach) { ASSERT_NO_ERRNO_AND_VALUE(Mount("", dir.path(), "tmpfs", 0, "mode=0700", /* umountflags= */ MNT_DETACH)); const struct stat after = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); - EXPECT_NE(before.st_ino, after.st_ino); + EXPECT_FALSE(before.st_dev == after.st_dev && before.st_ino == after.st_ino) + << "mount point has device number " << before.st_dev + << " and inode number " << before.st_ino << " before and after mount"; // Create files in the new mount. constexpr char kContents[] = "no no no"; @@ -147,12 +150,14 @@ TEST(MountTest, UmountDetach) { // Unmount the tmpfs. mount.Release()(); - // Only check for inode number equality if the directory is not in overlayfs. - // If xino option is not enabled and if all overlayfs layers do not belong to - // the same filesystem then "the value of st_ino for directory objects may not - // be persistent and could change even while the overlay filesystem is - // mounted." -- Documentation/filesystems/overlayfs.txt - if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) { + // Inode numbers for gofer-accessed files may change across save/restore. + // + // For overlayfs, if xino option is not enabled and if all overlayfs layers do + // not belong to the same filesystem then "the value of st_ino for directory + // objects may not be persistent and could change even while the overlay + // filesystem is mounted." -- Documentation/filesystems/overlayfs.txt + if (!IsRunningWithSaveRestore() && + !ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) { const struct stat after2 = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); EXPECT_EQ(before.st_ino, after2.st_ino); } @@ -214,18 +219,23 @@ TEST(MountTest, MountTmpfs) { const struct stat s = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); EXPECT_EQ(s.st_mode, S_IFDIR | 0700); - EXPECT_NE(s.st_ino, before.st_ino); + EXPECT_FALSE(before.st_dev == s.st_dev && before.st_ino == s.st_ino) + << "mount point has device number " << before.st_dev + << " and inode number " << before.st_ino << " before and after mount"; EXPECT_NO_ERRNO(Open(JoinPath(dir.path(), "foo"), O_CREAT | O_RDWR, 0777)); } // Now that dir is unmounted again, we should have the old inode back. - // Only check for inode number equality if the directory is not in overlayfs. - // If xino option is not enabled and if all overlayfs layers do not belong to - // the same filesystem then "the value of st_ino for directory objects may not - // be persistent and could change even while the overlay filesystem is - // mounted." -- Documentation/filesystems/overlayfs.txt - if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) { + // + // Inode numbers for gofer-accessed files may change across save/restore. + // + // For overlayfs, if xino option is not enabled and if all overlayfs layers do + // not belong to the same filesystem then "the value of st_ino for directory + // objects may not be persistent and could change even while the overlay + // filesystem is mounted." -- Documentation/filesystems/overlayfs.txt + if (!IsRunningWithSaveRestore() && + !ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) { const struct stat after = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); EXPECT_EQ(before.st_ino, after.st_ino); } diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc index b558e3a01..a7c46adbf 100644 --- a/test/syscalls/linux/packet_socket_raw.cc +++ b/test/syscalls/linux/packet_socket_raw.cc @@ -664,6 +664,17 @@ TEST_P(RawPacketTest, SetAndGetSocketLinger) { EXPECT_EQ(0, memcmp(&sl, &got_linger, length)); } +TEST_P(RawPacketTest, GetSocketAcceptConn) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int got = -1; + socklen_t length = sizeof(got); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ACCEPTCONN, &got, &length), + SyscallSucceedsWithValue(0)); + + ASSERT_EQ(length, sizeof(got)); + EXPECT_EQ(got, 0); +} INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest, ::testing::Values(ETH_P_IP, ETH_P_ALL)); diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index e8fcc4439..7a0f33dff 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -26,6 +26,7 @@ #include <string.h> #include <sys/mman.h> #include <sys/prctl.h> +#include <sys/ptrace.h> #include <sys/stat.h> #include <sys/statfs.h> #include <sys/utsname.h> @@ -512,6 +513,414 @@ TEST(ProcSelfAuxv, EntryValues) { EXPECT_EQ(i, proc_auxv.size()); } +// Just open and read a part of /proc/self/mem, check that we can read an item. +TEST(ProcPidMem, Read) { + auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/mem", O_RDONLY)); + char input[] = "hello-world"; + char output[sizeof(input)]; + ASSERT_THAT(pread(memfd.get(), output, sizeof(output), + reinterpret_cast<off_t>(input)), + SyscallSucceedsWithValue(sizeof(input))); + ASSERT_STREQ(input, output); +} + +// Perform read on an unmapped region. +TEST(ProcPidMem, Unmapped) { + // Strategy: map then unmap, so we have a guaranteed unmapped region + auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/mem", O_RDONLY)); + Mapping mapping = ASSERT_NO_ERRNO_AND_VALUE( + MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); + // Fill it with things + memset(mapping.ptr(), 'x', mapping.len()); + char expected = 'x', output; + ASSERT_THAT(pread(memfd.get(), &output, sizeof(output), + reinterpret_cast<off_t>(mapping.ptr())), + SyscallSucceedsWithValue(sizeof(output))); + ASSERT_EQ(expected, output); + + // Unmap region again + ASSERT_THAT(munmap(mapping.ptr(), mapping.len()), SyscallSucceeds()); + + // Now we want EIO error + ASSERT_THAT(pread(memfd.get(), &output, sizeof(output), + reinterpret_cast<off_t>(mapping.ptr())), + SyscallFailsWithErrno(EIO)); +} + +// Perform read repeatedly to verify offset change. +TEST(ProcPidMem, RepeatedRead) { + auto const num_reads = 3; + char expected[] = "01234567890abcdefghijkl"; + char output[sizeof(expected) / num_reads]; + + auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/mem", O_RDONLY)); + ASSERT_THAT(lseek(memfd.get(), reinterpret_cast<off_t>(&expected), SEEK_SET), + SyscallSucceedsWithValue(reinterpret_cast<off_t>(&expected))); + for (auto i = 0; i < num_reads; i++) { + ASSERT_THAT(read(memfd.get(), &output, sizeof(output)), + SyscallSucceedsWithValue(sizeof(output))); + ASSERT_EQ(strncmp(&expected[i * sizeof(output)], output, sizeof(output)), + 0); + } +} + +// Perform seek operations repeatedly. +TEST(ProcPidMem, RepeatedSeek) { + auto const num_reads = 3; + char expected[] = "01234567890abcdefghijkl"; + char output[sizeof(expected) / num_reads]; + + auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/mem", O_RDONLY)); + ASSERT_THAT(lseek(memfd.get(), reinterpret_cast<off_t>(&expected), SEEK_SET), + SyscallSucceedsWithValue(reinterpret_cast<off_t>(&expected))); + // Read from start + ASSERT_THAT(read(memfd.get(), &output, sizeof(output)), + SyscallSucceedsWithValue(sizeof(output))); + ASSERT_EQ(strncmp(&expected[0 * sizeof(output)], output, sizeof(output)), 0); + // Skip ahead one read + ASSERT_THAT(lseek(memfd.get(), sizeof(output), SEEK_CUR), + SyscallSucceedsWithValue(reinterpret_cast<off_t>(&expected) + + sizeof(output) * 2)); + // Do read again + ASSERT_THAT(read(memfd.get(), &output, sizeof(output)), + SyscallSucceedsWithValue(sizeof(output))); + ASSERT_EQ(strncmp(&expected[2 * sizeof(output)], output, sizeof(output)), 0); + // Skip back three reads + ASSERT_THAT(lseek(memfd.get(), -3 * sizeof(output), SEEK_CUR), + SyscallSucceedsWithValue(reinterpret_cast<off_t>(&expected))); + // Do read again + ASSERT_THAT(read(memfd.get(), &output, sizeof(output)), + SyscallSucceedsWithValue(sizeof(output))); + ASSERT_EQ(strncmp(&expected[0 * sizeof(output)], output, sizeof(output)), 0); + // Check that SEEK_END does not work + ASSERT_THAT(lseek(memfd.get(), 0, SEEK_END), SyscallFailsWithErrno(EINVAL)); +} + +// Perform read past an allocated memory region. +TEST(ProcPidMem, PartialRead) { + // Strategy: map large region, then do unmap and remap smaller region + auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/mem", O_RDONLY)); + + Mapping mapping = ASSERT_NO_ERRNO_AND_VALUE( + MmapAnon(2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); + ASSERT_THAT(munmap(mapping.ptr(), mapping.len()), SyscallSucceeds()); + Mapping smaller_mapping = ASSERT_NO_ERRNO_AND_VALUE( + Mmap(mapping.ptr(), kPageSize, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0)); + + // Fill it with things + memset(smaller_mapping.ptr(), 'x', smaller_mapping.len()); + + // Now we want no error + char expected[] = {'x'}; + std::unique_ptr<char[]> output(new char[kPageSize]); + off_t read_offset = + reinterpret_cast<off_t>(smaller_mapping.ptr()) + kPageSize - 1; + ASSERT_THAT( + pread(memfd.get(), output.get(), sizeof(output.get()), read_offset), + SyscallSucceedsWithValue(sizeof(expected))); + // Since output is larger, than expected we have to do manual compare + ASSERT_EQ(expected[0], (output).get()[0]); +} + +// Perform read on /proc/[pid]/mem after exit. +TEST(ProcPidMem, AfterExit) { + int pfd1[2] = {}; + int pfd2[2] = {}; + + char expected[] = "hello-world"; + + ASSERT_THAT(pipe(pfd1), SyscallSucceeds()); + ASSERT_THAT(pipe(pfd2), SyscallSucceeds()); + + // Create child process + pid_t const child_pid = fork(); + if (child_pid == 0) { + // Close reading end of first pipe + close(pfd1[0]); + + // Tell parent about location of input + char ok = 1; + TEST_CHECK(WriteFd(pfd1[1], &ok, sizeof(ok)) == sizeof(ok)); + TEST_PCHECK(close(pfd1[1]) == 0); + + // Close writing end of second pipe + TEST_PCHECK(close(pfd2[1]) == 0); + + // Await parent OK to die + ok = 0; + TEST_CHECK(ReadFd(pfd2[0], &ok, sizeof(ok)) == sizeof(ok)); + + // Close rest pipes + TEST_PCHECK(close(pfd2[0]) == 0); + _exit(0); + } + + // In parent process. + ASSERT_THAT(child_pid, SyscallSucceeds()); + + // Close writing end of first pipe + EXPECT_THAT(close(pfd1[1]), SyscallSucceeds()); + + // Wait for child to be alive and well + char ok = 0; + EXPECT_THAT(ReadFd(pfd1[0], &ok, sizeof(ok)), + SyscallSucceedsWithValue(sizeof(ok))); + // Close reading end of first pipe + EXPECT_THAT(close(pfd1[0]), SyscallSucceeds()); + + // Open /proc/pid/mem fd + std::string mempath = absl::StrCat("/proc/", child_pid, "/mem"); + auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open(mempath, O_RDONLY)); + + // Expect that we can read + char output[sizeof(expected)]; + EXPECT_THAT(pread(memfd.get(), &output, sizeof(output), + reinterpret_cast<off_t>(&expected)), + SyscallSucceedsWithValue(sizeof(output))); + EXPECT_STREQ(expected, output); + + // Tell proc its ok to go + EXPECT_THAT(close(pfd2[0]), SyscallSucceeds()); + ok = 1; + EXPECT_THAT(WriteFd(pfd2[1], &ok, sizeof(ok)), + SyscallSucceedsWithValue(sizeof(ok))); + EXPECT_THAT(close(pfd2[1]), SyscallSucceeds()); + + // Expect termination + int status; + ASSERT_THAT(waitpid(child_pid, &status, 0), SyscallSucceeds()); + + // Expect that we can't read anymore + EXPECT_THAT(pread(memfd.get(), &output, sizeof(output), + reinterpret_cast<off_t>(&expected)), + SyscallSucceedsWithValue(0)); +} + +// Read from /proc/[pid]/mem with different UID/GID and attached state. +TEST(ProcPidMem, DifferentUserAttached) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID))); + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_DAC_OVERRIDE))); + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_PTRACE))); + + int pfd1[2] = {}; + int pfd2[2] = {}; + + ASSERT_THAT(pipe(pfd1), SyscallSucceeds()); + ASSERT_THAT(pipe(pfd2), SyscallSucceeds()); + + // Create child process + pid_t const child_pid = fork(); + if (child_pid == 0) { + // Close reading end of first pipe + close(pfd1[0]); + + // Tell parent about location of input + char input[] = "hello-world"; + off_t input_location = reinterpret_cast<off_t>(input); + TEST_CHECK(WriteFd(pfd1[1], &input_location, sizeof(input_location)) == + sizeof(input_location)); + TEST_PCHECK(close(pfd1[1]) == 0); + + // Close writing end of second pipe + TEST_PCHECK(close(pfd2[1]) == 0); + + // Await parent OK to die + char ok = 0; + TEST_CHECK(ReadFd(pfd2[0], &ok, sizeof(ok)) == sizeof(ok)); + + // Close rest pipes + TEST_PCHECK(close(pfd2[0]) == 0); + _exit(0); + } + + // In parent process. + ASSERT_THAT(child_pid, SyscallSucceeds()); + + // Close writing end of first pipe + EXPECT_THAT(close(pfd1[1]), SyscallSucceeds()); + + // Read target location from child + off_t target_location; + EXPECT_THAT(ReadFd(pfd1[0], &target_location, sizeof(target_location)), + SyscallSucceedsWithValue(sizeof(target_location))); + // Close reading end of first pipe + EXPECT_THAT(close(pfd1[0]), SyscallSucceeds()); + + ScopedThread([&] { + // Attach to child subprocess without stopping it + EXPECT_THAT(ptrace(PTRACE_SEIZE, child_pid, NULL, NULL), SyscallSucceeds()); + + // Keep capabilities after setuid + EXPECT_THAT(prctl(PR_SET_KEEPCAPS, 1, 0, 0, 0), SyscallSucceeds()); + constexpr int kNobody = 65534; + EXPECT_THAT(syscall(SYS_setuid, kNobody), SyscallSucceeds()); + + // Only restore CAP_SYS_PTRACE and CAP_DAC_OVERRIDE + EXPECT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, true)); + EXPECT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, true)); + + // Open /proc/pid/mem fd + std::string mempath = absl::StrCat("/proc/", child_pid, "/mem"); + auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open(mempath, O_RDONLY)); + char expected[] = "hello-world"; + char output[sizeof(expected)]; + EXPECT_THAT(pread(memfd.get(), output, sizeof(output), + reinterpret_cast<off_t>(target_location)), + SyscallSucceedsWithValue(sizeof(output))); + EXPECT_STREQ(expected, output); + + // Tell proc its ok to go + EXPECT_THAT(close(pfd2[0]), SyscallSucceeds()); + char ok = 1; + EXPECT_THAT(WriteFd(pfd2[1], &ok, sizeof(ok)), + SyscallSucceedsWithValue(sizeof(ok))); + EXPECT_THAT(close(pfd2[1]), SyscallSucceeds()); + + // Expect termination + int status; + ASSERT_THAT(waitpid(child_pid, &status, 0), SyscallSucceeds()); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) + << " status " << status; + }); +} + +// Attempt to read from /proc/[pid]/mem with different UID/GID. +TEST(ProcPidMem, DifferentUser) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID))); + + int pfd1[2] = {}; + int pfd2[2] = {}; + + ASSERT_THAT(pipe(pfd1), SyscallSucceeds()); + ASSERT_THAT(pipe(pfd2), SyscallSucceeds()); + + // Create child process + pid_t const child_pid = fork(); + if (child_pid == 0) { + // Close reading end of first pipe + close(pfd1[0]); + + // Tell parent about location of input + char input[] = "hello-world"; + off_t input_location = reinterpret_cast<off_t>(input); + TEST_CHECK(WriteFd(pfd1[1], &input_location, sizeof(input_location)) == + sizeof(input_location)); + TEST_PCHECK(close(pfd1[1]) == 0); + + // Close writing end of second pipe + TEST_PCHECK(close(pfd2[1]) == 0); + + // Await parent OK to die + char ok = 0; + TEST_CHECK(ReadFd(pfd2[0], &ok, sizeof(ok)) == sizeof(ok)); + + // Close rest pipes + TEST_PCHECK(close(pfd2[0]) == 0); + _exit(0); + } + + // In parent process. + ASSERT_THAT(child_pid, SyscallSucceeds()); + + // Close writing end of first pipe + EXPECT_THAT(close(pfd1[1]), SyscallSucceeds()); + + // Read target location from child + off_t target_location; + EXPECT_THAT(ReadFd(pfd1[0], &target_location, sizeof(target_location)), + SyscallSucceedsWithValue(sizeof(target_location))); + // Close reading end of first pipe + EXPECT_THAT(close(pfd1[0]), SyscallSucceeds()); + + ScopedThread([&] { + constexpr int kNobody = 65534; + EXPECT_THAT(syscall(SYS_setuid, kNobody), SyscallSucceeds()); + + // Attempt to open /proc/[child_pid]/mem + std::string mempath = absl::StrCat("/proc/", child_pid, "/mem"); + EXPECT_THAT(open(mempath.c_str(), O_RDONLY), SyscallFailsWithErrno(EACCES)); + + // Tell proc its ok to go + EXPECT_THAT(close(pfd2[0]), SyscallSucceeds()); + char ok = 1; + EXPECT_THAT(WriteFd(pfd2[1], &ok, sizeof(ok)), + SyscallSucceedsWithValue(sizeof(ok))); + EXPECT_THAT(close(pfd2[1]), SyscallSucceeds()); + + // Expect termination + int status; + ASSERT_THAT(waitpid(child_pid, &status, 0), SyscallSucceeds()); + }); +} + +// Perform read on /proc/[pid]/mem with same UID/GID. +TEST(ProcPidMem, SameUser) { + int pfd1[2] = {}; + int pfd2[2] = {}; + + ASSERT_THAT(pipe(pfd1), SyscallSucceeds()); + ASSERT_THAT(pipe(pfd2), SyscallSucceeds()); + + // Create child process + pid_t const child_pid = fork(); + if (child_pid == 0) { + // Close reading end of first pipe + close(pfd1[0]); + + // Tell parent about location of input + char input[] = "hello-world"; + off_t input_location = reinterpret_cast<off_t>(input); + TEST_CHECK(WriteFd(pfd1[1], &input_location, sizeof(input_location)) == + sizeof(input_location)); + TEST_PCHECK(close(pfd1[1]) == 0); + + // Close writing end of second pipe + TEST_PCHECK(close(pfd2[1]) == 0); + + // Await parent OK to die + char ok = 0; + TEST_CHECK(ReadFd(pfd2[0], &ok, sizeof(ok)) == sizeof(ok)); + + // Close rest pipes + TEST_PCHECK(close(pfd2[0]) == 0); + _exit(0); + } + // In parent process. + ASSERT_THAT(child_pid, SyscallSucceeds()); + + // Close writing end of first pipe + EXPECT_THAT(close(pfd1[1]), SyscallSucceeds()); + + // Read target location from child + off_t target_location; + EXPECT_THAT(ReadFd(pfd1[0], &target_location, sizeof(target_location)), + SyscallSucceedsWithValue(sizeof(target_location))); + // Close reading end of first pipe + EXPECT_THAT(close(pfd1[0]), SyscallSucceeds()); + + // Open /proc/pid/mem fd + std::string mempath = absl::StrCat("/proc/", child_pid, "/mem"); + auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open(mempath, O_RDONLY)); + char expected[] = "hello-world"; + char output[sizeof(expected)]; + EXPECT_THAT(pread(memfd.get(), output, sizeof(output), + reinterpret_cast<off_t>(target_location)), + SyscallSucceedsWithValue(sizeof(output))); + EXPECT_STREQ(expected, output); + + // Tell proc its ok to go + EXPECT_THAT(close(pfd2[0]), SyscallSucceeds()); + char ok = 1; + EXPECT_THAT(WriteFd(pfd2[1], &ok, sizeof(ok)), + SyscallSucceedsWithValue(sizeof(ok))); + EXPECT_THAT(close(pfd2[1]), SyscallSucceeds()); + + // Expect termination + int status; + ASSERT_THAT(waitpid(child_pid, &status, 0), SyscallSucceeds()); +} + // Just open and read /proc/self/maps, check that we can find [stack] TEST(ProcSelfMaps, Basic) { auto proc_self_maps = diff --git a/test/syscalls/linux/proc_pid_smaps.cc b/test/syscalls/linux/proc_pid_smaps.cc index 9fb1b3a2c..738923822 100644 --- a/test/syscalls/linux/proc_pid_smaps.cc +++ b/test/syscalls/linux/proc_pid_smaps.cc @@ -191,7 +191,7 @@ PosixErrorOr<std::vector<ProcPidSmapsEntry>> ParseProcPidSmaps( // amount of whitespace). if (!entry) { std::cerr << "smaps line not considered a maps line: " - << maybe_maps_entry.error_message() << std::endl; + << maybe_maps_entry.error().message() << std::endl; return PosixError( EINVAL, absl::StrCat("smaps field line without preceding maps line: ", l)); diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc index 1b9dbc584..bd779da92 100644 --- a/test/syscalls/linux/raw_socket_icmp.cc +++ b/test/syscalls/linux/raw_socket_icmp.cc @@ -438,6 +438,19 @@ TEST_F(RawSocketICMPTest, SetAndGetSocketLinger) { EXPECT_EQ(0, memcmp(&sl, &got_linger, length)); } +// Test getsockopt for SO_ACCEPTCONN. +TEST_F(RawSocketICMPTest, GetSocketAcceptConn) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int got = -1; + socklen_t length = sizeof(got); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ACCEPTCONN, &got, &length), + SyscallSucceedsWithValue(0)); + + ASSERT_EQ(length, sizeof(got)); + EXPECT_EQ(got, 0); +} + void RawSocketICMPTest::ExpectICMPSuccess(const struct icmphdr& icmp) { // We're going to receive both the echo request and reply, but the order is // indeterminate. diff --git a/test/syscalls/linux/semaphore.cc b/test/syscalls/linux/semaphore.cc index e9b131ca9..ed6a1c2aa 100644 --- a/test/syscalls/linux/semaphore.cc +++ b/test/syscalls/linux/semaphore.cc @@ -486,6 +486,62 @@ TEST(SemaphoreTest, SemIpcSet) { ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallFailsWithErrno(EACCES)); } +TEST(SemaphoreTest, SemCtlIpcStat) { + // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false)); + const uid_t kUid = getuid(); + const gid_t kGid = getgid(); + time_t start_time = time(nullptr); + + AutoSem sem(semget(IPC_PRIVATE, 10, 0600 | IPC_CREAT)); + ASSERT_THAT(sem.get(), SyscallSucceeds()); + + struct semid_ds ds; + EXPECT_THAT(semctl(sem.get(), 0, IPC_STAT, &ds), SyscallSucceeds()); + + EXPECT_EQ(ds.sem_perm.__key, IPC_PRIVATE); + EXPECT_EQ(ds.sem_perm.uid, kUid); + EXPECT_EQ(ds.sem_perm.gid, kGid); + EXPECT_EQ(ds.sem_perm.cuid, kUid); + EXPECT_EQ(ds.sem_perm.cgid, kGid); + EXPECT_EQ(ds.sem_perm.mode, 0600); + // Last semop time is not set on creation. + EXPECT_EQ(ds.sem_otime, 0); + EXPECT_GE(ds.sem_ctime, start_time); + EXPECT_EQ(ds.sem_nsems, 10); + + // The timestamps only have a resolution of seconds; slow down so we actually + // see the timestamps change. + absl::SleepFor(absl::Seconds(1)); + + // Set semid_ds structure of the set. + auto last_ctime = ds.sem_ctime; + start_time = time(nullptr); + struct semid_ds semid_to_set = {}; + semid_to_set.sem_perm.uid = kUid; + semid_to_set.sem_perm.gid = kGid; + semid_to_set.sem_perm.mode = 0666; + ASSERT_THAT(semctl(sem.get(), 0, IPC_SET, &semid_to_set), SyscallSucceeds()); + struct sembuf buf = {}; + buf.sem_op = 1; + ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallSucceeds()); + + EXPECT_THAT(semctl(sem.get(), 0, IPC_STAT, &ds), SyscallSucceeds()); + EXPECT_EQ(ds.sem_perm.mode, 0666); + EXPECT_GE(ds.sem_otime, start_time); + EXPECT_GT(ds.sem_ctime, last_ctime); + + // An invalid semid fails the syscall with errno EINVAL. + EXPECT_THAT(semctl(sem.get() + 1, 0, IPC_STAT, &ds), + SyscallFailsWithErrno(EINVAL)); + + // Make semaphore not readable and check the signal fails. + semid_to_set.sem_perm.mode = 0200; + ASSERT_THAT(semctl(sem.get(), 0, IPC_SET, &semid_to_set), SyscallSucceeds()); + EXPECT_THAT(semctl(sem.get(), 0, IPC_STAT, &ds), + SyscallFailsWithErrno(EACCES)); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 11fcec443..39a68c5a5 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -350,6 +350,10 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdownListen) { sockaddr_storage conn_addr = connector.addr; ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + // TODO(b/157236388): Remove Disable save after bug is fixed. S/R test can + // fail because the last socket may not be delivered to the accept queue + // by the time connect returns. + DisableSave ds; for (int i = 0; i < kBacklog; i++) { auto client = ASSERT_NO_ERRNO_AND_VALUE( Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); @@ -554,7 +558,11 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdownWhileConnect) { }); } -TEST_P(SocketInetLoopbackTest, TCPbacklog) { +// TODO(b/157236388): Remove _NoRandomSave once bug is fixed. Test fails w/ +// random save as established connections which can't be delivered to the accept +// queue because the queue is full are not correctly delivered after restore +// causing the last accept to timeout on the restore. +TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) { auto const& param = GetParam(); TestAddress const& listener = param.listener; @@ -567,7 +575,8 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog) { ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), SyscallSucceeds()); - ASSERT_THAT(listen(listen_fd.get(), 2), SyscallSucceeds()); + constexpr int kBacklogSize = 2; + ASSERT_THAT(listen(listen_fd.get(), kBacklogSize), SyscallSucceeds()); // Get the port bound by the listening socket. socklen_t addrlen = listener.addr_len; @@ -1143,6 +1152,9 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) { sockaddr_storage conn_addr = connector.addr; ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + + // TODO(b/157236388): Reenable Cooperative S/R once bug is fixed. + DisableSave ds; ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_addr), connector.addr_len), diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index 3f2c0fdf2..f69f8f99f 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -472,5 +472,19 @@ TEST_P(UDPSocketPairTest, SetAndGetSocketLinger) { EXPECT_EQ(0, memcmp(&sl, &got_linger, length)); } +// Test getsockopt for SO_ACCEPTCONN on udp socket. +TEST_P(UDPSocketPairTest, GetSocketAcceptConn) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int got = -1; + socklen_t length = sizeof(got); + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_ACCEPTCONN, &got, &length), + SyscallSucceedsWithValue(0)); + + ASSERT_EQ(length, sizeof(got)); + EXPECT_EQ(got, 0); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index a72c76c97..b3f54e7f6 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -28,6 +28,7 @@ #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/posix_error.h" +#include "test/util/save_util.h" #include "test/util/test_util.h" namespace gvisor { @@ -75,7 +76,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNoGroup) { // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; EXPECT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), PosixErrorIs(EAGAIN, ::testing::_)); } @@ -209,7 +210,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddr) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -265,7 +266,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNic) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -321,7 +322,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddr) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -377,7 +378,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNic) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -437,7 +438,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrConnect) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -497,7 +498,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicConnect) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -553,7 +554,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelf) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -609,7 +610,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelf) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -669,7 +670,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfConnect) { // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; EXPECT_THAT( - RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), PosixErrorIs(EAGAIN, ::testing::_)); } @@ -727,7 +728,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfConnect) { // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; EXPECT_THAT( - RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), PosixErrorIs(EAGAIN, ::testing::_)); } @@ -785,7 +786,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfNoLoop) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -845,7 +846,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfNoLoop) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -919,7 +920,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropAddr) { // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; EXPECT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), PosixErrorIs(EAGAIN, ::testing::_)); } @@ -977,7 +978,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropNic) { // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; EXPECT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), PosixErrorIs(EAGAIN, ::testing::_)); } @@ -1330,8 +1331,8 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionOnTwoSockets) { // Check that we received the multicast packet on both sockets. for (auto& sockets : socket_pairs) { char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RecvMsgTimeout(sockets->second_fd(), recv_buf, - sizeof(recv_buf), 1 /*timeout*/), + ASSERT_THAT(RecvTimeout(sockets->second_fd(), recv_buf, sizeof(recv_buf), + 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -1409,8 +1410,8 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) { // Check that we received the multicast packet on both sockets. for (auto& sockets : socket_pairs) { char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RecvMsgTimeout(sockets->second_fd(), recv_buf, - sizeof(recv_buf), 1 /*timeout*/), + ASSERT_THAT(RecvTimeout(sockets->second_fd(), recv_buf, sizeof(recv_buf), + 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -1432,8 +1433,8 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) { char recv_buf[sizeof(send_buf)] = {}; for (auto& sockets : socket_pairs) { - ASSERT_THAT(RecvMsgTimeout(sockets->second_fd(), recv_buf, - sizeof(recv_buf), 1 /*timeout*/), + ASSERT_THAT(RecvTimeout(sockets->second_fd(), recv_buf, sizeof(recv_buf), + 1 /*timeout*/), PosixErrorIs(EAGAIN, ::testing::_)); } } @@ -1486,7 +1487,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenJoinThenReceive) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -1530,7 +1531,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenNoJoinThenNoReceive) { // Check that we don't receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), PosixErrorIs(EAGAIN, ::testing::_)); } @@ -1580,7 +1581,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenSend) { // Check that we received the packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -1627,7 +1628,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenReceive) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -1678,7 +1679,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenSend) { // Check that we received the packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -1737,7 +1738,7 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) { // of the other sockets to have received it, but we will check that later. char recv_buf[sizeof(send_buf)] = {}; EXPECT_THAT( - RecvMsgTimeout(last->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(last->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(send_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -1745,9 +1746,9 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) { // Verify that no other messages were received. for (auto& socket : sockets) { char recv_buf[kMessageSize] = {}; - EXPECT_THAT(RecvMsgTimeout(socket->get(), recv_buf, sizeof(recv_buf), - 1 /*timeout*/), - PosixErrorIs(EAGAIN, ::testing::_)); + EXPECT_THAT( + RecvTimeout(socket->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + PosixErrorIs(EAGAIN, ::testing::_)); } } @@ -2108,6 +2109,9 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) { constexpr int kMessageSize = 10; + // Saving during each iteration of the following loop is too expensive. + DisableSave ds; + for (int i = 0; i < 100; ++i) { // Send a new message to the REUSEADDR/REUSEPORT group. We use a new socket // each time so that a new ephemerial port will be used each time. This @@ -2120,16 +2124,18 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) { SyscallSucceedsWithValue(sizeof(send_buf))); } + ds.reset(); + // Check that both receivers got messages. This checks that we are using load // balancing (REUSEPORT) instead of the most recently bound socket // (REUSEADDR). char recv_buf[kMessageSize] = {}; - EXPECT_THAT(RecvMsgTimeout(receiver1->get(), recv_buf, sizeof(recv_buf), - 1 /*timeout*/), - IsPosixErrorOkAndHolds(kMessageSize)); - EXPECT_THAT(RecvMsgTimeout(receiver2->get(), recv_buf, sizeof(recv_buf), - 1 /*timeout*/), - IsPosixErrorOkAndHolds(kMessageSize)); + EXPECT_THAT( + RecvTimeout(receiver1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(kMessageSize)); + EXPECT_THAT( + RecvTimeout(receiver2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(kMessageSize)); } // Test that socket will receive packet info control message. @@ -2193,8 +2199,8 @@ TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) { received_msg.msg_controllen = CMSG_LEN(cmsg_data_len); received_msg.msg_control = received_cmsg_buf; - ASSERT_THAT(RetryEINTR(recvmsg)(receiver->get(), &received_msg, 0), - SyscallSucceedsWithValue(kDataLength)); + ASSERT_THAT(RecvMsgTimeout(receiver->get(), &received_msg, 1 /*timeout*/), + IsPosixErrorOkAndHolds(kDataLength)); cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); ASSERT_NE(cmsg, nullptr); diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc index 49a0f06d9..875016812 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc @@ -40,17 +40,9 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, JoinSubnet) { /*prefixlen=*/24, &addr, sizeof(addr))); Cleanup defer_addr_removal = Cleanup( [loopback_link = std::move(loopback_link), addr = std::move(addr)] { - if (IsRunningOnGvisor()) { - // TODO(gvisor.dev/issue/3921): Remove this once deleting addresses - // via netlink is supported. - EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET, - /*prefixlen=*/24, &addr, sizeof(addr)), - PosixErrorIs(EOPNOTSUPP, ::testing::_)); - } else { - EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET, - /*prefixlen=*/24, &addr, - sizeof(addr))); - } + EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, + sizeof(addr))); }); auto snd_sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -124,17 +116,9 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, ReuseAddrSubnetDirectedBroadcast) { 24 /* prefixlen */, &addr, sizeof(addr))); Cleanup defer_addr_removal = Cleanup( [loopback_link = std::move(loopback_link), addr = std::move(addr)] { - if (IsRunningOnGvisor()) { - // TODO(gvisor.dev/issue/3921): Remove this once deleting addresses - // via netlink is supported. - EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET, - /*prefixlen=*/24, &addr, sizeof(addr)), - PosixErrorIs(EOPNOTSUPP, ::testing::_)); - } else { - EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET, - /*prefixlen=*/24, &addr, - sizeof(addr))); - } + EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, + sizeof(addr))); }); TestAddress broadcast_address("SubnetBroadcastAddress"); diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc index 241ddad74..e83f0d81f 100644 --- a/test/syscalls/linux/socket_netlink_route.cc +++ b/test/syscalls/linux/socket_netlink_route.cc @@ -511,53 +511,42 @@ TEST(NetlinkRouteTest, LookupAll) { ASSERT_GT(count, 0); } -TEST(NetlinkRouteTest, AddAddr) { +TEST(NetlinkRouteTest, AddAndRemoveAddr) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + // Don't do cooperative save/restore because netstack state is not restored. + // TODO(gvisor.dev/issue/4595): enable cooperative save tests. + const DisableSave ds; Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - - struct request { - struct nlmsghdr hdr; - struct ifaddrmsg ifa; - struct rtattr rtattr; - struct in_addr addr; - char pad[NLMSG_ALIGNTO + RTA_ALIGNTO]; - }; - - struct request req = {}; - req.hdr.nlmsg_type = RTM_NEWADDR; - req.hdr.nlmsg_seq = kSeq; - req.ifa.ifa_family = AF_INET; - req.ifa.ifa_prefixlen = 24; - req.ifa.ifa_flags = 0; - req.ifa.ifa_scope = 0; - req.ifa.ifa_index = loopback_link.index; - req.rtattr.rta_type = IFA_LOCAL; - req.rtattr.rta_len = RTA_LENGTH(sizeof(req.addr)); - inet_pton(AF_INET, "10.0.0.1", &req.addr); - req.hdr.nlmsg_len = - NLMSG_LENGTH(sizeof(req.ifa)) + NLMSG_ALIGN(req.rtattr.rta_len); + struct in_addr addr; + ASSERT_EQ(inet_pton(AF_INET, "10.0.0.1", &addr), 1); // Create should succeed, as no such address in kernel. - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_ACK; - EXPECT_NO_ERRNO( - NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len)); + ASSERT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, sizeof(addr))); + + Cleanup defer_addr_removal = Cleanup( + [loopback_link = std::move(loopback_link), addr = std::move(addr)] { + // First delete should succeed, as address exists. + EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, + sizeof(addr))); + + // Second delete should fail, as address no longer exists. + EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, sizeof(addr)), + PosixErrorIs(EINVAL, ::testing::_)); + }); // Replace an existing address should succeed. - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_REPLACE | NLM_F_ACK; - req.hdr.nlmsg_seq++; - EXPECT_NO_ERRNO( - NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len)); + ASSERT_NO_ERRNO(LinkReplaceLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, sizeof(addr))); // Create exclusive should fail, as we created the address above. - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL | NLM_F_ACK; - req.hdr.nlmsg_seq++; - EXPECT_THAT( - NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len), - PosixErrorIs(EEXIST, ::testing::_)); + EXPECT_THAT(LinkAddExclusiveLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, sizeof(addr)), + PosixErrorIs(EEXIST, ::testing::_)); } // GetRouteDump tests a RTM_GETROUTE + NLM_F_DUMP request. diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc index 7a0bad4cb..46f749c7c 100644 --- a/test/syscalls/linux/socket_netlink_route_util.cc +++ b/test/syscalls/linux/socket_netlink_route_util.cc @@ -29,6 +29,8 @@ constexpr uint32_t kSeq = 12345; // Types of address modifications that may be performed on an interface. enum class LinkAddrModification { kAdd, + kAddExclusive, + kReplace, kDelete, }; @@ -40,6 +42,14 @@ PosixError PopulateNlmsghdr(LinkAddrModification modification, hdr->nlmsg_type = RTM_NEWADDR; hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; return NoError(); + case LinkAddrModification::kAddExclusive: + hdr->nlmsg_type = RTM_NEWADDR; + hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_EXCL | NLM_F_ACK; + return NoError(); + case LinkAddrModification::kReplace: + hdr->nlmsg_type = RTM_NEWADDR; + hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_REPLACE | NLM_F_ACK; + return NoError(); case LinkAddrModification::kDelete: hdr->nlmsg_type = RTM_DELADDR; hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; @@ -144,6 +154,18 @@ PosixError LinkAddLocalAddr(int index, int family, int prefixlen, LinkAddrModification::kAdd); } +PosixError LinkAddExclusiveLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen) { + return LinkModifyLocalAddr(index, family, prefixlen, addr, addrlen, + LinkAddrModification::kAddExclusive); +} + +PosixError LinkReplaceLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen) { + return LinkModifyLocalAddr(index, family, prefixlen, addr, addrlen, + LinkAddrModification::kReplace); +} + PosixError LinkDelLocalAddr(int index, int family, int prefixlen, const void* addr, int addrlen) { return LinkModifyLocalAddr(index, family, prefixlen, addr, addrlen, diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h index e5badca70..eaa91ad79 100644 --- a/test/syscalls/linux/socket_netlink_route_util.h +++ b/test/syscalls/linux/socket_netlink_route_util.h @@ -39,10 +39,19 @@ PosixErrorOr<std::vector<Link>> DumpLinks(); // Returns the loopback link on the system. ENOENT if not found. PosixErrorOr<Link> LoopbackLink(); -// LinkAddLocalAddr sets IFA_LOCAL attribute on the interface. +// LinkAddLocalAddr adds a new IFA_LOCAL address to the interface. PosixError LinkAddLocalAddr(int index, int family, int prefixlen, const void* addr, int addrlen); +// LinkAddExclusiveLocalAddr adds a new IFA_LOCAL address with NLM_F_EXCL flag +// to the interface. +PosixError LinkAddExclusiveLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen); + +// LinkReplaceLocalAddr replaces an IFA_LOCAL address on the interface. +PosixError LinkReplaceLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen); + // LinkDelLocalAddr removes IFA_LOCAL attribute on the interface. PosixError LinkDelLocalAddr(int index, int family, int prefixlen, const void* addr, int addrlen); diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc index e11792309..a760581b5 100644 --- a/test/syscalls/linux/socket_test_util.cc +++ b/test/syscalls/linux/socket_test_util.cc @@ -753,8 +753,7 @@ PosixErrorOr<int> SendMsg(int sock, msghdr* msg, char buf[], int buf_size) { return ret; } -PosixErrorOr<int> RecvMsgTimeout(int sock, char buf[], int buf_size, - int timeout) { +PosixErrorOr<int> RecvTimeout(int sock, char buf[], int buf_size, int timeout) { fd_set rfd; struct timeval to = {.tv_sec = timeout, .tv_usec = 0}; FD_ZERO(&rfd); @@ -767,6 +766,19 @@ PosixErrorOr<int> RecvMsgTimeout(int sock, char buf[], int buf_size, return ret; } +PosixErrorOr<int> RecvMsgTimeout(int sock, struct msghdr* msg, int timeout) { + fd_set rfd; + struct timeval to = {.tv_sec = timeout, .tv_usec = 0}; + FD_ZERO(&rfd); + FD_SET(sock, &rfd); + + int ret; + RETURN_ERROR_IF_SYSCALL_FAIL(ret = select(1, &rfd, NULL, NULL, &to)); + RETURN_ERROR_IF_SYSCALL_FAIL( + ret = RetryEINTR(recvmsg)(sock, msg, MSG_DONTWAIT)); + return ret; +} + void RecvNoData(int sock) { char data = 0; struct iovec iov; diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h index 468bc96e0..5e205339f 100644 --- a/test/syscalls/linux/socket_test_util.h +++ b/test/syscalls/linux/socket_test_util.h @@ -467,9 +467,12 @@ PosixError FreeAvailablePort(int port); // SendMsg converts a buffer to an iovec and adds it to msg before sending it. PosixErrorOr<int> SendMsg(int sock, msghdr* msg, char buf[], int buf_size); -// RecvMsgTimeout calls select on sock with timeout and then calls recv on sock. -PosixErrorOr<int> RecvMsgTimeout(int sock, char buf[], int buf_size, - int timeout); +// RecvTimeout calls select on sock with timeout and then calls recv on sock. +PosixErrorOr<int> RecvTimeout(int sock, char buf[], int buf_size, int timeout); + +// RecvMsgTimeout calls select on sock with timeout and then calls recvmsg on +// sock. +PosixErrorOr<int> RecvMsgTimeout(int sock, msghdr* msg, int timeout); // RecvNoData checks that no data is receivable on sock. void RecvNoData(int sock); diff --git a/test/syscalls/linux/socket_unix_stream.cc b/test/syscalls/linux/socket_unix_stream.cc index 1edcb15a7..ad9c4bf37 100644 --- a/test/syscalls/linux/socket_unix_stream.cc +++ b/test/syscalls/linux/socket_unix_stream.cc @@ -121,6 +121,19 @@ TEST_P(StreamUnixSocketPairTest, SetAndGetSocketLinger) { EXPECT_EQ(0, memcmp(&got_linger, &sl, length)); } +TEST_P(StreamUnixSocketPairTest, GetSocketAcceptConn) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int got = -1; + socklen_t length = sizeof(got); + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_ACCEPTCONN, &got, &length), + SyscallSucceedsWithValue(0)); + + ASSERT_EQ(length, sizeof(got)); + EXPECT_EQ(got, 0); +} + INSTANTIATE_TEST_SUITE_P( AllUnixDomainSockets, StreamUnixSocketPairTest, ::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>( diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc index 92260b1e1..6e7142a42 100644 --- a/test/syscalls/linux/stat.cc +++ b/test/syscalls/linux/stat.cc @@ -31,6 +31,7 @@ #include "test/util/cleanup.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" +#include "test/util/save_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -328,7 +329,10 @@ TEST_F(StatTest, LeadingDoubleSlash) { ASSERT_THAT(lstat(double_slash_path.c_str(), &double_slash_st), SyscallSucceeds()); EXPECT_EQ(st.st_dev, double_slash_st.st_dev); - EXPECT_EQ(st.st_ino, double_slash_st.st_ino); + // Inode numbers for gofer-accessed files may change across save/restore. + if (!IsRunningWithSaveRestore()) { + EXPECT_EQ(st.st_ino, double_slash_st.st_ino); + } } // Test that a rename doesn't change the underlying file. @@ -346,8 +350,14 @@ TEST_F(StatTest, StatDoesntChangeAfterRename) { EXPECT_EQ(st_old.st_nlink, st_new.st_nlink); EXPECT_EQ(st_old.st_dev, st_new.st_dev); + // Inode numbers for gofer-accessed files on which no reference is held may + // change across save/restore because the information that the gofer client + // uses to track file identity (9P QID path) is inconsistent between gofer + // processes, which are restarted across save/restore. + // // Overlay filesystems may synthesize directory inode numbers on the fly. - if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir()))) { + if (!IsRunningWithSaveRestore() && + !ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir()))) { EXPECT_EQ(st_old.st_ino, st_new.st_ino); } EXPECT_EQ(st_old.st_mode, st_new.st_mode); @@ -541,6 +551,26 @@ TEST_F(StatTest, LstatELOOPPath) { ASSERT_THAT(lstat(path.c_str(), &s), SyscallFailsWithErrno(ELOOP)); } +TEST(SimpleStatTest, DifferentFilesHaveDifferentDeviceInodeNumberPairs) { + // TODO(gvisor.dev/issue/1624): This test case fails in VFS1 save/restore + // tests because VFS1 gofer inode number assignment restarts after + // save/restore, such that the inodes for file1 and file2 (which are + // unreferenced and therefore not retained in sentry checkpoints before the + // calls to lstat()) are assigned the same inode number. + SKIP_IF(IsRunningWithVFS1() && IsRunningWithSaveRestore()); + + TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + TempPath file2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + + MaybeSave(); + struct stat st1 = ASSERT_NO_ERRNO_AND_VALUE(Lstat(file1.path())); + MaybeSave(); + struct stat st2 = ASSERT_NO_ERRNO_AND_VALUE(Lstat(file2.path())); + EXPECT_FALSE(st1.st_dev == st2.st_dev && st1.st_ino == st2.st_ino) + << "both files have device number " << st1.st_dev << " and inode number " + << st1.st_ino; +} + // Ensure that inode allocation for anonymous devices work correctly across // save/restore. In particular, inode numbers should be unique across S/R. TEST(SimpleStatTest, AnonDeviceAllocatesUniqueInodesAcrossSaveRestore) { diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 9f522f833..ebd873068 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -1725,6 +1725,63 @@ TEST_P(SimpleTcpSocketTest, CloseNonConnectedLingerOption) { ASSERT_LT((end_time - start_time), absl::Seconds(kLingerTimeout)); } +// Tests that SO_ACCEPTCONN returns non zero value for listening sockets. +TEST_P(TcpSocketTest, GetSocketAcceptConnListener) { + int got = -1; + socklen_t length = sizeof(got); + ASSERT_THAT(getsockopt(listener_, SOL_SOCKET, SO_ACCEPTCONN, &got, &length), + SyscallSucceeds()); + ASSERT_EQ(length, sizeof(got)); + EXPECT_EQ(got, 1); +} + +// Tests that SO_ACCEPTCONN returns zero value for not listening sockets. +TEST_P(TcpSocketTest, GetSocketAcceptConnNonListener) { + int got = -1; + socklen_t length = sizeof(got); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ACCEPTCONN, &got, &length), + SyscallSucceeds()); + ASSERT_EQ(length, sizeof(got)); + EXPECT_EQ(got, 0); + + ASSERT_THAT(getsockopt(t_, SOL_SOCKET, SO_ACCEPTCONN, &got, &length), + SyscallSucceeds()); + ASSERT_EQ(length, sizeof(got)); + EXPECT_EQ(got, 0); +} + +TEST_P(SimpleTcpSocketTest, GetSocketAcceptConnWithShutdown) { + // TODO(b/171345701): Fix the TCP state for listening socket on shutdown. + SKIP_IF(IsRunningOnGvisor()); + + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + // Initialize address to the loopback one. + sockaddr_storage addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); + socklen_t addrlen = sizeof(addr); + + // Bind to some port then start listening. + ASSERT_THAT(bind(s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + ASSERT_THAT(listen(s.get(), SOMAXCONN), SyscallSucceeds()); + + int got = -1; + socklen_t length = sizeof(got); + ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_ACCEPTCONN, &got, &length), + SyscallSucceeds()); + ASSERT_EQ(length, sizeof(got)); + EXPECT_EQ(got, 1); + + EXPECT_THAT(shutdown(s.get(), SHUT_RD), SyscallSucceeds()); + ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_ACCEPTCONN, &got, &length), + SyscallSucceeds()); + ASSERT_EQ(length, sizeof(got)); + EXPECT_EQ(got, 0); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc index 1a7673317..bc5bd9218 100644 --- a/test/syscalls/linux/udp_socket.cc +++ b/test/syscalls/linux/udp_socket.cc @@ -679,6 +679,43 @@ TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) { SyscallSucceedsWithValue(sizeof(buf))); } +TEST_P(UdpSocketTest, ConnectAndSendNoReceiver) { + ASSERT_NO_ERRNO(BindLoopback()); + // Close the socket to release the port so that we get an ICMP error. + ASSERT_THAT(close(bind_.release()), SyscallSucceeds()); + + // Connect to loopback:bind_addr_ which should *hopefully* not be bound by an + // UDP socket. There is no easy way to ensure that the UDP port is not bound + // by another conncurrently running test. *This is potentially flaky*. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + char buf[512]; + EXPECT_THAT(send(sock_.get(), buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(buf))); + + constexpr int kTimeout = 1000; + // Poll to make sure we get the ICMP error back before issuing more writes. + struct pollfd pfd = {sock_.get(), POLLERR, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + + // Next write should fail with ECONNREFUSED due to the ICMP error generated in + // response to the previous write. + ASSERT_THAT(send(sock_.get(), buf, sizeof(buf), 0), + SyscallFailsWithErrno(ECONNREFUSED)); + + // The next write should succeed again since the last write call would have + // retrieved and cleared the socket error. + ASSERT_THAT(send(sock_.get(), buf, sizeof(buf), 0), SyscallSucceeds()); + + // Poll to make sure we get the ICMP error back before issuing more writes. + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + + // Next write should fail with ECONNREFUSED due to the ICMP error generated in + // response to the previous write. + ASSERT_THAT(send(sock_.get(), buf, sizeof(buf), 0), + SyscallFailsWithErrno(ECONNREFUSED)); +} + TEST_P(UdpSocketTest, ZerolengthWriteAllowed) { // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. SKIP_IF(IsRunningWithHostinet()); @@ -838,7 +875,7 @@ TEST_P(UdpSocketTest, ReceiveBeforeConnect) { // Receive the data. It works because it was sent before the connect. char received[sizeof(buf)]; EXPECT_THAT( - RecvMsgTimeout(bind_.get(), received, sizeof(received), 1 /*timeout*/), + RecvTimeout(bind_.get(), received, sizeof(received), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(received))); EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); @@ -928,9 +965,8 @@ TEST_P(UdpSocketTest, ReadShutdownNonblockPendingData) { SyscallSucceedsWithValue(1)); // We should get the data even though read has been shutdown. - EXPECT_THAT( - RecvMsgTimeout(bind_.get(), received, 2 /*buf_size*/, 1 /*timeout*/), - IsPosixErrorOkAndHolds(2)); + EXPECT_THAT(RecvTimeout(bind_.get(), received, 2 /*buf_size*/, 1 /*timeout*/), + IsPosixErrorOkAndHolds(2)); // Because we read less than the entire packet length, since it's a packet // based socket any subsequent reads should return EWOULDBLOCK. @@ -1698,8 +1734,8 @@ TEST_P(UdpSocketTest, RecvBufLimitsEmptyRcvBuf) { sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), SyscallSucceedsWithValue(buf.size())); std::vector<char> received(buf.size()); - EXPECT_THAT(RecvMsgTimeout(bind_.get(), received.data(), received.size(), - 1 /*timeout*/), + EXPECT_THAT(RecvTimeout(bind_.get(), received.data(), received.size(), + 1 /*timeout*/), IsPosixErrorOkAndHolds(received.size())); } @@ -1714,8 +1750,8 @@ TEST_P(UdpSocketTest, RecvBufLimitsEmptyRcvBuf) { SyscallSucceedsWithValue(buf.size())); std::vector<char> received(buf.size()); - ASSERT_THAT(RecvMsgTimeout(bind_.get(), received.data(), received.size(), - 1 /*timeout*/), + ASSERT_THAT(RecvTimeout(bind_.get(), received.data(), received.size(), + 1 /*timeout*/), IsPosixErrorOkAndHolds(received.size())); } } @@ -1785,8 +1821,8 @@ TEST_P(UdpSocketTest, RecvBufLimits) { for (int i = 0; i < sent - 1; i++) { // Receive the data. std::vector<char> received(buf.size()); - EXPECT_THAT(RecvMsgTimeout(bind_.get(), received.data(), received.size(), - 1 /*timeout*/), + EXPECT_THAT(RecvTimeout(bind_.get(), received.data(), received.size(), + 1 /*timeout*/), IsPosixErrorOkAndHolds(received.size())); EXPECT_EQ(memcmp(buf.data(), received.data(), buf.size()), 0); } diff --git a/test/util/BUILD b/test/util/BUILD index 26c2b6a2f..1b028a477 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -155,6 +155,10 @@ cc_library( ], hdrs = ["save_util.h"], defines = select_system(), + deps = [ + ":logging", + "@com_google_absl//absl/types:optional", + ], ) cc_library( diff --git a/test/util/posix_error.cc b/test/util/posix_error.cc index cebf7e0ac..deed0c05b 100644 --- a/test/util/posix_error.cc +++ b/test/util/posix_error.cc @@ -87,7 +87,7 @@ bool PosixErrorIsMatcherCommonImpl::MatchAndExplain( return false; } - if (!message_matcher_.Matches(error.error_message())) { + if (!message_matcher_.Matches(error.message())) { return false; } diff --git a/test/util/posix_error.h b/test/util/posix_error.h index ad666bce0..b634a7f78 100644 --- a/test/util/posix_error.h +++ b/test/util/posix_error.h @@ -26,11 +26,6 @@ namespace gvisor { namespace testing { -class PosixErrorIsMatcherCommonImpl; - -template <typename T> -class PosixErrorOr; - class ABSL_MUST_USE_RESULT PosixError { public: PosixError() {} @@ -49,7 +44,8 @@ class ABSL_MUST_USE_RESULT PosixError { // PosixErrorOr. const PosixError& error() const { return *this; } - std::string error_message() const { return msg_; } + int errno_value() const { return errno_; } + std::string message() const { return msg_; } // ToString produces a full string representation of this posix error // including the printable representation of the errno and the error message. @@ -61,14 +57,8 @@ class ABSL_MUST_USE_RESULT PosixError { void IgnoreError() const {} private: - int errno_value() const { return errno_; } int errno_ = 0; std::string msg_; - - friend class PosixErrorIsMatcherCommonImpl; - - template <typename T> - friend class PosixErrorOr; }; template <typename T> @@ -94,15 +84,12 @@ class ABSL_MUST_USE_RESULT PosixErrorOr { template <typename U> PosixErrorOr& operator=(PosixErrorOr<U> other); - // Return a reference to the error or NoError(). - PosixError error() const; - - // Returns this->error().error_message(); - std::string error_message() const; - // Returns true if this PosixErrorOr contains some T. bool ok() const; + // Return a copy of the contained PosixError or NoError(). + PosixError error() const; + // Returns a reference to our current value, or CHECK-fails if !this->ok(). const T& ValueOrDie() const&; T& ValueOrDie() &; @@ -115,7 +102,6 @@ class ABSL_MUST_USE_RESULT PosixErrorOr { void IgnoreError() const {} private: - int errno_value() const; absl::variant<T, PosixError> value_; friend class PosixErrorIsMatcherCommonImpl; @@ -171,16 +157,6 @@ PosixError PosixErrorOr<T>::error() const { } template <typename T> -int PosixErrorOr<T>::errno_value() const { - return error().errno_value(); -} - -template <typename T> -std::string PosixErrorOr<T>::error_message() const { - return error().error_message(); -} - -template <typename T> bool PosixErrorOr<T>::ok() const { return absl::holds_alternative<T>(value_); } diff --git a/test/util/save_util.cc b/test/util/save_util.cc index 384d626f0..59d47e06e 100644 --- a/test/util/save_util.cc +++ b/test/util/save_util.cc @@ -21,35 +21,47 @@ #include <atomic> #include <cerrno> -#define GVISOR_COOPERATIVE_SAVE_TEST "GVISOR_COOPERATIVE_SAVE_TEST" +#include "absl/types/optional.h" namespace gvisor { namespace testing { namespace { -enum class CooperativeSaveMode { - kUnknown = 0, // cooperative_save_mode is statically-initialized to 0 - kAvailable, - kNotAvailable, -}; - -std::atomic<CooperativeSaveMode> cooperative_save_mode; - -bool CooperativeSaveEnabled() { - auto mode = cooperative_save_mode.load(); - if (mode == CooperativeSaveMode::kUnknown) { - mode = (getenv(GVISOR_COOPERATIVE_SAVE_TEST) != nullptr) - ? CooperativeSaveMode::kAvailable - : CooperativeSaveMode::kNotAvailable; - cooperative_save_mode.store(mode); +std::atomic<absl::optional<bool>> cooperative_save_present; +std::atomic<absl::optional<bool>> random_save_present; + +bool CooperativeSavePresent() { + auto present = cooperative_save_present.load(); + if (!present.has_value()) { + present = getenv("GVISOR_COOPERATIVE_SAVE_TEST") != nullptr; + cooperative_save_present.store(present); + } + return present.value(); +} + +bool RandomSavePresent() { + auto present = random_save_present.load(); + if (!present.has_value()) { + present = getenv("GVISOR_RANDOM_SAVE_TEST") != nullptr; + random_save_present.store(present); } - return mode == CooperativeSaveMode::kAvailable; + return present.value(); } std::atomic<int> save_disable; } // namespace +bool IsRunningWithSaveRestore() { + return CooperativeSavePresent() || RandomSavePresent(); +} + +void MaybeSave() { + if (CooperativeSavePresent() && save_disable.load() == 0) { + internal::DoCooperativeSave(); + } +} + DisableSave::DisableSave() { save_disable++; } DisableSave::~DisableSave() { reset(); } @@ -61,11 +73,5 @@ void DisableSave::reset() { } } -namespace internal { -bool ShouldSave() { - return CooperativeSaveEnabled() && (save_disable.load() == 0); -} -} // namespace internal - } // namespace testing } // namespace gvisor diff --git a/test/util/save_util.h b/test/util/save_util.h index bddad6120..e7218ae88 100644 --- a/test/util/save_util.h +++ b/test/util/save_util.h @@ -17,9 +17,17 @@ namespace gvisor { namespace testing { -// Disable save prevents saving while the given function executes. + +// Returns true if the environment in which the calling process is executing +// allows the test to be checkpointed and restored during execution. +bool IsRunningWithSaveRestore(); + +// May perform a co-operative save cycle. // -// This lasts the duration of the object, unless reset is called. +// errno is guaranteed to be preserved. +void MaybeSave(); + +// Causes MaybeSave to become a no-op until destroyed or reset. class DisableSave { public: DisableSave(); @@ -37,13 +45,13 @@ class DisableSave { bool reset_ = false; }; -// May perform a co-operative save cycle. +namespace internal { + +// Causes a co-operative save cycle to occur. // // errno is guaranteed to be preserved. -void MaybeSave(); +void DoCooperativeSave(); -namespace internal { -bool ShouldSave(); } // namespace internal } // namespace testing diff --git a/test/util/save_util_linux.cc b/test/util/save_util_linux.cc index fbac94912..57431b3ea 100644 --- a/test/util/save_util_linux.cc +++ b/test/util/save_util_linux.cc @@ -30,19 +30,19 @@ namespace gvisor { namespace testing { - -void MaybeSave() { - if (internal::ShouldSave()) { - int orig_errno = errno; - // We use it to trigger saving the sentry state - // when this syscall is called. - // Notice: this needs to be a valid syscall - // that is not used in any of the syscall tests. - syscall(SYS_TRIGGER_SAVE, nullptr, 0); - errno = orig_errno; - } +namespace internal { + +void DoCooperativeSave() { + int orig_errno = errno; + // We use it to trigger saving the sentry state + // when this syscall is called. + // Notice: this needs to be a valid syscall + // that is not used in any of the syscall tests. + syscall(SYS_TRIGGER_SAVE, nullptr, 0); + errno = orig_errno; } +} // namespace internal } // namespace testing } // namespace gvisor diff --git a/test/util/save_util_other.cc b/test/util/save_util_other.cc index 931af2c29..7749ded76 100644 --- a/test/util/save_util_other.cc +++ b/test/util/save_util_other.cc @@ -14,13 +14,17 @@ #ifndef __linux__ +#include "test/util/logging.h" + namespace gvisor { namespace testing { +namespace internal { -void MaybeSave() { - // Saving is never available in a non-linux environment. +void DoCooperativeSave() { + TEST_CHECK_MSG(false, "DoCooperativeSave not implemented"); } +} // namespace internal } // namespace testing } // namespace gvisor diff --git a/tools/bigquery/BUILD b/tools/bigquery/BUILD index 2b0062a63..1cea9e1c9 100644 --- a/tools/bigquery/BUILD +++ b/tools/bigquery/BUILD @@ -9,5 +9,8 @@ go_library( visibility = [ "//:sandbox", ], - deps = ["@com_google_cloud_go_bigquery//:go_default_library"], + deps = [ + "@com_google_cloud_go_bigquery//:go_default_library", + "@org_golang_google_api//option:go_default_library", + ], ) diff --git a/tools/bigquery/bigquery.go b/tools/bigquery/bigquery.go index 5f1a882de..34b270cc0 100644 --- a/tools/bigquery/bigquery.go +++ b/tools/bigquery/bigquery.go @@ -25,22 +25,30 @@ import ( "time" bq "cloud.google.com/go/bigquery" + "google.golang.org/api/option" ) -// Benchmark is the top level structure of recorded benchmark data. BigQuery +// Suite is the top level structure for a benchmark run. BigQuery // will infer the schema from this. +type Suite struct { + Name string `bq:"name"` + Conditions []*Condition `bq:"conditions"` + Benchmarks []*Benchmark `bq:"benchmarks"` + Official bool `bq:"official"` + Timestamp time.Time `bq:"timestamp"` +} + +// Benchmark represents an individual benchmark in a suite. type Benchmark struct { Name string `bq:"name"` Condition []*Condition `bq:"condition"` - Timestamp time.Time `bq:"timestamp"` - Official bool `bq:"official"` Metric []*Metric `bq:"metric"` - Metadata *Metadata `bq:"metadata"` } -// Condition represents qualifiers for the benchmark. For example: +// Condition represents qualifiers for the benchmark or suite. For example: // Get_Pid/1/real_time would have Benchmark Name "Get_Pid" with "1" -// and "real_time" parameters as conditions. +// and "real_time" parameters as conditions. Suite conditions include +// information such as the CL number and platform name. type Condition struct { Name string `bq:"name"` Value string `bq:"value"` @@ -53,19 +61,9 @@ type Metric struct { Sample float64 `bq:"sample"` } -// Metadata about this benchmark. -type Metadata struct { - CL string `bq:"changelist"` - IterationID string `bq:"iteration_id"` - PendingCL string `bq:"pending_cl"` - Workflow string `bq:"workflow"` - Platform string `bq:"platform"` - Gofer string `bq:"gofer"` -} - // InitBigQuery initializes a BigQuery dataset/table in the project. If the dataset/table already exists, it is not duplicated. -func InitBigQuery(ctx context.Context, projectID, datasetID, tableID string) error { - client, err := bq.NewClient(ctx, projectID) +func InitBigQuery(ctx context.Context, projectID, datasetID, tableID string, opts []option.ClientOption) error { + client, err := bq.NewClient(ctx, projectID, opts...) if err != nil { return fmt.Errorf("failed to initialize client on project %s: %v", projectID, err) } @@ -77,7 +75,7 @@ func InitBigQuery(ctx context.Context, projectID, datasetID, tableID string) err } table := dataset.Table(tableID) - schema, err := bq.InferSchema(Benchmark{}) + schema, err := bq.InferSchema(Suite{}) if err != nil { return fmt.Errorf("failed to infer schema: %v", err) } @@ -109,24 +107,32 @@ func (bm *Benchmark) AddMetric(metricName, unit string, sample float64) { // NewBenchmark initializes a new benchmark. func NewBenchmark(name string, iters int, official bool) *Benchmark { return &Benchmark{ - Name: name, - Timestamp: time.Now().UTC(), - Official: official, - Metric: make([]*Metric, 0), + Name: name, + Metric: make([]*Metric, 0), + } +} + +// NewSuite initializes a new Suite. +func NewSuite(name string) *Suite { + return &Suite{ + Name: name, + Timestamp: time.Now().UTC(), + Benchmarks: make([]*Benchmark, 0), + Conditions: make([]*Condition, 0), } } // SendBenchmarks sends the slice of benchmarks to the BigQuery dataset/table. -func SendBenchmarks(ctx context.Context, benchmarks []*Benchmark, projectID, datasetID, tableID string) error { - client, err := bq.NewClient(ctx, projectID) +func SendBenchmarks(ctx context.Context, suite *Suite, projectID, datasetID, tableID string, opts []option.ClientOption) error { + client, err := bq.NewClient(ctx, projectID, opts...) if err != nil { return fmt.Errorf("failed to initialize client on project: %s: %v", projectID, err) } defer client.Close() uploader := client.Dataset(datasetID).Table(tableID).Uploader() - if err = uploader.Put(ctx, benchmarks); err != nil { - return fmt.Errorf("failed to upload benchmarks to proejct %s, table %s.%s: %v", projectID, datasetID, tableID, err) + if err = uploader.Put(ctx, suite); err != nil { + return fmt.Errorf("failed to upload benchmarks %s to project %s, table %s.%s: %v", suite.Name, projectID, datasetID, tableID, err) } return nil diff --git a/tools/defs.bzl b/tools/defs.bzl index bb291c512..d75e40863 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -86,8 +86,10 @@ def go_binary(name, nogo = True, pure = False, static = False, x_defs = None, ** ) nogo_test( name = name + "_nogo", + config = "//:nogo_config", srcs = kwargs.get("srcs", []), - library = ":" + name + "_nogo_library", + deps = [":" + name + "_nogo_library"], + tags = ["nogo"], ) def calculate_sets(srcs): @@ -218,8 +220,10 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F if nogo: nogo_test( name = name + "_nogo", + config = "//:nogo_config", srcs = all_srcs, - library = ":" + name, + deps = [":" + name], + tags = ["nogo"], ) if marshal: @@ -255,8 +259,10 @@ def go_test(name, nogo = True, **kwargs): if nogo: nogo_test( name = name + "_nogo", + config = "//:nogo_config", srcs = kwargs.get("srcs", []), - library = ":" + name, + deps = [":" + name], + tags = ["nogo"], ) def proto_library(name, srcs, deps = None, has_services = 0, **kwargs): diff --git a/tools/github/nogo/BUILD b/tools/github/nogo/BUILD index 0633eaf19..19b7eec4d 100644 --- a/tools/github/nogo/BUILD +++ b/tools/github/nogo/BUILD @@ -10,7 +10,7 @@ go_library( "//tools/github:__subpackages__", ], deps = [ - "//tools/nogo/util", + "//tools/nogo", "@com_github_google_go_github_v28//github:go_default_library", ], ) diff --git a/tools/github/nogo/nogo.go b/tools/github/nogo/nogo.go index b2bc63459..27ab1b8eb 100644 --- a/tools/github/nogo/nogo.go +++ b/tools/github/nogo/nogo.go @@ -24,7 +24,7 @@ import ( "time" "github.com/google/go-github/github" - "gvisor.dev/gvisor/tools/nogo/util" + "gvisor.dev/gvisor/tools/nogo" ) // FindingsPoster is a simple wrapper around the GitHub api. @@ -35,7 +35,7 @@ type FindingsPoster struct { dryRun bool startTime time.Time - findings map[util.Finding]struct{} + findings map[nogo.Finding]struct{} client *github.Client } @@ -47,7 +47,7 @@ func NewFindingsPoster(client *github.Client, owner, repo, commit string, dryRun commit: commit, dryRun: dryRun, startTime: time.Now(), - findings: make(map[util.Finding]struct{}), + findings: make(map[nogo.Finding]struct{}), client: client, } } @@ -63,7 +63,7 @@ func (p *FindingsPoster) Walk(paths []string) error { if !strings.HasSuffix(filename, ".findings") || info.IsDir() { return nil } - findings, err := util.ExtractFindingsFromFile(filename) + findings, err := nogo.ExtractFindingsFromFile(filename) if err != nil { return err } @@ -86,7 +86,7 @@ func (p *FindingsPoster) Post() error { if p.dryRun { for finding, _ := range p.findings { // Pretty print, so that this is useful for debugging. - fmt.Printf("%s: (%s+%d) %s\n", finding.Category, finding.Path, finding.Line, finding.Message) + fmt.Printf("%s: (%s+%d) %s\n", finding.Category, finding.Position.Filename, finding.Position.Line, finding.Message) } return nil } @@ -115,12 +115,13 @@ func (p *FindingsPoster) Post() error { } annotationLevel := "failure" // Always. for finding, _ := range p.findings { + title := string(finding.Category) opts.Output.Annotations = append(opts.Output.Annotations, &github.CheckRunAnnotation{ - Path: &finding.Path, - StartLine: &finding.Line, - EndLine: &finding.Line, + Path: &finding.Position.Filename, + StartLine: &finding.Position.Line, + EndLine: &finding.Position.Line, Message: &finding.Message, - Title: &finding.Category, + Title: &title, AnnotationLevel: &annotationLevel, }) } diff --git a/tools/github/reviver/github.go b/tools/github/reviver/github.go index a95df0fb6..c4b624f2a 100644 --- a/tools/github/reviver/github.go +++ b/tools/github/reviver/github.go @@ -121,13 +121,24 @@ func (b *GitHubBugger) Activate(todo *Todo) (bool, error) { return true, nil } +var issuePrefixes = []string{ + "gvisor.dev/issue/", + "gvisor.dev/issues/", +} + // parseIssueNo parses the issue number out of the issue url. +// +// 0 is returned if url does not correspond to an issue. func parseIssueNo(url string) (int, error) { - const prefix = "gvisor.dev/issue/" - // First check if I can handle the TODO. - idStr := strings.TrimPrefix(url, prefix) - if len(url) == len(idStr) { + var idStr string + for _, p := range issuePrefixes { + if str := strings.TrimPrefix(url, p); len(str) < len(url) { + idStr = str + break + } + } + if len(idStr) == 0 { return 0, nil } diff --git a/tools/github/reviver/reviver_test.go b/tools/github/reviver/reviver_test.go index a9fb1f9f1..851306c9d 100644 --- a/tools/github/reviver/reviver_test.go +++ b/tools/github/reviver/reviver_test.go @@ -33,6 +33,15 @@ func TestProcessLine(t *testing.T) { }, }, { + line: "// TODO(foobar.com/issues/123): comment, bla. blabla.", + want: &Todo{ + Issue: "foobar.com/issues/123", + Locations: []Location{ + {Comment: "comment, bla. blabla."}, + }, + }, + }, + { line: "// FIXME(b/123): internal bug", want: &Todo{ Issue: "b/123", diff --git a/tools/nogo/BUILD b/tools/nogo/BUILD index 3c6be3339..12b8b597c 100644 --- a/tools/nogo/BUILD +++ b/tools/nogo/BUILD @@ -20,20 +20,14 @@ nogo_stdlib( visibility = ["//visibility:public"], ) -sh_binary( - name = "gentest", - srcs = ["gentest.sh"], - visibility = ["//visibility:public"], -) - go_library( name = "nogo", srcs = [ + "analyzers.go", "build.go", "config.go", - "matchers.go", + "findings.go", "nogo.go", - "register.go", ], nogo = False, visibility = ["//:sandbox"], diff --git a/tools/nogo/analyzers.go b/tools/nogo/analyzers.go new file mode 100644 index 000000000..b919bc2f8 --- /dev/null +++ b/tools/nogo/analyzers.go @@ -0,0 +1,131 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nogo + +import ( + "encoding/gob" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/asmdecl" + "golang.org/x/tools/go/analysis/passes/assign" + "golang.org/x/tools/go/analysis/passes/atomic" + "golang.org/x/tools/go/analysis/passes/bools" + "golang.org/x/tools/go/analysis/passes/buildtag" + "golang.org/x/tools/go/analysis/passes/cgocall" + "golang.org/x/tools/go/analysis/passes/composite" + "golang.org/x/tools/go/analysis/passes/copylock" + "golang.org/x/tools/go/analysis/passes/errorsas" + "golang.org/x/tools/go/analysis/passes/httpresponse" + "golang.org/x/tools/go/analysis/passes/loopclosure" + "golang.org/x/tools/go/analysis/passes/lostcancel" + "golang.org/x/tools/go/analysis/passes/nilfunc" + "golang.org/x/tools/go/analysis/passes/nilness" + "golang.org/x/tools/go/analysis/passes/printf" + "golang.org/x/tools/go/analysis/passes/shadow" + "golang.org/x/tools/go/analysis/passes/shift" + "golang.org/x/tools/go/analysis/passes/stdmethods" + "golang.org/x/tools/go/analysis/passes/stringintconv" + "golang.org/x/tools/go/analysis/passes/structtag" + "golang.org/x/tools/go/analysis/passes/tests" + "golang.org/x/tools/go/analysis/passes/unmarshal" + "golang.org/x/tools/go/analysis/passes/unreachable" + "golang.org/x/tools/go/analysis/passes/unsafeptr" + "golang.org/x/tools/go/analysis/passes/unusedresult" + "honnef.co/go/tools/staticcheck" + "honnef.co/go/tools/stylecheck" + + "gvisor.dev/gvisor/tools/checkescape" + "gvisor.dev/gvisor/tools/checkunsafe" +) + +// AllAnalyzers is a list of all available analyzers. +var AllAnalyzers = []*analysis.Analyzer{ + asmdecl.Analyzer, + assign.Analyzer, + atomic.Analyzer, + bools.Analyzer, + buildtag.Analyzer, + cgocall.Analyzer, + composite.Analyzer, + copylock.Analyzer, + errorsas.Analyzer, + httpresponse.Analyzer, + loopclosure.Analyzer, + lostcancel.Analyzer, + nilfunc.Analyzer, + nilness.Analyzer, + printf.Analyzer, + shift.Analyzer, + stdmethods.Analyzer, + stringintconv.Analyzer, + shadow.Analyzer, + structtag.Analyzer, + tests.Analyzer, + unmarshal.Analyzer, + unreachable.Analyzer, + unsafeptr.Analyzer, + unusedresult.Analyzer, + checkescape.Analyzer, + checkunsafe.Analyzer, +} + +// EscapeAnalyzers is a list of escape-related analyzers. +var EscapeAnalyzers = []*analysis.Analyzer{ + checkescape.EscapeAnalyzer, +} + +func register(all []*analysis.Analyzer) { + // Register all fact types. + // + // N.B. This needs to be done recursively, because there may be + // analyzers in the Requires list that do not appear explicitly above. + registered := make(map[*analysis.Analyzer]struct{}) + var registerOne func(*analysis.Analyzer) + registerOne = func(a *analysis.Analyzer) { + if _, ok := registered[a]; ok { + return + } + + // Register dependencies. + for _, da := range a.Requires { + registerOne(da) + } + + // Register local facts. + for _, f := range a.FactTypes { + gob.Register(f) + } + + registered[a] = struct{}{} // Done. + } + for _, a := range all { + registerOne(a) + } +} + +func init() { + // Add all staticcheck analyzers. + for _, a := range staticcheck.Analyzers { + AllAnalyzers = append(AllAnalyzers, a) + } + // Add all stylecheck analyzers. + for _, a := range stylecheck.Analyzers { + AllAnalyzers = append(AllAnalyzers, a) + } + + // Register lists. + register(AllAnalyzers) + register(EscapeAnalyzers) +} diff --git a/tools/nogo/build.go b/tools/nogo/build.go index 55d34760f..d173cff1f 100644 --- a/tools/nogo/build.go +++ b/tools/nogo/build.go @@ -20,22 +20,6 @@ import ( "os" ) -var ( - // internalPrefix is the internal path prefix. Note that this is not - // special, as paths should be passed relative to the repository root - // and should not have any special prefix applied. - internalPrefix = fmt.Sprintf("^") - - // internalDefault is applied when no paths are provided. - internalDefault = fmt.Sprintf("%s/.*", notPath("external")) - - // generatedPrefix is a regex for generated files. - generatedPrefix = "^(.*/)?(bazel-genfiles|bazel-out|bazel-bin)/" - - // externalPrefix is external workspace packages. - externalPrefix = "^external/" -) - // findStdPkg needs to find the bundled standard library packages. func findStdPkg(GOOS, GOARCH, path string) (io.ReadCloser, error) { if path == "C" { diff --git a/tools/nogo/check/BUILD b/tools/nogo/check/BUILD index 21ba2c306..e18483a18 100644 --- a/tools/nogo/check/BUILD +++ b/tools/nogo/check/BUILD @@ -2,8 +2,6 @@ load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) -# Note that the check binary must be public, since an aspect may be applied -# across lots of different rules in different repositories. go_binary( name = "check", srcs = ["main.go"], diff --git a/tools/nogo/check/main.go b/tools/nogo/check/main.go index 3828edf3a..69bdfe502 100644 --- a/tools/nogo/check/main.go +++ b/tools/nogo/check/main.go @@ -16,9 +16,99 @@ package main import ( + "encoding/json" + "flag" + "fmt" + "io/ioutil" + "log" + "os" + "gvisor.dev/gvisor/tools/nogo" ) +var ( + packageFile = flag.String("package", "", "package configuration file (in JSON format)") + stdlibFile = flag.String("stdlib", "", "stdlib configuration file (in JSON format)") + findingsOutput = flag.String("findings", "", "output file (or stdout, if not specified)") + factsOutput = flag.String("facts", "", "output file for facts (optional)") + escapesOutput = flag.String("escapes", "", "output file for escapes (optional)") +) + +func loadConfig(file string, config interface{}) interface{} { + // Load the configuration. + f, err := os.Open(file) + if err != nil { + log.Fatalf("unable to open configuration %q: %v", file, err) + } + defer f.Close() + dec := json.NewDecoder(f) + dec.DisallowUnknownFields() + if err := dec.Decode(config); err != nil { + log.Fatalf("unable to decode configuration: %v", err) + } + return config +} + func main() { - nogo.Main() + // Parse all flags. + flag.Parse() + + var ( + findings []nogo.Finding + factData []byte + err error + ) + + // Check & load the configuration. + if *packageFile != "" && *stdlibFile != "" { + log.Fatalf("unable to perform stdlib and package analysis; provide only one!") + } + + // Run the configuration. + if *stdlibFile != "" { + // Perform basic analysis. + c := loadConfig(*stdlibFile, new(nogo.StdlibConfig)).(*nogo.StdlibConfig) + findings, factData, err = nogo.CheckStdlib(c, nogo.AllAnalyzers) + + } else if *packageFile != "" { + // Perform basic analysis. + c := loadConfig(*packageFile, new(nogo.PackageConfig)).(*nogo.PackageConfig) + findings, factData, err = nogo.CheckPackage(c, nogo.AllAnalyzers, nil) + + // Do we need to do escape analysis? + if *escapesOutput != "" { + escapes, _, err := nogo.CheckPackage(c, nogo.EscapeAnalyzers, nil) + if err != nil { + log.Fatalf("error performing escape analysis: %v", err) + } + if err := nogo.WriteFindingsToFile(escapes, *escapesOutput); err != nil { + log.Fatalf("error writing escapes to %q: %v", *escapesOutput, err) + } + } + } else { + log.Fatalf("please provide at least one of package or stdlib!") + } + + // Check that analysis was successful. + if err != nil { + log.Fatalf("error performing analysis: %v", err) + } + + // Save facts. + if *factsOutput != "" { + if err := ioutil.WriteFile(*factsOutput, factData, 0644); err != nil { + log.Fatalf("error saving findings to %q: %v", *factsOutput, err) + } + } + + // Write all findings. + if *findingsOutput != "" { + if err := nogo.WriteFindingsToFile(findings, *findingsOutput); err != nil { + log.Fatalf("error writing findings to %q: %v", *findingsOutput, err) + } + } else { + for _, finding := range findings { + fmt.Fprintf(os.Stdout, "%s\n", finding.String()) + } + } } diff --git a/tools/nogo/config.go b/tools/nogo/config.go index 0853f03cf..2fea5b3e1 100644 --- a/tools/nogo/config.go +++ b/tools/nogo/config.go @@ -15,544 +15,247 @@ package nogo import ( - "golang.org/x/tools/go/analysis" - "golang.org/x/tools/go/analysis/passes/asmdecl" - "golang.org/x/tools/go/analysis/passes/assign" - "golang.org/x/tools/go/analysis/passes/atomic" - "golang.org/x/tools/go/analysis/passes/bools" - "golang.org/x/tools/go/analysis/passes/buildtag" - "golang.org/x/tools/go/analysis/passes/cgocall" - "golang.org/x/tools/go/analysis/passes/composite" - "golang.org/x/tools/go/analysis/passes/copylock" - "golang.org/x/tools/go/analysis/passes/errorsas" - "golang.org/x/tools/go/analysis/passes/httpresponse" - "golang.org/x/tools/go/analysis/passes/loopclosure" - "golang.org/x/tools/go/analysis/passes/lostcancel" - "golang.org/x/tools/go/analysis/passes/nilfunc" - "golang.org/x/tools/go/analysis/passes/nilness" - "golang.org/x/tools/go/analysis/passes/printf" - "golang.org/x/tools/go/analysis/passes/shadow" - "golang.org/x/tools/go/analysis/passes/shift" - "golang.org/x/tools/go/analysis/passes/stdmethods" - "golang.org/x/tools/go/analysis/passes/stringintconv" - "golang.org/x/tools/go/analysis/passes/structtag" - "golang.org/x/tools/go/analysis/passes/tests" - "golang.org/x/tools/go/analysis/passes/unmarshal" - "golang.org/x/tools/go/analysis/passes/unreachable" - "golang.org/x/tools/go/analysis/passes/unsafeptr" - "golang.org/x/tools/go/analysis/passes/unusedresult" - "honnef.co/go/tools/staticcheck" - "honnef.co/go/tools/stylecheck" - - "gvisor.dev/gvisor/tools/checkescape" - "gvisor.dev/gvisor/tools/checkunsafe" + "fmt" + "regexp" ) -var analyzerConfig = map[*analysis.Analyzer]matcher{ - // Standard analyzers. - asmdecl.Analyzer: alwaysMatches(), - assign.Analyzer: externalExcluded( - ".*gazelle/walk/walk.go", // False positive. - ), - atomic.Analyzer: alwaysMatches(), - bools.Analyzer: alwaysMatches(), - buildtag.Analyzer: alwaysMatches(), - cgocall.Analyzer: alwaysMatches(), - composite.Analyzer: and( - disableMatches(), // Disabled for now. - resultExcluded{ - "Object_", - "Range{", - }, - ), - copylock.Analyzer: internalMatches(), // Common external issues (e.g. protos). - errorsas.Analyzer: alwaysMatches(), - httpresponse.Analyzer: alwaysMatches(), - loopclosure.Analyzer: alwaysMatches(), - lostcancel.Analyzer: internalMatches(), // Common external issues. - nilfunc.Analyzer: alwaysMatches(), - nilness.Analyzer: and( - internalMatches(), // Common "tautological checks". - internalExcluded( - "pkg/sentry/platform/kvm/kvm_test.go", // Intentional. - "tools/bigquery/bigquery.go", // False positive. - ), - ), - printf.Analyzer: alwaysMatches(), - shift.Analyzer: alwaysMatches(), - stdmethods.Analyzer: internalMatches(), // Common external issues (e.g. methods named "Write"). - stringintconv.Analyzer: and( - internalExcluded(), - externalExcluded( - ".*protobuf/.*.go", // Bad conversions. - ".*flate/huffman_bit_writer.go", // Bad conversion. +// GroupName is a named group. +type GroupName string + +// AnalyzerName is a named analyzer. +type AnalyzerName string + +// Group represents a named collection of files. +type Group struct { + // Name is the short name for the group. + Name GroupName `yaml:"name"` + + // Regex matches all full paths in the group. + Regex string `yaml:"regex"` + regex *regexp.Regexp `yaml:"-"` + + // Default determines the default group behavior. + // + // If Default is true, all Analyzers are enabled for this + // group. Otherwise, Analyzers must be individually enabled + // by specifying a (possible empty) ItemConfig for the group + // in the AnalyzerConfig. + Default bool `yaml:"default"` +} + +func (g *Group) compile() error { + r, err := regexp.Compile(g.Regex) + if err != nil { + return err + } + g.regex = r + return nil +} + +// ItemConfig is an (Analyzer,Group) configuration. +type ItemConfig struct { + // Exclude are analyzer exclusions. + // + // Exclude is a list of regular expressions. If the corresponding + // Analyzer emits a Finding for which Finding.Position.String() + // matches a regular expression in Exclude, the finding will not + // be reported. + Exclude []string `yaml:"exclude,omitempty"` + exclude []*regexp.Regexp `yaml:"-"` + + // Suppress are analyzer suppressions. + // + // Suppress is a list of regular expressions. If the corresponding + // Analyzer emits a Finding for which Finding.Message matches a regular + // expression in Suppress, the finding will not be reported. + Suppress []string `yaml:"suppress,omitempty"` + suppress []*regexp.Regexp `yaml:"-"` +} + +func compileRegexps(ss []string, rs *[]*regexp.Regexp) error { + *rs = make([]*regexp.Regexp, 0, len(ss)) + for _, s := range ss { + r, err := regexp.Compile(s) + if err != nil { + return err + } + *rs = append(*rs, r) + } + return nil +} + +func (i *ItemConfig) compile() error { + if i == nil { + // This may be nil if nothing is included in the + // item configuration. That's fine, there's nothing + // to compile and nothing to exclude & suppress. + return nil + } + if err := compileRegexps(i.Exclude, &i.exclude); err != nil { + return fmt.Errorf("in exclude: %w", err) + } + if err := compileRegexps(i.Suppress, &i.suppress); err != nil { + return fmt.Errorf("in suppress: %w", err) + } + return nil +} + +func (i *ItemConfig) merge(other *ItemConfig) { + i.Exclude = append(i.Exclude, other.Exclude...) + i.Suppress = append(i.Suppress, other.Suppress...) +} + +func (i *ItemConfig) shouldReport(fullPos, msg string) bool { + if i == nil { + // See above. + return true + } + for _, r := range i.exclude { + if r.MatchString(fullPos) { + return false + } + } + for _, r := range i.suppress { + if r.MatchString(msg) { + return false + } + } + return true +} + +// AnalyzerConfig is the configuration for a single analyzers. +// +// This map is keyed by individual Group names, to allow for different +// configurations depending on what Group the file belongs to. +type AnalyzerConfig map[GroupName]*ItemConfig + +func (a AnalyzerConfig) compile() error { + for name, gc := range a { + if err := gc.compile(); err != nil { + return fmt.Errorf("invalid group %q: %v", name, err) + } + } + return nil +} + +func (a AnalyzerConfig) merge(other AnalyzerConfig) { + // Merge all the groups. + for name, gc := range other { + old, ok := a[name] + if !ok || old == nil { + a[name] = gc // Not configured in a. + continue + } + old.merge(gc) + } +} + +func (a AnalyzerConfig) shouldReport(groupConfig *Group, fullPos, msg string) bool { + gc, ok := a[groupConfig.Name] + if !ok { + return groupConfig.Default + } + + // Note that if a section appears for a particular group + // for a particular analyzer, then it will now be enabled, + // and the group default no longer applies. + return gc.shouldReport(fullPos, msg) +} + +// Config is a nogo configuration. +type Config struct { + // Prefixes defines a set of regular expressions that + // are standard "prefixes", so that files can be grouped + // and specific rules applied to individual groups. + Groups []Group `yaml:"groups"` - // Runtime internal violations. - ".*reflect/value.go", - ".*encoding/xml/xml.go", - ".*runtime/pprof/internal/profile/proto.go", - ".*fmt/scan.go", - ".*go/types/conversions.go", - ".*golang.org/x/net/dns/dnsmessage/message.go", - ), - ), - shadow.Analyzer: disableMatches(), // Disabled for now. - structtag.Analyzer: internalMatches(), // External not subject to rules. - tests.Analyzer: alwaysMatches(), - unmarshal.Analyzer: alwaysMatches(), - unreachable.Analyzer: internalMatches(), - unsafeptr.Analyzer: and( - internalMatches(), - internalExcluded( - ".*_test.go", // Exclude tests. - "pkg/flipcall/.*_unsafe.go", // Special case. - "pkg/gohacks/gohacks_unsafe.go", // Special case. - "pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go", // Special case. - "pkg/sentry/platform/kvm/bluepill_unsafe.go", // Special case. - "pkg/sentry/platform/kvm/machine_unsafe.go", // Special case. - "pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go", // Special case. - "pkg/sentry/platform/safecopy/safecopy_unsafe.go", // Special case. - "pkg/sentry/vfs/mount_unsafe.go", // Special case. - "pkg/sentry/platform/systrap/stub_unsafe.go", // Special case. - "pkg/sentry/platform/systrap/switchto_google_unsafe.go", // Special case. - "pkg/sentry/platform/systrap/sysmsg_thread_unsafe.go", // Special case. - ), - ), - unusedresult.Analyzer: alwaysMatches(), + // Global is the global analyzer config. + Global AnalyzerConfig `yaml:"global"` - // Internal analyzers: external packages not subject. - checkescape.Analyzer: internalMatches(), - checkunsafe.Analyzer: internalMatches(), + // Analyzers are individual analyzer configurations. The + // key for each analyzer is the name of the analyzer. The + // value is either a boolean (enable/disable), or a map to + // the groups above. + Analyzers map[AnalyzerName]AnalyzerConfig `yaml:"analyzers"` } -func init() { - staticMatcher := and( - // Only match internal, non-generated files. - internalMatches(), - generatedExcluded(), +// Merge merges two configurations. +func (c *Config) Merge(other *Config) { + // Merge all groups. + for _, g := range other.Groups { + // Is there a matching group? If yes, we just delete + // it. This will preserve the order provided in the + // overriding file, even if it differs. + for i := 0; i < len(c.Groups); i++ { + if g.Name == c.Groups[i].Name { + copy(c.Groups[i:], c.Groups[i+1:]) + c.Groups = c.Groups[:len(c.Groups)-1] + break + } + } + c.Groups = append(c.Groups, g) + } - // We use ALL_CAPS for system definitions, - // which are common enough in the code base - // that we shouldn't annotate exceptions. - // - // Same story for underscores. - resultExcluded([]string{ - "should not use ALL_CAPS in Go names", - "should not use underscores in Go names", - }), + // Merge global configurations. + c.Global.merge(other.Global) - // Exclude existing matches. - internalExcluded( - "pkg/abi/linux/fuse.go:22", - "pkg/abi/linux/fuse.go:25", - "pkg/abi/linux/socket.go:113", - "pkg/abi/linux/tty.go:73", - "pkg/bpf/decoder.go:112", - "pkg/cpuid/cpuid_x86.go:675", - "pkg/eventchannel/event.go:193", - "pkg/eventchannel/event.go:27", - "pkg/eventchannel/event_test.go:22", - "pkg/eventchannel/rate.go:19", - "pkg/gohacks/gohacks_unsafe.go:33", - "pkg/log/json.go:30", - "pkg/log/log.go:359", - "pkg/merkletree/merkletree.go:230", - "pkg/merkletree/merkletree.go:243", - "pkg/merkletree/merkletree.go:249", - "pkg/merkletree/merkletree.go:266", - "pkg/merkletree/merkletree.go:355", - "pkg/merkletree/merkletree.go:369", - "pkg/metric/metric_test.go:20", - "pkg/p9/p9test/client_test.go:687", - "pkg/p9/transport_test.go:196", - "pkg/pool/pool.go:15", - "pkg/refs/refcounter.go:510", - "pkg/refs/refcounter_test.go:169", - "pkg/refs_vfs2/refs.go:16", - "pkg/safemem/block_unsafe.go:89", - "pkg/seccomp/seccomp.go:82", - "pkg/segment/test/set_functions.go:15", - "pkg/sentry/arch/signal.go:166", - "pkg/sentry/arch/signal.go:171", - "pkg/sentry/control/pprof.go:196", - "pkg/sentry/devices/memdev/full.go:58", - "pkg/sentry/devices/memdev/null.go:59", - "pkg/sentry/devices/memdev/random.go:68", - "pkg/sentry/devices/memdev/zero.go:86", - "pkg/sentry/fdimport/fdimport.go:15", - "pkg/sentry/fs/attr.go:257", - "pkg/sentry/fsbridge/fs.go:116", - "pkg/sentry/fsbridge/vfs.go:124", - "pkg/sentry/fsbridge/vfs.go:70", - "pkg/sentry/fs/copy_up.go:365", - "pkg/sentry/fs/copy_up_test.go:65", - "pkg/sentry/fs/dev/net_tun.go:161", - "pkg/sentry/fs/dev/net_tun.go:63", - "pkg/sentry/fs/dev/null.go:97", - "pkg/sentry/fs/dirent_cache.go:64", - "pkg/sentry/fs/file_overlay.go:327", - "pkg/sentry/fs/file_overlay.go:524", - "pkg/sentry/fs/filetest/filetest.go:55", - "pkg/sentry/fs/filetest/filetest.go:60", - "pkg/sentry/fs/fs.go:77", - "pkg/sentry/fs/fsutil/file.go:290", - "pkg/sentry/fs/fsutil/file.go:346", - "pkg/sentry/fs/fsutil/host_file_mapper.go:105", - "pkg/sentry/fs/fsutil/inode_cached.go:676", - "pkg/sentry/fs/fsutil/inode_cached.go:772", - "pkg/sentry/fs/gofer/attr.go:120", - "pkg/sentry/fs/gofer/fifo.go:33", - "pkg/sentry/fs/gofer/inode.go:410", - "pkg/sentry/fsimpl/devpts/devpts.go:110", - "pkg/sentry/fsimpl/devpts/devpts.go:246", - "pkg/sentry/fsimpl/devpts/devpts.go:50", - "pkg/sentry/fsimpl/devpts/master.go:110", - "pkg/sentry/fsimpl/devpts/master.go:55", - "pkg/sentry/fsimpl/devpts/replica.go:113", - "pkg/sentry/fsimpl/devpts/replica.go:57", - "pkg/sentry/fsimpl/devtmpfs/devtmpfs.go:54", - "pkg/sentry/fsimpl/ext/disklayout/superblock_64.go:97", - "pkg/sentry/fsimpl/ext/disklayout/superblock_old.go:92", - "pkg/sentry/fsimpl/ext/disklayout/block_group_32.go:44", - "pkg/sentry/fsimpl/ext/disklayout/inode_new.go:91", - "pkg/sentry/fsimpl/ext/disklayout/inode_old.go:93", - "pkg/sentry/fsimpl/ext/disklayout/superblock_32.go:66", - "pkg/sentry/fsimpl/ext/disklayout/block_group_64.go:53", - "pkg/sentry/fsimpl/eventfd/eventfd.go:268", - "pkg/sentry/fsimpl/ext/directory.go:163", - "pkg/sentry/fsimpl/ext/directory.go:164", - "pkg/sentry/fsimpl/ext/extent_file.go:142", - "pkg/sentry/fsimpl/ext/extent_file.go:143", - "pkg/sentry/fsimpl/ext/ext.go:105", - "pkg/sentry/fsimpl/ext/filesystem.go:287", - "pkg/sentry/fsimpl/ext/regular_file.go:153", - "pkg/sentry/fsimpl/ext/symlink.go:113", - "pkg/sentry/fsimpl/fuse/connection_control.go:194", - "pkg/sentry/fsimpl/fuse/dev.go:387", - "pkg/sentry/fsimpl/fuse/dev_test.go:318", - "pkg/sentry/fsimpl/fuse/fusefs.go:102", - "pkg/sentry/fsimpl/fuse/read_write.go:129", - "pkg/sentry/fsimpl/fuse/request_response.go:71", - "pkg/sentry/fsimpl/gofer/directory.go:135", - "pkg/sentry/fsimpl/gofer/filesystem.go:679", - "pkg/sentry/fsimpl/gofer/gofer.go:1694", - "pkg/sentry/fsimpl/gofer/gofer.go:276", - "pkg/sentry/fsimpl/gofer/regular_file.go:81", - "pkg/sentry/fsimpl/gofer/special_file.go:141", - "pkg/sentry/fsimpl/host/host.go:184", - "pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go:50", - "pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go:90", - "pkg/sentry/fsimpl/kernfs/fd_impl_util.go:273", - "pkg/sentry/fsimpl/kernfs/filesystem.go:247", - "pkg/sentry/fsimpl/kernfs/inode_impl_util.go:320", - "pkg/sentry/fsimpl/kernfs/inode_impl_util.go:497", - "pkg/sentry/fsimpl/kernfs/synthetic_directory.go:52", - "pkg/sentry/fsimpl/overlay/directory.go:119", - "pkg/sentry/fsimpl/overlay/filesystem.go:527", - "pkg/sentry/fsimpl/overlay/non_directory.go:152", - "pkg/sentry/fsimpl/overlay/overlay.go:115", - "pkg/sentry/fsimpl/overlay/overlay.go:719", - "pkg/sentry/fsimpl/pipefs/pipefs.go:74", - "pkg/sentry/fsimpl/proc/filesystem.go:52", - "pkg/sentry/fsimpl/proc/filesystem.go:81", - "pkg/sentry/fsimpl/proc/subtasks.go:126", - "pkg/sentry/fsimpl/proc/subtasks.go:189", - "pkg/sentry/fsimpl/proc/task_fds.go:168", - "pkg/sentry/fsimpl/proc/task_fds.go:228", - "pkg/sentry/fsimpl/proc/task_fds.go:301", - "pkg/sentry/fsimpl/proc/task_fds.go:318", - "pkg/sentry/fsimpl/proc/task_fds.go:67", - "pkg/sentry/fsimpl/proc/task_files.go:112", - "pkg/sentry/fsimpl/proc/task_files.go:158", - "pkg/sentry/fsimpl/proc/task_files.go:259", - "pkg/sentry/fsimpl/proc/task_files.go:285", - "pkg/sentry/fsimpl/proc/task_files.go:305", - "pkg/sentry/fsimpl/proc/task_files.go:384", - "pkg/sentry/fsimpl/proc/task_files.go:403", - "pkg/sentry/fsimpl/proc/task_files.go:428", - "pkg/sentry/fsimpl/proc/task_files.go:691", - "pkg/sentry/fsimpl/proc/task_files.go:770", - "pkg/sentry/fsimpl/proc/task_files.go:797", - "pkg/sentry/fsimpl/proc/task_files.go:828", - "pkg/sentry/fsimpl/proc/task_files.go:879", - "pkg/sentry/fsimpl/proc/task_files.go:910", - "pkg/sentry/fsimpl/proc/task_files.go:961", - "pkg/sentry/fsimpl/proc/task.go:127", - "pkg/sentry/fsimpl/proc/task.go:193", - "pkg/sentry/fsimpl/proc/task_net.go:134", - "pkg/sentry/fsimpl/proc/task_net.go:475", - "pkg/sentry/fsimpl/proc/task_net.go:491", - "pkg/sentry/fsimpl/proc/task_net.go:508", - "pkg/sentry/fsimpl/proc/task_net.go:665", - "pkg/sentry/fsimpl/proc/task_net.go:715", - "pkg/sentry/fsimpl/proc/task_net.go:779", - "pkg/sentry/fsimpl/proc/tasks_files.go:113", - "pkg/sentry/fsimpl/proc/tasks_files.go:388", - "pkg/sentry/fsimpl/proc/tasks.go:232", - "pkg/sentry/fsimpl/proc/tasks_sys.go:145", - "pkg/sentry/fsimpl/proc/tasks_sys.go:181", - "pkg/sentry/fsimpl/proc/tasks_sys.go:239", - "pkg/sentry/fsimpl/proc/tasks_sys.go:291", - "pkg/sentry/fsimpl/proc/tasks_sys.go:375", - "pkg/sentry/fsimpl/signalfd/signalfd.go:124", - "pkg/sentry/fsimpl/signalfd/signalfd.go:15", - "pkg/sentry/fsimpl/signalfd/signalfd.go:126", - "pkg/sentry/fsimpl/sockfs/sockfs.go:36", - "pkg/sentry/fsimpl/sockfs/sockfs.go:79", - "pkg/sentry/fsimpl/sys/kcov.go:49", - "pkg/sentry/fsimpl/sys/kcov.go:99", - "pkg/sentry/fsimpl/sys/sys.go:118", - "pkg/sentry/fsimpl/sys/sys.go:56", - "pkg/sentry/fsimpl/testutil/testutil.go:257", - "pkg/sentry/fsimpl/testutil/testutil.go:260", - "pkg/sentry/fsimpl/timerfd/timerfd.go:87", - "pkg/sentry/fsimpl/tmpfs/directory.go:112", - "pkg/sentry/fsimpl/tmpfs/filesystem.go:195", - "pkg/sentry/fsimpl/tmpfs/regular_file.go:226", - "pkg/sentry/fsimpl/tmpfs/regular_file.go:346", - "pkg/sentry/fsimpl/tmpfs/tmpfs.go:103", - "pkg/sentry/fsimpl/tmpfs/tmpfs.go:733", - "pkg/sentry/fsimpl/verity/filesystem.go:490", - "pkg/sentry/fsimpl/verity/verity.go:156", - "pkg/sentry/fsimpl/verity/verity.go:629", - "pkg/sentry/fsimpl/verity/verity.go:672", - "pkg/sentry/fs/mount.go:162", - "pkg/sentry/fs/mount.go:256", - "pkg/sentry/fs/mount_overlay.go:144", - "pkg/sentry/fs/mounts.go:432", - "pkg/sentry/fs/proc/exec_args.go:104", - "pkg/sentry/fs/proc/exec_args.go:73", - "pkg/sentry/fs/proc/fds.go:269", - "pkg/sentry/fs/proc/loadavg.go:33", - "pkg/sentry/fs/proc/meminfo.go:39", - "pkg/sentry/fs/proc/mounts.go:193", - "pkg/sentry/fs/proc/mounts.go:84", - "pkg/sentry/fs/proc/net.go:125", - "pkg/sentry/fs/proc/proc.go:146", - "pkg/sentry/fs/proc/proc.go:204", - "pkg/sentry/fs/proc/seqfile/seqfile.go:210", - "pkg/sentry/fs/proc/sys.go:146", - "pkg/sentry/fs/proc/sys.go:43", - "pkg/sentry/fs/proc/sys_net.go:113", - "pkg/sentry/fs/proc/sys_net.go:205", - "pkg/sentry/fs/proc/sys_net.go:233", - "pkg/sentry/fs/proc/sys_net.go:307", - "pkg/sentry/fs/proc/sys_net.go:335", - "pkg/sentry/fs/proc/sys_net.go:446", - "pkg/sentry/fs/proc/sys_net.go:456", - "pkg/sentry/fs/proc/sys_net.go:89", - "pkg/sentry/fs/proc/task.go:170", - "pkg/sentry/fs/proc/task.go:322", - "pkg/sentry/fs/proc/task.go:427", - "pkg/sentry/fs/proc/task.go:467", - "pkg/sentry/fs/proc/task.go:500", - "pkg/sentry/fs/proc/task.go:784", - "pkg/sentry/fs/proc/task.go:839", - "pkg/sentry/fs/proc/task.go:920", - "pkg/sentry/fs/proc/uid_gid_map.go:108", - "pkg/sentry/fs/proc/uid_gid_map.go:79", - "pkg/sentry/fs/proc/uptime.go:75", - "pkg/sentry/fs/ramfs/dir.go:447", - "pkg/sentry/fs/tmpfs/inode_file.go:436", - "pkg/sentry/fs/tmpfs/inode_file.go:537", - "pkg/sentry/fs/tty/dir.go:313", - "pkg/sentry/fs/tty/master.go:131", - "pkg/sentry/fs/tty/master.go:91", - "pkg/sentry/fs/tty/replica.go:116", - "pkg/sentry/fs/tty/replica.go:88", - "pkg/sentry/kernel/auth/id_map.go:269", - "pkg/sentry/kernel/fasync/fasync.go:67", - "pkg/sentry/kernel/kcov.go:209", - "pkg/sentry/kernel/kcov.go:223", - "pkg/sentry/kernel/kernel.go:343", - "pkg/sentry/kernel/kernel.go:368", - "pkg/sentry/kernel/pipe/node_test.go:112", - "pkg/sentry/kernel/pipe/node_test.go:119", - "pkg/sentry/kernel/pipe/node_test.go:130", - "pkg/sentry/kernel/pipe/node_test.go:137", - "pkg/sentry/kernel/pipe/node_test.go:149", - "pkg/sentry/kernel/pipe/node_test.go:150", - "pkg/sentry/kernel/pipe/node_test.go:158", - "pkg/sentry/kernel/pipe/node_test.go:174", - "pkg/sentry/kernel/pipe/node_test.go:180", - "pkg/sentry/kernel/pipe/node_test.go:193", - "pkg/sentry/kernel/pipe/node_test.go:202", - "pkg/sentry/kernel/pipe/node_test.go:205", - "pkg/sentry/kernel/pipe/node_test.go:216", - "pkg/sentry/kernel/pipe/node_test.go:219", - "pkg/sentry/kernel/pipe/node_test.go:271", - "pkg/sentry/kernel/pipe/node_test.go:290", - "pkg/sentry/kernel/pipe/pipe_test.go:93", - "pkg/sentry/kernel/pipe/reader_writer.go:65", - "pkg/sentry/kernel/posixtimer.go:157", - "pkg/sentry/kernel/ptrace.go:218", - "pkg/sentry/kernel/semaphore/semaphore.go:323", - "pkg/sentry/kernel/sessions.go:123", - "pkg/sentry/kernel/sessions.go:508", - "pkg/sentry/kernel/signal_handlers.go:57", - "pkg/sentry/kernel/task_context.go:72", - "pkg/sentry/kernel/task_exit.go:67", - "pkg/sentry/kernel/task_sched.go:255", - "pkg/sentry/kernel/task_sched.go:280", - "pkg/sentry/kernel/task_sched.go:323", - "pkg/sentry/kernel/task_stop.go:192", - "pkg/sentry/kernel/thread_group.go:530", - "pkg/sentry/kernel/timekeeper.go:316", - "pkg/sentry/kernel/vdso.go:106", - "pkg/sentry/kernel/vdso.go:118", - "pkg/sentry/memmap/memmap.go:103", - "pkg/sentry/memmap/memmap.go:163", - "pkg/sentry/mm/address_space.go:42", - "pkg/sentry/mm/address_space.go:42", - "pkg/sentry/mm/aio_context.go:208", - "pkg/sentry/mm/aio_context.go:288", - "pkg/sentry/mm/pma.go:683", - "pkg/sentry/mm/special_mappable.go:80", - "pkg/sentry/platform/systrap/subprocess.go:370", - "pkg/sentry/platform/systrap/usertrap/usertrap_amd64.go:124", - "pkg/sentry/socket/control/control.go:260", - "pkg/sentry/socket/control/control.go:94", - "pkg/sentry/socket/control/control_vfs2.go:37", - "pkg/sentry/socket/hostinet/stack.go:433", - "pkg/sentry/socket/hostinet/stack.go:438", - "pkg/sentry/socket/hostinet/stack.go:444", - "pkg/sentry/socket/hostinet/stack.go:460", - "pkg/sentry/socket/netfilter/tcp_matcher.go:74", - "pkg/sentry/socket/netfilter/udp_matcher.go:71", - "pkg/sentry/socket/netlink/route/protocol.go:38", - "pkg/sentry/socket/socket.go:332", - "pkg/sentry/socket/unix/transport/connectioned.go:394", - "pkg/sentry/socket/unix/transport/connectionless.go:152", - "pkg/sentry/socket/unix/transport/unix.go:436", - "pkg/sentry/socket/unix/transport/unix.go:490", - "pkg/sentry/socket/unix/transport/unix.go:685", - "pkg/sentry/socket/unix/transport/unix.go:795", - "pkg/sentry/syscalls/linux/sys_sem.go:62", - "pkg/sentry/syscalls/linux/sys_time.go:189", - "pkg/sentry/usage/cpu.go:42", - "pkg/sentry/vfs/anonfs.go:302", - "pkg/sentry/vfs/anonfs.go:99", - "pkg/sentry/vfs/dentry.go:214", - "pkg/sentry/vfs/epoll.go:168", - "pkg/sentry/vfs/epoll.go:314", - "pkg/sentry/vfs/file_description.go:549", - "pkg/sentry/vfs/file_description_impl_util.go:304", - "pkg/sentry/vfs/file_description_impl_util.go:412", - "pkg/sentry/vfs/filesystem.go:76", - "pkg/sentry/vfs/lock.go:15", - "pkg/sentry/vfs/lock.go:47", - "pkg/sentry/vfs/memxattr/xattr.go:37", - "pkg/sentry/vfs/mount.go:510", - "pkg/sentry/vfs/mount.go:667", - "pkg/sentry/vfs/mount_test.go:106", - "pkg/sentry/vfs/mount_test.go:160", - "pkg/sentry/vfs/mount_test.go:215", - "pkg/sentry/vfs/mount_unsafe.go:153", - "pkg/sentry/vfs/resolving_path.go:228", - "pkg/sentry/vfs/vfs.go:897", - "pkg/shim/runsc/runsc.go:16", - "pkg/shim/runsc/utils.go:16", - "pkg/shim/v1/proc/deleted_state.go:16", - "pkg/shim/v1/proc/exec.go:16", - "pkg/shim/v1/proc/exec_state.go:16", - "pkg/shim/v1/proc/init.go:16", - "pkg/shim/v1/proc/init_state.go:16", - "pkg/shim/v1/proc/io.go:16", - "pkg/shim/v1/proc/process.go:16", - "pkg/shim/v1/proc/types.go:16", - "pkg/shim/v1/proc/utils.go:16", - "pkg/shim/v1/shim/api.go:16", - "pkg/shim/v1/shim/platform.go:16", - "pkg/shim/v1/shim/service.go:16", - "pkg/shim/v1/utils/annotations.go:15", - "pkg/shim/v1/utils/utils.go:15", - "pkg/shim/v1/utils/volumes.go:15", - "pkg/shim/v2/api.go:16", - "pkg/shim/v2/epoll.go:18", - "pkg/shim/v2/options/options.go:15", - "pkg/shim/v2/options/options.go:24", - "pkg/shim/v2/options/options.go:26", - "pkg/shim/v2/runtimeoptions/runtimeoptions.go:16", - "pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go", // Generated: exempt all. - "pkg/shim/v2/runtimeoptions/runtimeoptions_test.go:22", - "pkg/shim/v2/service.go:15", - "pkg/shim/v2/service_linux.go:18", - "pkg/state/tests/integer_test.go:23", - "pkg/state/tests/integer_test.go:28", - "pkg/sync/rwmutex_test.go:105", - "pkg/syserr/host_linux.go:35", - "pkg/tcpip/adapters/gonet/gonet_test.go:144", - "pkg/tcpip/adapters/gonet/gonet_test.go:415", - "pkg/tcpip/adapters/gonet/gonet_test.go:99", - "pkg/tcpip/buffer/view.go:238", - "pkg/tcpip/buffer/view.go:238", - "pkg/tcpip/buffer/view.go:246", - "pkg/tcpip/header/tcp.go:151", - "pkg/tcpip/link/sharedmem/pipe/pipe_test.go:493", - "pkg/tcpip/stack/iptables.go:293", - "pkg/tcpip/stack/iptables_types.go:277", - "pkg/tcpip/stack/stack.go:553", - "pkg/tcpip/stack/transport_test.go:30", - "pkg/tcpip/transport/packet/endpoint.go:126", - "pkg/tcpip/transport/raw/endpoint.go:145", - "pkg/tcpip/transport/tcp/sack_scoreboard.go:167", - "pkg/unet/unet_test.go:634", - "pkg/unet/unet_test.go:662", - "pkg/unet/unet_test.go:703", - "pkg/unet/unet_test.go:98", - "pkg/usermem/addr.go:34", - "pkg/usermem/usermem.go:171", - "pkg/usermem/usermem.go:170", - "runsc/boot/compat.go:22", - "runsc/boot/compat.go:56", - "runsc/boot/loader.go:1115", - "runsc/boot/loader.go:1120", - "runsc/cmd/checkpoint.go:151", - "runsc/config/flags.go:32", - "runsc/container/container.go:641", - "runsc/container/container.go:988", - "runsc/specutils/specutils.go:172", - "runsc/specutils/specutils.go:428", - "runsc/specutils/specutils.go:436", - "runsc/specutils/specutils.go:442", - "runsc/specutils/specutils.go:447", - "runsc/specutils/specutils.go:454", - "test/cmd/test_app/fds.go:171", - "test/iptables/filter_output.go:251", - "test/packetimpact/testbench/connections.go:77", - "tools/bigquery/bigquery.go:106", - "tools/checkescape/test1/test1.go:108", - "tools/checkescape/test1/test1.go:122", - "tools/checkescape/test1/test1.go:137", - "tools/checkescape/test1/test1.go:151", - "tools/checkescape/test1/test1.go:170", - "tools/checkescape/test1/test1.go:39", - "tools/checkescape/test1/test1.go:45", - "tools/checkescape/test1/test1.go:50", - "tools/checkescape/test1/test1.go:64", - "tools/checkescape/test1/test1.go:80", - "tools/checkescape/test1/test1.go:94", - "tools/go_generics/imports.go:51", - "tools/go_generics/imports.go:75", - "tools/go_marshal/gomarshal/generator.go:177", - "tools/go_marshal/gomarshal/generator.go:81", - "tools/go_marshal/gomarshal/generator.go:85", - "tools/go_marshal/test/escape/escape.go:15", - "tools/go_marshal/test/test.go:164", - ), - ) + // Merge all analyzer configurations. + for name, ac := range other.Analyzers { + old, ok := c.Analyzers[name] + if !ok { + c.Analyzers[name] = ac // No analyzer in original config. + continue + } + old.merge(ac) + } +} - // Add all staticcheck analyzers; internal only. - for _, a := range staticcheck.Analyzers { - analyzerConfig[a] = staticMatcher +// Compile compiles a configuration to make it useable. +func (c *Config) Compile() error { + for i := 0; i < len(c.Groups); i++ { + if err := c.Groups[i].compile(); err != nil { + return fmt.Errorf("invalid group %q: %w", c.Groups[i].Name, err) + } } - // Add all stylecheck analyzers; internal only. - for _, a := range stylecheck.Analyzers { - analyzerConfig[a] = staticMatcher + if err := c.Global.compile(); err != nil { + return fmt.Errorf("invalid global: %w", err) } + for name, ac := range c.Analyzers { + if err := ac.compile(); err != nil { + return fmt.Errorf("invalid analyzer %q: %w", name, err) + } + } + return nil } -var escapesConfig = map[*analysis.Analyzer]matcher{ - // Informational only: include all packages. - checkescape.EscapeAnalyzer: alwaysMatches(), +// ShouldReport returns true iff the finding should match the Config. +func (c *Config) ShouldReport(finding Finding) bool { + fullPos := finding.Position.String() + + // Find the matching group. + var groupConfig *Group + for i := 0; i < len(c.Groups); i++ { + if c.Groups[i].regex.MatchString(fullPos) { + groupConfig = &c.Groups[i] + break + } + } + + // If there is no group matching this path, then + // we default to accept the finding. + if groupConfig == nil { + return true + } + + // Suppress via global rule? + if !c.Global.shouldReport(groupConfig, fullPos, finding.Message) { + return false + } + + // Try the analyzer config. + ac, ok := c.Analyzers[finding.Category] + if !ok { + return groupConfig.Default + } + return ac.shouldReport(groupConfig, fullPos, finding.Message) } diff --git a/tools/nogo/defs.bzl b/tools/nogo/defs.bzl index 543598b52..b3d297308 100644 --- a/tools/nogo/defs.bzl +++ b/tools/nogo/defs.bzl @@ -2,6 +2,28 @@ load("//tools/bazeldefs:go.bzl", "go_context", "go_importpath", "go_rule", "go_test_library") +NogoConfigInfo = provider( + "information about a nogo configuration", + fields = { + "srcs": "the collection of configuration files", + }, +) + +def _nogo_config_impl(ctx): + return [NogoConfigInfo( + srcs = ctx.files.srcs, + )] + +nogo_config = rule( + implementation = _nogo_config_impl, + attrs = { + "srcs": attr.label_list( + doc = "a list of yaml files (schema defined by tool/nogo/config.go).", + allow_files = True, + ), + }, +) + NogoTargetInfo = provider( "information about the Go target", fields = { @@ -20,11 +42,14 @@ nogo_target = go_rule( rule, implementation = _nogo_target_impl, attrs = { - # goarch is the build architecture. This will normally be provided by a - # select statement, but this information is propagated to other rules. - "goarch": attr.string(mandatory = True), - # goos is similarly the build operating system target. - "goos": attr.string(mandatory = True), + "goarch": attr.string( + doc = "the Go build architecture (propagated to other rules).", + mandatory = True, + ), + "goos": attr.string( + doc = "the Go OS target (propagated to other rules).", + mandatory = True, + ), }, ) @@ -81,7 +106,7 @@ NogoStdlibInfo = provider( "information for nogo analysis (standard library facts)", fields = { "facts": "serialized standard library facts", - "findings": "package findings (if relevant)", + "raw_findings": "raw package findings (if relevant)", }, ) @@ -90,7 +115,7 @@ def _nogo_stdlib_impl(ctx): nogo_target_info = ctx.attr._nogo_target[NogoTargetInfo] go_ctx = go_context(ctx, goos = nogo_target_info.goos, goarch = nogo_target_info.goarch) facts = ctx.actions.declare_file(ctx.label.name + ".facts") - findings = ctx.actions.declare_file(ctx.label.name + ".findings") + raw_findings = ctx.actions.declare_file(ctx.label.name + ".raw_findings") config = struct( Srcs = [f.path for f in go_ctx.stdlib_srcs], GOOS = go_ctx.goos, @@ -101,15 +126,15 @@ def _nogo_stdlib_impl(ctx): ctx.actions.write(config_file, config.to_json()) ctx.actions.run( inputs = [config_file] + go_ctx.stdlib_srcs, - outputs = [facts, findings], + outputs = [facts, raw_findings], tools = depset(go_ctx.runfiles.to_list() + ctx.files._nogo_objdump_tool), executable = ctx.files._nogo_check[0], - mnemonic = "GoStandardLibraryAnalysis", + mnemonic = "NogoStandardLibraryAnalysis", progress_message = "Analyzing Go Standard Library", arguments = go_ctx.nogo_args + [ "-objdump_tool=%s" % ctx.files._nogo_objdump_tool[0].path, "-stdlib=%s" % config_file.path, - "-findings=%s" % findings.path, + "-findings=%s" % raw_findings.path, "-facts=%s" % facts.path, ], ) @@ -117,7 +142,7 @@ def _nogo_stdlib_impl(ctx): # Return the stdlib facts as output. return [NogoStdlibInfo( facts = facts, - findings = findings, + raw_findings = raw_findings, )] nogo_stdlib = go_rule( @@ -148,7 +173,8 @@ NogoInfo = provider( "information for nogo analysis", fields = { "facts": "serialized package facts", - "findings": "package findings (if relevant)", + "raw_findings": "raw package findings (if relevant)", + "escapes": "escape-only findings (if relevant)", "importpath": "package import path", "binaries": "package binary files", "srcs": "srcs (for go_test support)", @@ -214,6 +240,7 @@ def _nogo_aspect_impl(target, ctx): # Collect all info from shadow dependencies. fact_map = dict() import_map = dict() + all_raw_findings = [] for dep in deps: # There will be no file attribute set for all transitive dependencies # that are not go_library or go_binary rules, such as a proto rules. @@ -231,6 +258,9 @@ def _nogo_aspect_impl(target, ctx): import_map[info.importpath] = x_files[0] fact_map[info.importpath] = info.facts.path + # Collect all findings; duplicates are resolved at the end. + all_raw_findings.extend(info.raw_findings) + # Ensure the above are available as inputs. inputs.append(info.facts) inputs += info.binaries @@ -244,7 +274,7 @@ def _nogo_aspect_impl(target, ctx): nogo_target_info = ctx.attr._nogo_target[NogoTargetInfo] go_ctx = go_context(ctx, goos = nogo_target_info.goos, goarch = nogo_target_info.goarch) facts = ctx.actions.declare_file(target.label.name + ".facts") - findings = ctx.actions.declare_file(target.label.name + ".findings") + raw_findings = ctx.actions.declare_file(target.label.name + ".raw_findings") escapes = ctx.actions.declare_file(target.label.name + ".escapes") config = struct( ImportPath = importpath, @@ -262,39 +292,39 @@ def _nogo_aspect_impl(target, ctx): inputs.append(config_file) ctx.actions.run( inputs = inputs, - outputs = [facts, findings, escapes], + outputs = [facts, raw_findings, escapes], tools = depset(go_ctx.runfiles.to_list() + ctx.files._nogo_objdump_tool), executable = ctx.files._nogo_check[0], - mnemonic = "GoStaticAnalysis", + mnemonic = "NogoAnalysis", progress_message = "Analyzing %s" % target.label, arguments = go_ctx.nogo_args + [ "-binary=%s" % target_objfile.path, "-objdump_tool=%s" % ctx.files._nogo_objdump_tool[0].path, "-package=%s" % config_file.path, - "-findings=%s" % findings.path, + "-findings=%s" % raw_findings.path, "-facts=%s" % facts.path, "-escapes=%s" % escapes.path, ], ) + # Flatten all findings from all dependencies. + # + # This is done because all the filtering must be done at the + # top-level nogo_test to dynamically apply a configuration. + # This does not actually add any additional work here, but + # will simply propagate the full list of files. + all_raw_findings = [stdlib_info.raw_findings] + depset(all_raw_findings).to_list() + [raw_findings] + # Return the package facts as output. - return [ - NogoInfo( - facts = facts, - findings = findings, - importpath = importpath, - binaries = binaries, - srcs = srcs, - deps = deps, - ), - OutputGroupInfo( - # Expose all findings (should just be a single file). This can be - # used for build analysis of the nogo findings. - nogo_findings = depset([findings]), - # Expose all escape analysis findings (see above). - nogo_escapes = depset([escapes]), - ), - ] + return [NogoInfo( + facts = facts, + raw_findings = all_raw_findings, + escapes = escapes, + importpath = importpath, + binaries = binaries, + srcs = srcs, + deps = deps, + )] nogo_aspect = go_rule( aspect, @@ -327,41 +357,72 @@ nogo_aspect = go_rule( def _nogo_test_impl(ctx): """Check nogo findings.""" - # Build a runner that checks the facts files. - findings = [dep[NogoInfo].findings for dep in ctx.attr.deps] - runner = ctx.actions.declare_file(ctx.label.name) + # Ensure there's a single dependency. + if len(ctx.attr.deps) != 1: + fail("nogo_test requires exactly one dep.") + raw_findings = ctx.attr.deps[0][NogoInfo].raw_findings + escapes = ctx.attr.deps[0][NogoInfo].escapes + + # Build a step that applies the configuration. + config_srcs = ctx.attr.config[NogoConfigInfo].srcs + findings = ctx.actions.declare_file(ctx.label.name + ".findings") ctx.actions.run( - inputs = findings + ctx.files.srcs, - outputs = [runner], - tools = depset(ctx.files._gentest), - executable = ctx.files._gentest[0], - mnemonic = "Gentest", + inputs = raw_findings + ctx.files.srcs + config_srcs, + outputs = [findings], + tools = depset(ctx.files._filter), + executable = ctx.files._filter[0], + mnemonic = "GoStaticAnalysis", progress_message = "Generating %s" % ctx.label, - arguments = [runner.path] + [f.path for f in findings], + arguments = ["-input=%s" % f.path for f in raw_findings] + + ["-config=%s" % f.path for f in config_srcs] + + ["-output=%s" % findings.path], ) + + # Build a runner that checks the filtered facts. + # + # Note that this calls the filter binary without any configuration, so all + # findings will be included. But this is expected, since we've already + # filtered out everything that should not be included. + runner = ctx.actions.declare_file(ctx.label.name) + runner_content = [ + "#!/bin/bash", + "exec %s -input=%s" % (ctx.files._filter[0].short_path, findings.short_path), + "", + ] + ctx.actions.write(runner, "\n".join(runner_content), is_executable = True) + return [DefaultInfo( + # The runner just executes the filter again, on the + # newly generated filtered findings. We still need + # the filter tool as part of our runfiles, however. + runfiles = ctx.runfiles(files = ctx.files._filter + [findings]), executable = runner, + ), OutputGroupInfo( + # Propagate the filtered filters, for consumption by + # build tooling. Note that the build tooling typically + # pays attention to the mnemoic above, so this must be + # what is expected by the tooling. + nogo_findings = depset([findings]), + # Expose all escape analysis findings (see above). + nogo_escapes = depset([escapes]), )] -_nogo_test = rule( +nogo_test = rule( implementation = _nogo_test_impl, attrs = { - # deps should have only a single element. - "deps": attr.label_list(aspects = [nogo_aspect]), - # srcs exist here only to ensure that this target is - # directly affected by changes to the source files. - "srcs": attr.label_list(allow_files = True), - "_gentest": attr.label(default = "//tools/nogo:gentest"), + "config": attr.label( + mandatory = True, + doc = "A rule of kind nogo_config.", + ), + "deps": attr.label_list( + aspects = [nogo_aspect], + doc = "Exactly one Go dependency to be analyzed.", + ), + "srcs": attr.label_list( + allow_files = True, + doc = "Relevant src files. This is ignored except to make the nogo_test directly affected by the files.", + ), + "_filter": attr.label(default = "//tools/nogo/filter:filter"), }, test = True, ) - -def nogo_test(name, srcs, library, **kwargs): - tags = kwargs.pop("tags", []) + ["nogo"] - _nogo_test( - name = name, - srcs = srcs, - deps = [library], - tags = tags, - **kwargs - ) diff --git a/tools/nogo/filter/BUILD b/tools/nogo/filter/BUILD new file mode 100644 index 000000000..e56a783e2 --- /dev/null +++ b/tools/nogo/filter/BUILD @@ -0,0 +1,14 @@ +load("//tools:defs.bzl", "go_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "filter", + srcs = ["main.go"], + nogo = False, + visibility = ["//visibility:public"], + deps = [ + "//tools/nogo", + "@in_gopkg_yaml_v2//:go_default_library", + ], +) diff --git a/tools/nogo/filter/main.go b/tools/nogo/filter/main.go new file mode 100644 index 000000000..9cf41b3b0 --- /dev/null +++ b/tools/nogo/filter/main.go @@ -0,0 +1,131 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary check is the nogo entrypoint. +package main + +import ( + "flag" + "fmt" + "io/ioutil" + "log" + "os" + "strings" + + yaml "gopkg.in/yaml.v2" + "gvisor.dev/gvisor/tools/nogo" +) + +type stringList []string + +func (s *stringList) String() string { + return strings.Join(*s, ",") +} + +func (s *stringList) Set(value string) error { + *s = append(*s, value) + return nil +} + +var ( + inputFiles stringList + configFiles stringList + outputFile string + showConfig bool +) + +func init() { + flag.Var(&inputFiles, "input", "findings input files") + flag.StringVar(&outputFile, "output", "", "findings output file") + flag.Var(&configFiles, "config", "findings configuration files") + flag.BoolVar(&showConfig, "show-config", false, "dump configuration only") +} + +func main() { + flag.Parse() + + // Load all available findings. + var findings []nogo.Finding + for _, filename := range inputFiles { + inputFindings, err := nogo.ExtractFindingsFromFile(filename) + if err != nil { + log.Fatalf("unable to extract findings from %s: %v", filename, err) + } + findings = append(findings, inputFindings...) + } + + // Open and merge all configuations. + config := &nogo.Config{ + Global: make(nogo.AnalyzerConfig), + Analyzers: make(map[nogo.AnalyzerName]nogo.AnalyzerConfig), + } + for _, filename := range configFiles { + content, err := ioutil.ReadFile(filename) + if err != nil { + log.Fatalf("unable to read %s: %v", filename, err) + } + var newConfig nogo.Config // For current file. + if err := yaml.Unmarshal(content, &newConfig); err != nil { + log.Fatalf("unable to decode %s: %v", filename, err) + } + config.Merge(&newConfig) + if showConfig { + bytes, err := yaml.Marshal(&newConfig) + if err != nil { + log.Fatalf("error marshalling config: %v", err) + } + mergedBytes, err := yaml.Marshal(config) + if err != nil { + log.Fatalf("error marshalling config: %v", err) + } + fmt.Fprintf(os.Stdout, "Loaded configuration from %s:\n%s\n", filename, string(bytes)) + fmt.Fprintf(os.Stdout, "Merged configuration:\n%s\n", string(mergedBytes)) + } + } + if err := config.Compile(); err != nil { + log.Fatalf("error compiling config: %v", err) + } + if showConfig { + os.Exit(0) + } + + // Filter the findings (and aggregate by group). + filteredFindings := make([]nogo.Finding, 0, len(findings)) + for _, finding := range findings { + if ok := config.ShouldReport(finding); ok { + filteredFindings = append(filteredFindings, finding) + } + } + + // Write the output (if required). + // + // If the outputFile is specified, then we exit here. Otherwise, + // we continue to write to stdout and treat like a test. + if outputFile != "" { + if err := nogo.WriteFindingsToFile(filteredFindings, outputFile); err != nil { + log.Fatalf("unable to write findings: %v", err) + } + return + } + + // Treat the run as a test. + if len(filteredFindings) == 0 { + fmt.Fprintf(os.Stdout, "PASS\n") + os.Exit(0) + } + for _, finding := range filteredFindings { + fmt.Fprintf(os.Stdout, "%s\n", finding.String()) + } + os.Exit(1) +} diff --git a/tools/nogo/findings.go b/tools/nogo/findings.go new file mode 100644 index 000000000..5bd850269 --- /dev/null +++ b/tools/nogo/findings.go @@ -0,0 +1,63 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nogo + +import ( + "encoding/json" + "fmt" + "go/token" + "io/ioutil" +) + +// Finding is a single finding. +type Finding struct { + Category AnalyzerName + Position token.Position + Message string +} + +// String implements fmt.Stringer.String. +func (f *Finding) String() string { + return fmt.Sprintf("%s: %s: %s", f.Category, f.Position.String(), f.Message) +} + +// WriteFindingsToFile writes findings to a file. +func WriteFindingsToFile(findings []Finding, filename string) error { + content, err := WriteFindingsToBytes(findings) + if err != nil { + return err + } + return ioutil.WriteFile(filename, content, 0644) +} + +// WriteFindingsToBytes serializes findings as bytes. +func WriteFindingsToBytes(findings []Finding) ([]byte, error) { + return json.Marshal(findings) +} + +// ExtractFindingsFromFile loads findings from a file. +func ExtractFindingsFromFile(filename string) ([]Finding, error) { + content, err := ioutil.ReadFile(filename) + if err != nil { + return nil, err + } + return ExtractFindingsFromBytes(content) +} + +// ExtractFindingsFromBytes loads findings from bytes. +func ExtractFindingsFromBytes(content []byte) (findings []Finding, err error) { + err = json.Unmarshal(content, &findings) + return findings, err +} diff --git a/tools/nogo/gentest.sh b/tools/nogo/gentest.sh deleted file mode 100755 index 0a762f9f6..000000000 --- a/tools/nogo/gentest.sh +++ /dev/null @@ -1,48 +0,0 @@ -#!/bin/bash -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -euo pipefail - -if [[ "$#" -lt 2 ]]; then - echo "usage: $0 <output> <findings...>" - exit 2 -fi -declare violations=0 -declare output=$1 -shift - -# Start the script. -echo "#!/bin/sh" > "${output}" - -# Read a list of findings files. -declare filename -declare line -for filename in "$@"; do - if [[ -z "${filename}" ]]; then - continue - fi - while read -r line; do - line="${line@Q}" - violations=$((${violations}+1)); - echo "echo -e '\\033[0;31m${line}\\033[0;31m\\033[0m'" >> "${output}" - done < "${filename}" -done - -# Show violations. -if [[ "${violations}" -eq 0 ]]; then - echo "echo -e '\\033[0;32mPASS\\033[0;31m\\033[0m'" >> "${output}" -else - echo "exit 1" >> "${output}" -fi diff --git a/tools/nogo/matchers.go b/tools/nogo/matchers.go deleted file mode 100644 index b7b73fa27..000000000 --- a/tools/nogo/matchers.go +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package nogo - -import ( - "go/token" - "regexp" - "strings" - - "golang.org/x/tools/go/analysis" -) - -type matcher interface { - ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool -} - -// pathRegexps filters explicit paths. -type pathRegexps struct { - expr []*regexp.Regexp - - // include, if true, indicates that paths matching any regexp in expr - // match. - // - // If false, paths matching no regexps in expr match. - include bool -} - -// buildRegexps builds a list of regular expressions. -// -// This will panic on error. -func buildRegexps(prefix string, args ...string) []*regexp.Regexp { - result := make([]*regexp.Regexp, 0, len(args)) - for _, arg := range args { - result = append(result, regexp.MustCompile(prefix+arg)) - } - return result -} - -// notPath works around the lack of backtracking. -// -// It is used to construct a regular expression for non-matching components. -func notPath(name string) string { - sb := strings.Builder{} - sb.WriteString("(") - for i := range name { - if i > 0 { - sb.WriteString("|") - } - sb.WriteString(name[:i]) - sb.WriteString("[^") - sb.WriteByte(name[i]) - sb.WriteString("/][^/]*") - } - sb.WriteString(")") - return sb.String() -} - -// ShouldReport implements matcher.ShouldReport. -func (p *pathRegexps) ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool { - fullPos := fs.Position(d.Pos).String() - for _, path := range p.expr { - if path.MatchString(fullPos) { - return p.include - } - } - return !p.include -} - -// internalExcluded excludes specific internal paths. -func internalExcluded(paths ...string) *pathRegexps { - return &pathRegexps{ - expr: buildRegexps(internalPrefix, paths...), - include: false, - } -} - -// excludedExcluded excludes specific external paths. -func externalExcluded(paths ...string) *pathRegexps { - return &pathRegexps{ - expr: buildRegexps(externalPrefix, paths...), - include: false, - } -} - -// internalMatches returns a path matcher for internal packages. -func internalMatches() *pathRegexps { - return &pathRegexps{ - expr: buildRegexps(internalPrefix, internalDefault), - include: true, - } -} - -// generatedExcluded excludes all generated code. -func generatedExcluded() *pathRegexps { - return &pathRegexps{ - expr: buildRegexps(generatedPrefix, ".*"), - include: false, - } -} - -// resultExcluded excludes explicit message contents. -type resultExcluded []string - -// ShouldReport implements matcher.ShouldReport. -func (r resultExcluded) ShouldReport(d analysis.Diagnostic, _ *token.FileSet) bool { - for _, str := range r { - if strings.Contains(d.Message, str) { - return false - } - } - return true // Not excluded. -} - -// andMatcher is a composite matcher. -type andMatcher struct { - all []matcher -} - -// ShouldReport implements matcher.ShouldReport. -func (a *andMatcher) ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool { - for _, m := range a.all { - if !m.ShouldReport(d, fs) { - return false - } - } - return true -} - -// and is a syntactic convension for andMatcher. -func and(ms ...matcher) *andMatcher { - return &andMatcher{ - all: ms, - } -} - -// anyMatcher matches everything. -type anyMatcher struct{} - -// ShouldReport implements matcher.ShouldReport. -func (anyMatcher) ShouldReport(analysis.Diagnostic, *token.FileSet) bool { - return true -} - -// alwaysMatches returns an anyMatcher instance. -func alwaysMatches() anyMatcher { - return anyMatcher{} -} - -// neverMatcher will never match. -type neverMatcher struct{} - -// ShouldReport implements matcher.ShouldReport. -func (neverMatcher) ShouldReport(analysis.Diagnostic, *token.FileSet) bool { - return false -} - -// disableMatches returns a neverMatcher instance. -func disableMatches() neverMatcher { - return neverMatcher{} -} diff --git a/tools/nogo/nogo.go b/tools/nogo/nogo.go index e19e3c237..779d4d6d8 100644 --- a/tools/nogo/nogo.go +++ b/tools/nogo/nogo.go @@ -21,7 +21,6 @@ package nogo import ( "encoding/json" "errors" - "flag" "fmt" "go/ast" "go/build" @@ -45,20 +44,20 @@ import ( "gvisor.dev/gvisor/tools/checkescape" ) -// stdlibConfig is serialized as the configuration. +// StdlibConfig is serialized as the configuration. // // This contains everything required for stdlib analysis. -type stdlibConfig struct { +type StdlibConfig struct { Srcs []string GOOS string GOARCH string Tags []string } -// packageConfig is serialized as the configuration. +// PackageConfig is serialized as the configuration. // // This contains everything required for single package analysis. -type packageConfig struct { +type PackageConfig struct { ImportPath string GoFiles []string NonGoFiles []string @@ -84,7 +83,7 @@ type saver func([]byte) error // // This is done because all stdlib data is stored together, and we don't want // to load this data many times over. -func (c *packageConfig) factLoader() (loader, error) { +func (c *PackageConfig) factLoader() (loader, error) { allFacts := make(map[string][]byte) if c.StdlibFacts != "" { data, err := ioutil.ReadFile(c.StdlibFacts) @@ -114,7 +113,7 @@ func (c *packageConfig) factLoader() (loader, error) { // shouldInclude indicates whether the file should be included. // // NOTE: This does only basic parsing of tags. -func (c *packageConfig) shouldInclude(path string) (bool, error) { +func (c *PackageConfig) shouldInclude(path string) (bool, error) { ctx := build.Default ctx.GOOS = c.GOOS ctx.GOARCH = c.GOARCH @@ -128,7 +127,7 @@ func (c *packageConfig) shouldInclude(path string) (bool, error) { // files, and the facts. Note that this importer implementation will always // pass when a given package is not available. type importer struct { - *packageConfig + *PackageConfig fset *token.FileSet cache map[string]*types.Package lastErr error @@ -185,14 +184,14 @@ func (i *importer) Import(path string) (*types.Package, error) { // ErrSkip indicates the package should be skipped. var ErrSkip = errors.New("skipped") -// checkStdlib checks the standard library. +// CheckStdlib checks the standard library. // // This constructs a synthetic package configuration for each library in the -// standard library sources, and call checkPackage repeatedly. +// standard library sources, and call CheckPackage repeatedly. // // Note that not all parts of the source are expected to build. We skip obvious // test files, and cmd files, which should not be dependencies. -func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]string, []byte, error) { +func CheckStdlib(config *StdlibConfig, analyzers []*analysis.Analyzer) (allFindings []Finding, facts []byte, err error) { if len(config.Srcs) == 0 { return nil, nil, nil } @@ -225,7 +224,7 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str } // Aggregate all files by directory. - packages := make(map[string]*packageConfig) + packages := make(map[string]*PackageConfig) for _, file := range config.Srcs { if !strings.HasPrefix(file, rootSrcPrefix) { // Superflouous file. @@ -243,7 +242,7 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str } c, ok := packages[pkg] if !ok { - c = &packageConfig{ + c = &PackageConfig{ ImportPath: pkg, GOOS: config.GOOS, GOARCH: config.GOARCH, @@ -262,7 +261,6 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str } // Closure to check a single package. - allFindings := make([]string, 0) stdlibFacts := make(map[string][]byte) stdlibErrs := make(map[string]error) var checkOne func(pkg string) error // Recursive. @@ -301,7 +299,7 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str }() // Run the analysis. - findings, factData, err := checkPackage(config, ac, checkOne) + findings, factData, err := CheckPackage(config, analyzers, checkOne) if err != nil { // If we can't analyze a package from the standard library, // then we skip it. It will simply not have any findings. @@ -344,7 +342,7 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str return allFindings, factData, nil } -// checkPackage runs all analyzers. +// CheckPackage runs all given analyzers. // // The implementation was adapted from [1], which was in turn adpated from [2]. // This returns a list of matching analysis issues, or an error if the analysis @@ -352,9 +350,9 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str // // [1] bazelbuid/rules_go/tools/builders/nogo_main.go // [2] golang.org/x/tools/go/checker/internal/checker -func checkPackage(config *packageConfig, ac map[*analysis.Analyzer]matcher, importCallback func(string) error) ([]string, []byte, error) { +func CheckPackage(config *PackageConfig, analyzers []*analysis.Analyzer, importCallback func(string) error) (findings []Finding, factData []byte, err error) { imp := &importer{ - packageConfig: config, + PackageConfig: config, fset: token.NewFileSet(), cache: make(map[string]*types.Package), callback: importCallback, @@ -406,7 +404,6 @@ func checkPackage(config *packageConfig, ac map[*analysis.Analyzer]matcher, impo // Register fact types and establish dependencies between analyzers. // The visit closure will execute recursively, and populate results // will all required analysis results. - diagnostics := make(map[*analysis.Analyzer][]analysis.Diagnostic) results := make(map[*analysis.Analyzer]interface{}) var visit func(*analysis.Analyzer) error // For recursion. visit = func(a *analysis.Analyzer) error { @@ -421,27 +418,25 @@ func checkPackage(config *packageConfig, ac map[*analysis.Analyzer]matcher, impo } } - // Prepare the matcher. - m := ac[a] - report := func(d analysis.Diagnostic) { - if m.ShouldReport(d, imp.fset) { - diagnostics[a] = append(diagnostics[a], d) - } - } - // Run the analysis. factFilter := make(map[reflect.Type]bool) for _, f := range a.FactTypes { factFilter[reflect.TypeOf(f)] = true } p := &analysis.Pass{ - Analyzer: a, - Fset: imp.fset, - Files: syntax, - Pkg: types, - TypesInfo: typesInfo, - ResultOf: results, // All results. - Report: report, + Analyzer: a, + Fset: imp.fset, + Files: syntax, + Pkg: types, + TypesInfo: typesInfo, + ResultOf: results, // All results. + Report: func(d analysis.Diagnostic) { + findings = append(findings, Finding{ + Category: AnalyzerName(a.Name), + Position: imp.fset.Position(d.Pos), + Message: d.Message, + }) + }, ImportPackageFact: facts.ImportPackageFact, ExportPackageFact: facts.ExportPackageFact, ImportObjectFact: facts.ImportObjectFact, @@ -464,7 +459,7 @@ func checkPackage(config *packageConfig, ac map[*analysis.Analyzer]matcher, impo } // Visit all analyzers recursively. - for a, _ := range ac { + for _, a := range analyzers { if imp.lastErr == ErrSkip { continue // No local analysis. } @@ -473,114 +468,6 @@ func checkPackage(config *packageConfig, ac map[*analysis.Analyzer]matcher, impo } } - // Convert all diagnostics to strings. - findings := make([]string, 0, len(diagnostics)) - for a, ds := range diagnostics { - for _, d := range ds { - // Include the anlyzer name for debugability and configuration. - findings = append(findings, fmt.Sprintf("%s: %s: %s", a.Name, imp.fset.Position(d.Pos), d.Message)) - } - } - // Return all findings. - factData := facts.Encode() - return findings, factData, nil -} - -var ( - packageFile = flag.String("package", "", "package configuration file (in JSON format)") - stdlibFile = flag.String("stdlib", "", "stdlib configuration file (in JSON format)") - findingsOutput = flag.String("findings", "", "output file (or stdout, if not specified)") - factsOutput = flag.String("facts", "", "output file for facts (optional)") - escapesOutput = flag.String("escapes", "", "output file for escapes (optional)") -) - -func loadConfig(file string, config interface{}) interface{} { - // Load the configuration. - f, err := os.Open(file) - if err != nil { - log.Fatalf("unable to open configuration %q: %v", file, err) - } - defer f.Close() - dec := json.NewDecoder(f) - dec.DisallowUnknownFields() - if err := dec.Decode(config); err != nil { - log.Fatalf("unable to decode configuration: %v", err) - } - return config -} - -// Main is the entrypoint; it should be called directly from main. -// -// N.B. This package registers it's own flags. -func Main() { - // Parse all flags. - flag.Parse() - - var ( - findings []string - factData []byte - err error - ) - - // Check the configuration. - if *packageFile != "" && *stdlibFile != "" { - log.Fatalf("unable to perform stdlib and package analysis; provide only one!") - } else if *stdlibFile != "" { - // Perform basic analysis. - c := loadConfig(*stdlibFile, new(stdlibConfig)).(*stdlibConfig) - findings, factData, err = checkStdlib(c, analyzerConfig) - } else if *packageFile != "" { - // Perform basic analysis. - c := loadConfig(*packageFile, new(packageConfig)).(*packageConfig) - findings, factData, err = checkPackage(c, analyzerConfig, nil) - // Do we need to do escape analysis? - if *escapesOutput != "" { - f, err := os.OpenFile(*escapesOutput, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) - if err != nil { - log.Fatalf("unable to open output %q: %v", *escapesOutput, err) - } - defer f.Close() - escapes, _, err := checkPackage(c, escapesConfig, nil) - if err != nil { - log.Fatalf("error performing escape analysis: %v", err) - } - for _, escape := range escapes { - fmt.Fprintf(f, "%s\n", escape) - } - } - } else { - log.Fatalf("please provide at least one of package or stdlib!") - } - - // Save facts. - if *factsOutput != "" { - if err := ioutil.WriteFile(*factsOutput, factData, 0644); err != nil { - log.Fatalf("error saving findings to %q: %v", *factsOutput, err) - } - } - - // Open the output file. - var w io.Writer = os.Stdout - if *findingsOutput != "" { - f, err := os.OpenFile(*findingsOutput, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) - if err != nil { - log.Fatalf("unable to open output %q: %v", *findingsOutput, err) - } - defer f.Close() - w = f - } - - // Handle findings & errors. - if err != nil { - log.Fatalf("error checking package: %v", err) - } - if len(findings) == 0 { - return - } - - // Print findings. - for _, finding := range findings { - fmt.Fprintf(w, "%s\n", finding) - } + return findings, facts.Encode(), nil } diff --git a/tools/nogo/register.go b/tools/nogo/register.go deleted file mode 100644 index 34b173937..000000000 --- a/tools/nogo/register.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package nogo - -import ( - "encoding/gob" - "log" - - "golang.org/x/tools/go/analysis" -) - -// analyzers returns all configured analyzers. -func analyzers() (all []*analysis.Analyzer) { - for a, _ := range analyzerConfig { - all = append(all, a) - } - for a, _ := range escapesConfig { - all = append(all, a) - } - return all -} - -func init() { - // Validate basic configuration. - if err := analysis.Validate(analyzers()); err != nil { - log.Fatalf("unable to validate analyzer: %v", err) - } - - // Register all fact types. - // - // N.B. This needs to be done recursively, because there may be - // analyzers in the Requires list that do not appear explicitly above. - registered := make(map[*analysis.Analyzer]struct{}) - var register func(*analysis.Analyzer) - register = func(a *analysis.Analyzer) { - if _, ok := registered[a]; ok { - return - } - - // Regsiter dependencies. - for _, da := range a.Requires { - register(da) - } - - // Register local facts. - for _, f := range a.FactTypes { - gob.Register(f) - } - - registered[a] = struct{}{} // Done. - } - for _, a := range analyzers() { - register(a) - } -} diff --git a/tools/nogo/util/BUILD b/tools/nogo/util/BUILD deleted file mode 100644 index 7ab340b51..000000000 --- a/tools/nogo/util/BUILD +++ /dev/null @@ -1,9 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "util", - srcs = ["util.go"], - visibility = ["//visibility:public"], -) diff --git a/tools/nogo/util/util.go b/tools/nogo/util/util.go deleted file mode 100644 index 919fec799..000000000 --- a/tools/nogo/util/util.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package util contains nogo-related utilities. -package util - -import ( - "fmt" - "io/ioutil" - "regexp" - "strconv" - "strings" -) - -// findingRegexp is used to parse findings. -var findingRegexp = regexp.MustCompile(`([a-zA-Z0-9_\/\.-]+): (-|([a-zA-Z0-9_\/\.-]+):([0-9]+)(:([0-9]+))?): (.*)`) - -const ( - categoryIndex = 1 - fullPathAndLineIndex = 2 - fullPathIndex = 3 - lineIndex = 4 - messageIndex = 7 -) - -// Finding is a single finding. -type Finding struct { - Category string - Path string - Line int - Message string -} - -// ExtractFindingsFromFile loads findings from a file. -func ExtractFindingsFromFile(filename string) ([]Finding, error) { - content, err := ioutil.ReadFile(filename) - if err != nil { - return nil, err - } - return ExtractFindingsFromBytes(content) -} - -// ExtractFindingsFromBytes loads findings from bytes. -func ExtractFindingsFromBytes(content []byte) (findings []Finding, err error) { - lines := strings.Split(string(content), "\n") - for _, singleLine := range lines { - // Skip blank lines. - singleLine = strings.TrimSpace(singleLine) - if singleLine == "" { - continue - } - m := findingRegexp.FindStringSubmatch(singleLine) - if m == nil { - // We shouldn't see findings like this. - return findings, fmt.Errorf("poorly formated line: %v", singleLine) - } - if m[fullPathAndLineIndex] == "-" { - continue // No source file available. - } - // Cleanup the message. - message := m[messageIndex] - message = strings.Replace(message, " → ", "\n → ", -1) - message = strings.Replace(message, " or ", "\n or ", -1) - // Construct a new annotation. - lineNumber, _ := strconv.ParseUint(m[lineIndex], 10, 32) - findings = append(findings, Finding{ - Category: m[categoryIndex], - Path: m[fullPathIndex], - Line: int(lineNumber), - Message: message, - }) - } - return findings, nil -} diff --git a/tools/parsers/BUILD b/tools/parsers/BUILD index 7d9c9a3fb..dab954e25 100644 --- a/tools/parsers/BUILD +++ b/tools/parsers/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_binary", "go_library", "go_test") package(licenses = ["notice"]) @@ -7,6 +7,7 @@ go_test( size = "small", srcs = ["go_parser_test.go"], library = ":parsers", + nogo = False, deps = [ "//tools/bigquery", "@com_github_google_go_cmp//cmp:go_default_library", @@ -19,9 +20,22 @@ go_library( srcs = [ "go_parser.go", ], + nogo = False, visibility = ["//:sandbox"], deps = [ "//test/benchmarks/tools", "//tools/bigquery", ], ) + +go_binary( + name = "parser", + testonly = 1, + srcs = ["parser_main.go"], + nogo = False, + deps = [ + ":parsers", + "//runsc/flag", + "//tools/bigquery", + ], +) diff --git a/tools/parsers/go_parser.go b/tools/parsers/go_parser.go index 2cf74c883..df4875e6a 100644 --- a/tools/parsers/go_parser.go +++ b/tools/parsers/go_parser.go @@ -27,20 +27,21 @@ import ( "gvisor.dev/gvisor/tools/bigquery" ) -// parseOutput expects golang benchmark output returns a Benchmark struct formatted for BigQuery. -func parseOutput(output string, metadata *bigquery.Metadata, official bool) ([]*bigquery.Benchmark, error) { - var benchmarks []*bigquery.Benchmark +// ParseOutput expects golang benchmark output and returns a struct formatted +// for BigQuery. +func ParseOutput(output string, name string, official bool) (*bigquery.Suite, error) { + suite := bigquery.NewSuite(name) lines := strings.Split(output, "\n") for _, line := range lines { - bm, err := parseLine(line, metadata, official) + bm, err := parseLine(line, official) if err != nil { return nil, fmt.Errorf("failed to parse line '%s': %v", line, err) } if bm != nil { - benchmarks = append(benchmarks, bm) + suite.Benchmarks = append(suite.Benchmarks, bm) } } - return benchmarks, nil + return suite, nil } // parseLine handles parsing a benchmark line into a bigquery.Benchmark. @@ -58,9 +59,8 @@ func parseOutput(output string, metadata *bigquery.Metadata, official bool) ([]* // {Name: ns/op, Unit: ns/op, Sample: 1397875880} // {Name: requests_per_second, Unit: QPS, Sample: 140 } // } -// Metadata: metadata //} -func parseLine(line string, metadata *bigquery.Metadata, official bool) (*bigquery.Benchmark, error) { +func parseLine(line string, official bool) (*bigquery.Benchmark, error) { fields := strings.Fields(line) // Check if this line is a Benchmark line. Otherwise ignore the line. @@ -79,7 +79,6 @@ func parseLine(line string, metadata *bigquery.Metadata, official bool) (*bigque } bm := bigquery.NewBenchmark(name, iters, official) - bm.Metadata = metadata for _, p := range params { bm.AddCondition(p.Name, p.Value) } diff --git a/tools/parsers/go_parser_test.go b/tools/parsers/go_parser_test.go index 36996b7c8..0aa1152a2 100644 --- a/tools/parsers/go_parser_test.go +++ b/tools/parsers/go_parser_test.go @@ -94,13 +94,11 @@ func TestParseLine(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - got, err := parseLine(tc.data, nil, false) + got, err := parseLine(tc.data, false) if err != nil { t.Fatalf("parseLine failed with: %v", err) } - tc.want.Timestamp = got.Timestamp - if !cmp.Equal(tc.want, got, nil) { for _, c := range got.Condition { t.Logf("Cond: %+v", c) @@ -150,14 +148,14 @@ BenchmarkRuby/server_threads.5-6 1 1416003331 ns/op 0.00950 average_latency.s 46 for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - bms, err := parseOutput(tc.data, nil, false) + suite, err := ParseOutput(tc.data, "", false) if err != nil { t.Fatalf("parseOutput failed: %v", err) - } else if len(bms) != tc.numBenchmarks { - t.Fatalf("NumBenchmarks failed want: %d got: %d %+v", tc.numBenchmarks, len(bms), bms) + } else if len(suite.Benchmarks) != tc.numBenchmarks { + t.Fatalf("NumBenchmarks failed want: %d got: %d %+v", tc.numBenchmarks, len(suite.Benchmarks), suite.Benchmarks) } - for _, bm := range bms { + for _, bm := range suite.Benchmarks { if len(bm.Metric) != tc.numMetrics { t.Fatalf("NumMetrics failed want: %d got: %d %+v", tc.numMetrics, len(bm.Metric), bm.Metric) } diff --git a/tools/parsers/parser_main.go b/tools/parsers/parser_main.go new file mode 100644 index 000000000..6c6182464 --- /dev/null +++ b/tools/parsers/parser_main.go @@ -0,0 +1,129 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary parser parses Benchmark data from golang benchmarks, +// puts it into a Schema for BigQuery, and sends it to BigQuery. +// parser will also initialize a table with the Benchmarks BigQuery schema. +package main + +import ( + "context" + "fmt" + "io/ioutil" + "os" + + "gvisor.dev/gvisor/runsc/flag" + bq "gvisor.dev/gvisor/tools/bigquery" + "gvisor.dev/gvisor/tools/parsers" +) + +const ( + initString = "init" + initDescription = "initializes a new table with benchmarks schema" + parseString = "parse" + parseDescription = "parses given benchmarks file and sends it to BigQuery table." +) + +var ( + // The init command will create a new dataset/table in the given project and initialize + // the table with the schema in //tools/bigquery/bigquery.go. If the table/dataset exists + // or has been initialized, init has no effect and successfully returns. + initCmd = flag.NewFlagSet(initString, flag.ContinueOnError) + initProject = initCmd.String("project", "", "GCP project to send benchmarks.") + initDataset = initCmd.String("dataset", "", "dataset to send benchmarks data.") + initTable = initCmd.String("table", "", "table to send benchmarks data.") + + // The parse command parses benchmark data in `file` and sends it to the + // requested table. + parseCmd = flag.NewFlagSet(parseString, flag.ContinueOnError) + file = parseCmd.String("file", "", "file to parse for benchmarks") + name = parseCmd.String("suite_name", "", "name of the benchmark suite") + clNumber = parseCmd.String("cl", "", "changelist number of this run") + gitCommit = parseCmd.String("git_commit", "", "git commit sha for this run") + parseProject = parseCmd.String("project", "", "GCP project to send benchmarks.") + parseDataset = parseCmd.String("dataset", "", "dataset to send benchmarks data.") + parseTable = parseCmd.String("table", "", "table to send benchmarks data.") + official = parseCmd.Bool("official", false, "mark input data as official.") +) + +// initBenchmarks initializes a dataset/table in a BigQuery project. +func initBenchmarks(ctx context.Context) error { + return bq.InitBigQuery(ctx, *initProject, *initDataset, *initTable, nil) +} + +// parseBenchmarks parses the given file into the BigQuery schema, +// adds some custom data for the commit, and sends the data to BigQuery. +func parseBenchmarks(ctx context.Context) error { + data, err := ioutil.ReadFile(*file) + if err != nil { + return fmt.Errorf("failed to read file: %v", err) + } + suite, err := parsers.ParseOutput(string(data), *name, *official) + if err != nil { + return fmt.Errorf("failed parse data: %v", err) + } + extraConditions := []*bq.Condition{ + { + Name: "change_list", + Value: *clNumber, + }, + { + Name: "commit", + Value: *gitCommit, + }, + } + + suite.Conditions = append(suite.Conditions, extraConditions...) + return bq.SendBenchmarks(ctx, suite, *parseProject, *parseDataset, *parseTable, nil) +} + +func main() { + ctx := context.Background() + switch { + // the "init" command + case len(os.Args) >= 2 && os.Args[1] == initString: + if err := initCmd.Parse(os.Args[2:]); err != nil { + fmt.Fprintf(os.Stderr, "failed parse flags: %v", err) + os.Exit(1) + } + if err := initBenchmarks(ctx); err != nil { + failure := "failed to initialize project: %s dataset: %s table: %s: %v" + fmt.Fprintf(os.Stderr, failure, *parseProject, *parseDataset, *parseTable, err) + os.Exit(1) + } + // the "parse" command. + case len(os.Args) >= 2 && os.Args[1] == parseString: + if err := parseCmd.Parse(os.Args[2:]); err != nil { + fmt.Fprintf(os.Stderr, "failed parse flags: %v", err) + os.Exit(1) + } + if err := parseBenchmarks(ctx); err != nil { + fmt.Fprintf(os.Stderr, "failed parse benchmarks: %v", err) + os.Exit(1) + } + default: + printUsage() + } +} + +// printUsage prints the top level usage string. +func printUsage() { + usage := `Usage: parser <command> <flags> ... + +Available commands: + %s %s + %s %s +` + fmt.Fprintf(os.Stderr, usage, initCmd.Name(), initDescription, parseCmd.Name(), parseDescription) +} diff --git a/tools/tag_release.sh b/tools/tag_release.sh index b0bab74b4..50378065e 100755 --- a/tools/tag_release.sh +++ b/tools/tag_release.sh @@ -43,7 +43,7 @@ fi closest_commit() { while read line; do - if [[ "$line" =~ "commit " ]]; then + if [[ "$line" =~ ^"commit " ]]; then current_commit="${line#commit }" continue elif [[ "$line" =~ "PiperOrigin-RevId: " ]]; then @@ -57,7 +57,9 @@ closest_commit() { # Is the passed identifier a sha commit? if ! git show "${target_commit}" &> /dev/null; then # Extract the commit given a piper ID. - declare -r commit="$(git log | closest_commit "${target_commit}")" + commit="$(set +o pipefail; \ + git log --first-parent | closest_commit "${target_commit}")" + declare -r commit else declare -r commit="${target_commit}" fi diff --git a/website/BUILD b/website/BUILD index f3642b903..676c2b701 100644 --- a/website/BUILD +++ b/website/BUILD @@ -6,11 +6,16 @@ package(licenses = ["notice"]) docker_image( name = "website", - data = [":files"], + data = ":files", statements = [ "EXPOSE 8080/tcp", 'ENTRYPOINT ["/server"]', ], + tags = [ + "local", + "manual", + "nosandbox", + ], ) # files is the full file system of the generated container. diff --git a/website/_config.yml b/website/_config.yml index 20fbb3d2d..51cb8e13c 100644 --- a/website/_config.yml +++ b/website/_config.yml @@ -37,3 +37,9 @@ authors: fvoznika: name: Fabricio Voznika email: fvoznika@google.com + ianlewis: + name: Ian Lewis + email: ianlewis@google.com + mpratt: + name: Michael Pratt + email: mpratt@google.com diff --git a/website/blog/2020-10-22-platform-portability.md b/website/blog/2020-10-22-platform-portability.md new file mode 100644 index 000000000..4d82940f9 --- /dev/null +++ b/website/blog/2020-10-22-platform-portability.md @@ -0,0 +1,120 @@ +# Platform Portability + +Hardware virtualization is often seen as a requirement to provide an additional +isolation layer for untrusted applications. However, hardware virtualization +requires expensive bare-metal machines or cloud instances to run safely with +good performance, increasing cost and complexity for Cloud users. gVisor, +however, takes a more flexible approach. + +One of the pillars of gVisor's architecture is portability, allowing it to run +anywhere that runs Linux. Modern Cloud-Native applications run in containers in +many different places, from bare metal to virtual machines, and can't always +rely on nested virtualization. It is important for gVisor to be able to support +the environments where you run containers. + +gVisor achieves portability through an abstraction called a _Platform_. +Platforms can have many implementations, and each implementation can cover +different environments, making use of available software or hardware features. + +## Background + +Before we can understand how gVisor achieves portability using platforms, we +should take a step back and understand how applications interact with their +host. + +Container sandboxes can provide an isolation layer between the host and +application by virtualizing one of the layers below it, including the hardware +or operating system. Many sandboxes virtualize the hardware layer by running +applications in virtual machines. gVisor takes a different approach by +virtualizing the OS layer. + +When an application is run in a normal situation the host operating system loads +the application into user memory and schedules it for execution. The operating +system scheduler eventually schedules the application to a CPU and begins +executing it. It then handles the application's requests, such as for memory and +the lifecycle of the application. gVisor virtualizes these interactions, such as +system calls, and context switching that happen between an application and OS. + +[System calls](https://en.wikipedia.org/wiki/System_call) allow applications to +ask the OS to perform some task for it. System calls look like a normal function +call in most programming languages though works a bit differently under the +hood. When an application system call is encountered some special processing +takes place to do a +[context switch](https://en.wikipedia.org/wiki/Context_switch) into kernel mode +and begin executing code in the kernel before returning a result to the +application. Context switching may happen in other situations as well. For +example, to respond to an interrupt. + +## The Platform Interface + +gVisor provides a sandbox which implements the Linux OS interface, intercepting +OS interactions such as system calls and implements them in the sandbox kernel. + +It does this to limit interactions with the host, and protect the host from an +untrusted application running in the sandbox. The Platform is the bottom layer +of gVisor which provides the environment necessary for gVisor to control and +manage applications. In general, the Platform must: + +1. Provide the ability to create and manage memory address spaces. +2. Provide execution contexts for running applications in those memory address + spaces. +3. Provide the ability to change execution context and return control to gVisor + at specific times (e.g. system call, page fault) + +This interface is conceptually simple, but very powerful. Since the Platform +interface only requires these three capabilities, it gives gVisor enough control +for it to act as the application's OS, while still allowing the use of very +different isolation technologies under the hood. You can learn more about the +Platform interface in the +[Platform Guide](https://gvisor.dev/docs/architecture_guide/platforms/). + +## Implementations of the Platform Interface + +While gVisor can make use of technologies like hardware virtualization, it +doesn't necessarily rely on any one technology to provide a similar level of +isolation. The flexibility of the Platform interface allows for implementations +that use technologies other than hardware virtualization. This allows gVisor to +run in VMs without nested virtualization, for example. By providing an +abstraction for the underlying platform, each implementation can make various +tradeoffs regarding performance or hardware requirements. + +Currently gVisor provides two gVisor Platform implementations; the Ptrace +Platform, and the KVM Platform, each using very different methods to implement +the Platform interface. + +![gVisor Platforms](../../../../../docs/architecture_guide/platforms/platforms.png "Platforms") + +The Ptrace Platform uses +[PTRACE\_SYSEMU](http://man7.org/linux/man-pages/man2/ptrace.2.html) to trap +syscalls, and uses the host for memory mapping and context switching. This +platform can run anywhere that ptrace is available, which includes most Linux +systems, VMs or otherwise. + +The KVM Platform uses virtualization, but in an unconventional way. gVisor runs +in a virtual machine but as both guest OS and VMM, and presents no virtualized +hardware layer. This provides a simpler interface that can avoid hardware +initialization for fast start up, while taking advantage of hardware +virtualization support to improve memory isolation and performance of context +switching. + +The flexibility of the Platform interface allows for a lot of room to improve +the existing KVM and ptrace platforms, as well as the ability to utilize new +methods for improving gVisor's performance or portability in future Platform +implementations. + +## Portability + +Through the Platform interface, gVisor is able to support bare metal, virtual +machines, and Cloud environments while still providing a highly secure sandbox +for running untrusted applications. This is especially important for Cloud and +Kubernetes users because it allows gVisor to run anywhere that Kubernetes can +run and provide similar experiences in multi-region, hybrid, multi-platform +environments. + +Give gVisor's open source platforms a try. Using a Platform is as easy as +providing the `--platform` flag to `runsc`. See the documentation on +[changing platforms](https://gvisor.dev/docs/user_guide/platforms/) for how to +use different platforms with Docker. We would love to hear about your experience +so come chat with us in our +[Gitter channel](https://gitter.im/gvisor/community), or send us an +[issue on Github](https://gvisor.dev/issue) if you run into any problems. diff --git a/website/blog/BUILD b/website/blog/BUILD index 865e403da..17beb721f 100644 --- a/website/blog/BUILD +++ b/website/blog/BUILD @@ -38,6 +38,17 @@ doc( permalink = "/blog/2020/09/18/containing-a-real-vulnerability/", ) +doc( + name = "platform_portability", + src = "2020-10-22-platform-portability.md", + authors = [ + "ianlewis", + "mpratt", + ], + layout = "post", + permalink = "/blog/2020/10/22/platform-portability/", +) + docs( name = "posts", deps = [ diff --git a/website/cmd/server/main.go b/website/cmd/server/main.go index c401b6abd..ac09550a9 100644 --- a/website/cmd/server/main.go +++ b/website/cmd/server/main.go @@ -29,6 +29,7 @@ var redirects = map[string]string{ // GitHub redirects. "/change": "https://github.com/google/gvisor", "/issue": "https://github.com/google/gvisor/issues", + "/issues": "https://github.com/google/gvisor/issues", "/issue/new": "https://github.com/google/gvisor/issues/new", "/pr": "https://github.com/google/gvisor/pulls", @@ -44,14 +45,16 @@ var redirects = map[string]string{ "/c/linux/amd64": "/docs/user_guide/compatibility/linux/amd64/", // Redirect for old URLs. - "/docs/user_guide/compatibility/amd64/": "/docs/user_guide/compatibility/linux/amd64/", - "/docs/user_guide/compatibility/amd64": "/docs/user_guide/compatibility/linux/amd64/", - "/docs/user_guide/kubernetes/": "/docs/user_guide/quick_start/kubernetes/", - "/docs/user_guide/kubernetes": "/docs/user_guide/quick_start/kubernetes/", - "/docs/user_guide/oci/": "/docs/user_guide/quick_start/oci/", - "/docs/user_guide/oci": "/docs/user_guide/quick_start/oci/", - "/docs/user_guide/docker/": "/docs/user_guide/quick_start/docker/", - "/docs/user_guide/docker": "/docs/user_guide/quick_start/docker/", + "/docs/user_guide/compatibility/amd64/": "/docs/user_guide/compatibility/linux/amd64/", + "/docs/user_guide/compatibility/amd64": "/docs/user_guide/compatibility/linux/amd64/", + "/docs/user_guide/kubernetes/": "/docs/user_guide/quick_start/kubernetes/", + "/docs/user_guide/kubernetes": "/docs/user_guide/quick_start/kubernetes/", + "/docs/user_guide/oci/": "/docs/user_guide/quick_start/oci/", + "/docs/user_guide/oci": "/docs/user_guide/quick_start/oci/", + "/docs/user_guide/docker/": "/docs/user_guide/quick_start/docker/", + "/docs/user_guide/docker": "/docs/user_guide/quick_start/docker/", + "/blog/2020/09/22/platform-portability": "/blog/2020/10/22/platform-portability/", + "/blog/2020/09/22/platform-portability/": "/blog/2020/10/22/platform-portability/", // Deprecated, but links continue to work. "/cl": "https://gvisor-review.googlesource.com", @@ -60,6 +63,7 @@ var redirects = map[string]string{ var prefixHelpers = map[string]string{ "change": "https://github.com/google/gvisor/commit/%s", "issue": "https://github.com/google/gvisor/issues/%s", + "issues": "https://github.com/google/gvisor/issues/%s", "pr": "https://github.com/google/gvisor/pull/%s", // Redirects to compatibility docs. |