summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.travis.yml19
-rw-r--r--CONTRIBUTING.md2
-rw-r--r--Dockerfile11
-rw-r--r--Makefile5
-rw-r--r--WORKSPACE45
-rw-r--r--benchmarks/BUILD18
-rw-r--r--benchmarks/README.md42
-rw-r--r--benchmarks/defs.bzl14
-rw-r--r--benchmarks/harness/BUILD184
-rw-r--r--benchmarks/harness/__init__.py36
-rw-r--r--benchmarks/harness/machine.py46
-rw-r--r--benchmarks/harness/machine_producers/BUILD6
-rw-r--r--benchmarks/harness/machine_producers/gcloud_producer.py114
-rw-r--r--benchmarks/harness/ssh_connection.py32
-rw-r--r--benchmarks/runner/BUILD17
-rw-r--r--benchmarks/runner/__init__.py75
-rw-r--r--benchmarks/runner/commands.py70
-rw-r--r--benchmarks/tcp/tcp_proxy.go2
-rw-r--r--benchmarks/workloads/ab/BUILD13
-rw-r--r--benchmarks/workloads/absl/BUILD13
-rw-r--r--benchmarks/workloads/fio/BUILD13
-rw-r--r--benchmarks/workloads/iperf/BUILD13
-rw-r--r--benchmarks/workloads/redisbenchmark/BUILD13
-rw-r--r--benchmarks/workloads/ruby/Gemfile.lock18
-rw-r--r--benchmarks/workloads/ruby_template/BUILD1
-rw-r--r--benchmarks/workloads/ruby_template/Gemfile.lock6
-rw-r--r--benchmarks/workloads/sysbench/BUILD13
-rw-r--r--benchmarks/workloads/syscall/BUILD13
-rw-r--r--kokoro/benchmark_tests.cfg26
-rw-r--r--kokoro/kythe/generate_xrefs.sh2
-rw-r--r--kokoro/packetdrill_tests.cfg9
-rw-r--r--kokoro/runtime_tests/go1.12.cfg16
-rw-r--r--kokoro/runtime_tests/java11.cfg16
-rw-r--r--kokoro/runtime_tests/nodejs12.4.0.cfg16
-rw-r--r--kokoro/runtime_tests/php7.3.6.cfg16
-rw-r--r--kokoro/runtime_tests/python3.7.3.cfg16
-rwxr-xr-xkokoro/runtime_tests/runtime_tests.sh (renamed from scripts/runtime_tests.sh)6
-rw-r--r--pkg/abi/linux/BUILD4
-rw-r--r--pkg/abi/linux/dev.go3
-rw-r--r--pkg/abi/linux/epoll.go12
-rw-r--r--pkg/abi/linux/epoll_amd64.go (renamed from pkg/usermem/usermem_unsafe.go)22
-rw-r--r--pkg/abi/linux/epoll_arm64.go26
-rw-r--r--pkg/abi/linux/file.go2
-rw-r--r--pkg/abi/linux/file_amd64.go4
-rw-r--r--pkg/abi/linux/file_arm64.go4
-rw-r--r--pkg/abi/linux/fs.go2
-rw-r--r--pkg/abi/linux/ioctl.go26
-rw-r--r--pkg/abi/linux/ioctl_tun.go29
-rw-r--r--pkg/abi/linux/netfilter.go109
-rw-r--r--pkg/abi/linux/signal.go2
-rw-r--r--pkg/abi/linux/socket.go13
-rw-r--r--pkg/abi/linux/time.go8
-rw-r--r--pkg/abi/linux/xattr.go1
-rw-r--r--pkg/atomicbitops/BUILD10
-rw-r--r--pkg/atomicbitops/atomicbitops.go (renamed from pkg/atomicbitops/atomic_bitops.go)31
-rw-r--r--pkg/atomicbitops/atomicbitops_amd64.s (renamed from pkg/atomicbitops/atomic_bitops_amd64.s)38
-rw-r--r--pkg/atomicbitops/atomicbitops_arm64.s (renamed from pkg/atomicbitops/atomic_bitops_arm64.s)34
-rw-r--r--pkg/atomicbitops/atomicbitops_noasm.go (renamed from pkg/atomicbitops/atomic_bitops_common.go)42
-rw-r--r--pkg/atomicbitops/atomicbitops_test.go (renamed from pkg/atomicbitops/atomic_bitops_test.go)64
-rw-r--r--pkg/binary/binary.go10
-rw-r--r--pkg/cpuid/BUILD9
-rw-r--r--pkg/cpuid/cpuid.go1098
-rw-r--r--pkg/cpuid/cpuid_arm64.go482
-rw-r--r--pkg/cpuid/cpuid_arm64_test.go55
-rw-r--r--pkg/cpuid/cpuid_parse_x86_test.go (renamed from pkg/cpuid/cpuid_parse_test.go)2
-rw-r--r--pkg/cpuid/cpuid_x86.go1107
-rw-r--r--pkg/cpuid/cpuid_x86_test.go (renamed from pkg/cpuid/cpuid_test.go)2
-rw-r--r--pkg/fspath/BUILD4
-rw-r--r--pkg/fspath/builder.go8
-rw-r--r--pkg/fspath/fspath.go3
-rw-r--r--pkg/gohacks/BUILD11
-rw-r--r--pkg/gohacks/gohacks_unsafe.go57
-rw-r--r--pkg/log/glog.go6
-rw-r--r--pkg/log/json.go2
-rw-r--r--pkg/log/json_k8s.go2
-rw-r--r--pkg/log/log.go60
-rw-r--r--pkg/metric/metric.go1
-rw-r--r--pkg/p9/BUILD3
-rw-r--r--pkg/p9/buffer.go10
-rw-r--r--pkg/p9/client.go9
-rw-r--r--pkg/p9/client_file.go33
-rw-r--r--pkg/p9/file.go16
-rw-r--r--pkg/p9/handlers.go33
-rw-r--r--pkg/p9/messages.go787
-rw-r--r--pkg/p9/messages_test.go4
-rw-r--r--pkg/p9/p9.go72
-rw-r--r--pkg/p9/transport.go4
-rw-r--r--pkg/p9/transport_flipcall.go4
-rw-r--r--pkg/p9/transport_test.go8
-rw-r--r--pkg/p9/version.go8
-rw-r--r--pkg/pool/BUILD25
-rw-r--r--pkg/pool/pool.go (renamed from pkg/p9/pool.go)26
-rw-r--r--pkg/pool/pool_test.go (renamed from pkg/p9/pool_test.go)8
-rw-r--r--pkg/safecopy/safecopy_test.go88
-rw-r--r--pkg/safecopy/safecopy_unsafe.go98
-rw-r--r--pkg/safemem/seq_unsafe.go17
-rw-r--r--pkg/seccomp/BUILD2
-rw-r--r--pkg/seccomp/seccomp.go20
-rw-r--r--pkg/seccomp/seccomp_rules.go6
-rw-r--r--pkg/seccomp/seccomp_test.go27
-rw-r--r--pkg/sentry/BUILD8
-rw-r--r--pkg/sentry/arch/BUILD1
-rw-r--r--pkg/sentry/arch/arch_aarch64.go3
-rw-r--r--pkg/sentry/arch/arch_amd64.s7
-rw-r--r--pkg/sentry/arch/arch_state_x86.go4
-rw-r--r--pkg/sentry/arch/arch_x86.go19
-rw-r--r--pkg/sentry/arch/arch_x86_impl.go43
-rw-r--r--pkg/sentry/arch/signal_amd64.go2
-rw-r--r--pkg/sentry/arch/signal_arm64.go2
-rw-r--r--pkg/sentry/control/BUILD5
-rw-r--r--pkg/sentry/control/proc.go127
-rw-r--r--pkg/sentry/devices/memdev/BUILD28
-rw-r--r--pkg/sentry/devices/memdev/full.go75
-rw-r--r--pkg/sentry/devices/memdev/memdev.go59
-rw-r--r--pkg/sentry/devices/memdev/null.go76
-rw-r--r--pkg/sentry/devices/memdev/random.go92
-rw-r--r--pkg/sentry/devices/memdev/zero.go88
-rw-r--r--pkg/sentry/fs/copy_up.go2
-rw-r--r--pkg/sentry/fs/dev/BUILD5
-rw-r--r--pkg/sentry/fs/dev/dev.go10
-rw-r--r--pkg/sentry/fs/dev/net_tun.go170
-rw-r--r--pkg/sentry/fs/dirent.go25
-rw-r--r--pkg/sentry/fs/file_overlay_test.go1
-rw-r--r--pkg/sentry/fs/fsutil/BUILD4
-rw-r--r--pkg/sentry/fs/fsutil/frame_ref_set.go13
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper.go17
-rw-r--r--pkg/sentry/fs/fsutil/inode.go20
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go2
-rw-r--r--pkg/sentry/fs/g3doc/inotify.md16
-rw-r--r--pkg/sentry/fs/gofer/BUILD2
-rw-r--r--pkg/sentry/fs/gofer/attr.go6
-rw-r--r--pkg/sentry/fs/gofer/cache_policy.go3
-rw-r--r--pkg/sentry/fs/gofer/context_file.go14
-rw-r--r--pkg/sentry/fs/gofer/fifo.go40
-rw-r--r--pkg/sentry/fs/gofer/gofer_test.go2
-rw-r--r--pkg/sentry/fs/gofer/inode.go13
-rw-r--r--pkg/sentry/fs/gofer/path.go183
-rw-r--r--pkg/sentry/fs/gofer/session.go183
-rw-r--r--pkg/sentry/fs/gofer/session_state.go12
-rw-r--r--pkg/sentry/fs/gofer/socket.go8
-rw-r--r--pkg/sentry/fs/inode.go14
-rw-r--r--pkg/sentry/fs/inode_operations.go13
-rw-r--r--pkg/sentry/fs/inode_overlay.go18
-rw-r--r--pkg/sentry/fs/mount_test.go11
-rw-r--r--pkg/sentry/fs/mounts.go10
-rw-r--r--pkg/sentry/fs/proc/BUILD1
-rw-r--r--pkg/sentry/fs/proc/README.md4
-rw-r--r--pkg/sentry/fs/proc/mounts.go16
-rw-r--r--pkg/sentry/fs/proc/net.go5
-rw-r--r--pkg/sentry/fs/proc/sys_net.go4
-rw-r--r--pkg/sentry/fs/proc/task.go17
-rw-r--r--pkg/sentry/fs/tmpfs/inode_file.go10
-rw-r--r--pkg/sentry/fs/tmpfs/tmpfs.go9
-rw-r--r--pkg/sentry/fs/tty/slave.go2
-rw-r--r--pkg/sentry/fsbridge/BUILD24
-rw-r--r--pkg/sentry/fsbridge/bridge.go54
-rw-r--r--pkg/sentry/fsbridge/fs.go181
-rw-r--r--pkg/sentry/fsbridge/vfs.go138
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/devtmpfs.go6
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go7
-rw-r--r--pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go8
-rw-r--r--pkg/sentry/fsimpl/ext/directory.go6
-rw-r--r--pkg/sentry/fsimpl/ext/ext_test.go9
-rw-r--r--pkg/sentry/fsimpl/ext/filesystem.go2
-rw-r--r--pkg/sentry/fsimpl/ext/inode.go12
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD55
-rw-r--r--pkg/sentry/fsimpl/gofer/directory.go194
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go1090
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go1150
-rw-r--r--pkg/sentry/fsimpl/gofer/handle.go135
-rw-r--r--pkg/sentry/fsimpl/gofer/handle_unsafe.go66
-rw-r--r--pkg/sentry/fsimpl/gofer/p9file.go219
-rw-r--r--pkg/sentry/fsimpl/gofer/pagemath.go31
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go865
-rw-r--r--pkg/sentry/fsimpl/gofer/special_file.go159
-rw-r--r--pkg/sentry/fsimpl/gofer/symlink.go47
-rw-r--r--pkg/sentry/fsimpl/gofer/time.go75
-rw-r--r--pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go4
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go18
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go28
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go6
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go2
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go13
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD1
-rw-r--r--pkg/sentry/fsimpl/proc/filesystem.go18
-rw-r--r--pkg/sentry/fsimpl/proc/subtasks.go8
-rw-r--r--pkg/sentry/fsimpl/proc/task.go4
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go20
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_net.go5
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go4
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_test.go17
-rw-r--r--pkg/sentry/fsimpl/sys/BUILD1
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go7
-rw-r--r--pkg/sentry/fsimpl/sys/sys_test.go7
-rw-r--r--pkg/sentry/fsimpl/testutil/BUILD2
-rw-r--r--pkg/sentry/fsimpl/testutil/kernel.go24
-rw-r--r--pkg/sentry/fsimpl/testutil/testutil.go16
-rw-r--r--pkg/sentry/fsimpl/tmpfs/BUILD3
-rw-r--r--pkg/sentry/fsimpl/tmpfs/benchmark_test.go14
-rw-r--r--pkg/sentry/fsimpl/tmpfs/directory.go18
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go24
-rw-r--r--pkg/sentry/fsimpl/tmpfs/pipe_test.go5
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go256
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file_test.go64
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go62
-rw-r--r--pkg/sentry/inet/BUILD1
-rw-r--r--pkg/sentry/inet/inet.go4
-rw-r--r--pkg/sentry/inet/namespace.go99
-rw-r--r--pkg/sentry/inet/test_stack.go6
-rw-r--r--pkg/sentry/kernel/BUILD4
-rw-r--r--pkg/sentry/kernel/fd_table.go49
-rw-r--r--pkg/sentry/kernel/fs_context.go120
-rw-r--r--pkg/sentry/kernel/kernel.go191
-rw-r--r--pkg/sentry/kernel/kernel_opts.go20
-rw-r--r--pkg/sentry/kernel/rseq.go16
-rw-r--r--pkg/sentry/kernel/task.go54
-rw-r--r--pkg/sentry/kernel/task_clone.go27
-rw-r--r--pkg/sentry/kernel/task_context.go4
-rw-r--r--pkg/sentry/kernel/task_exec.go2
-rw-r--r--pkg/sentry/kernel/task_exit.go7
-rw-r--r--pkg/sentry/kernel/task_log.go21
-rw-r--r--pkg/sentry/kernel/task_net.go19
-rw-r--r--pkg/sentry/kernel/task_start.go55
-rw-r--r--pkg/sentry/kernel/task_usermem.go2
-rw-r--r--pkg/sentry/kernel/thread_group.go6
-rw-r--r--pkg/sentry/loader/BUILD2
-rw-r--r--pkg/sentry/loader/elf.go28
-rw-r--r--pkg/sentry/loader/interpreter.go6
-rw-r--r--pkg/sentry/loader/loader.go179
-rw-r--r--pkg/sentry/loader/vdso.go7
-rw-r--r--pkg/sentry/mm/BUILD2
-rw-r--r--pkg/sentry/mm/README.md8
-rw-r--r--pkg/sentry/mm/address_space.go46
-rw-r--r--pkg/sentry/mm/lifecycle.go37
-rw-r--r--pkg/sentry/mm/metadata.go10
-rw-r--r--pkg/sentry/mm/mm.go9
-rw-r--r--pkg/sentry/mm/mm_test.go2
-rw-r--r--pkg/sentry/platform/kvm/bluepill.go6
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go6
-rw-r--r--pkg/sentry/platform/kvm/machine.go18
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go24
-rw-r--r--pkg/sentry/platform/ring0/aarch64.go6
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s85
-rw-r--r--pkg/sentry/platform/ring0/kernel_arm64.go8
-rw-r--r--pkg/sentry/platform/ring0/offsets_arm64.go1
-rw-r--r--pkg/sentry/platform/ring0/pagetables/BUILD2
-rw-r--r--pkg/sentry/socket/control/BUILD1
-rw-r--r--pkg/sentry/socket/control/control.go71
-rw-r--r--pkg/sentry/socket/hostinet/BUILD1
-rw-r--r--pkg/sentry/socket/hostinet/socket.go39
-rw-r--r--pkg/sentry/socket/hostinet/sockopt_impl.go27
-rw-r--r--pkg/sentry/socket/hostinet/stack.go5
-rw-r--r--pkg/sentry/socket/netfilter/BUILD5
-rw-r--r--pkg/sentry/socket/netfilter/extensions.go95
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go422
-rw-r--r--pkg/sentry/socket/netfilter/targets.go35
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go143
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go142
-rw-r--r--pkg/sentry/socket/netlink/BUILD14
-rw-r--r--pkg/sentry/socket/netlink/message.go134
-rw-r--r--pkg/sentry/socket/netlink/message_test.go312
-rw-r--r--pkg/sentry/socket/netlink/provider.go2
-rw-r--r--pkg/sentry/socket/netlink/route/BUILD2
-rw-r--r--pkg/sentry/socket/netlink/route/protocol.go238
-rw-r--r--pkg/sentry/socket/netlink/socket.go54
-rw-r--r--pkg/sentry/socket/netlink/uevent/protocol.go2
-rw-r--r--pkg/sentry/socket/netstack/netstack.go133
-rw-r--r--pkg/sentry/socket/netstack/provider.go2
-rw-r--r--pkg/sentry/socket/netstack/stack.go55
-rw-r--r--pkg/sentry/strace/BUILD2
-rw-r--r--pkg/sentry/strace/epoll.go89
-rw-r--r--pkg/sentry/strace/linux64_amd64.go10
-rw-r--r--pkg/sentry/strace/linux64_arm64.go4
-rw-r--r--pkg/sentry/strace/socket.go231
-rw-r--r--pkg/sentry/strace/strace.go58
-rw-r--r--pkg/sentry/strace/syscalls.go32
-rw-r--r--pkg/sentry/syscalls/linux/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/linux64_amd64.go12
-rw-r--r--pkg/sentry/syscalls/linux/linux64_arm64.go12
-rw-r--r--pkg/sentry/syscalls/linux/sys_epoll.go7
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go44
-rw-r--r--pkg/sentry/syscalls/linux/sys_getdents.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_lseek.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_mmap.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_prctl.go3
-rw-r--r--pkg/sentry/syscalls/linux/sys_read.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_splice.go16
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat.go10
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat_amd64.go72
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat_arm64.go74
-rw-r--r--pkg/sentry/syscalls/linux/sys_sync.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_thread.go17
-rw-r--r--pkg/sentry/syscalls/linux/sys_timer.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_write.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_xattr.go248
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD30
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/epoll.go225
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go44
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/execve.go137
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fd.go147
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/filesystem.go326
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fscontext.go131
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/getdents.go149
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/ioctl.go35
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go140
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go2
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/mmap.go92
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/path.go94
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/poll.go584
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/read_write.go511
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/setstat.go380
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/stat.go323
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/stat_amd64.go46
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/stat_arm64.go46
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/sync.go87
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/sys_read.go95
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/xattr.go353
-rw-r--r--pkg/sentry/usage/memory.go4
-rw-r--r--pkg/sentry/vfs/BUILD3
-rw-r--r--pkg/sentry/vfs/context.go7
-rw-r--r--pkg/sentry/vfs/dentry.go4
-rw-r--r--pkg/sentry/vfs/device.go3
-rw-r--r--pkg/sentry/vfs/epoll.go22
-rw-r--r--pkg/sentry/vfs/file_description.go31
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go21
-rw-r--r--pkg/sentry/vfs/file_description_impl_util_test.go10
-rw-r--r--pkg/sentry/vfs/filesystem.go2
-rw-r--r--pkg/sentry/vfs/filesystem_type.go1
-rw-r--r--pkg/sentry/vfs/lock/BUILD13
-rw-r--r--pkg/sentry/vfs/lock/lock.go72
-rw-r--r--pkg/sentry/vfs/mount.go34
-rw-r--r--pkg/sentry/vfs/mount_unsafe.go20
-rw-r--r--pkg/sentry/vfs/options.go7
-rw-r--r--pkg/sentry/vfs/permissions.go19
-rw-r--r--pkg/sentry/vfs/resolving_path.go2
-rw-r--r--pkg/sentry/vfs/vfs.go65
-rw-r--r--pkg/sleep/commit_noasm.go13
-rw-r--r--pkg/sleep/sleep_unsafe.go23
-rw-r--r--pkg/syncevent/BUILD39
-rw-r--r--pkg/syncevent/broadcaster.go218
-rw-r--r--pkg/syncevent/broadcaster_test.go376
-rw-r--r--pkg/syncevent/receiver.go103
-rw-r--r--pkg/syncevent/source.go59
-rw-r--r--pkg/syncevent/syncevent.go32
-rw-r--r--pkg/syncevent/syncevent_example_test.go108
-rw-r--r--pkg/syncevent/waiter_amd64.s32
-rw-r--r--pkg/syncevent/waiter_arm64.s34
-rw-r--r--pkg/syncevent/waiter_asm_unsafe.go24
-rw-r--r--pkg/syncevent/waiter_noasm_unsafe.go39
-rw-r--r--pkg/syncevent/waiter_test.go414
-rw-r--r--pkg/syncevent/waiter_unsafe.go206
-rw-r--r--pkg/syserror/syserror.go1
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go111
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go83
-rw-r--r--pkg/tcpip/buffer/view.go6
-rw-r--r--pkg/tcpip/checker/checker.go14
-rw-r--r--pkg/tcpip/header/eth.go41
-rw-r--r--pkg/tcpip/header/eth_test.go34
-rw-r--r--pkg/tcpip/header/ipv6.go22
-rw-r--r--pkg/tcpip/header/ipv6_test.go127
-rw-r--r--pkg/tcpip/iptables/iptables.go128
-rw-r--r--pkg/tcpip/iptables/targets.go43
-rw-r--r--pkg/tcpip/iptables/types.go54
-rw-r--r--pkg/tcpip/link/channel/BUILD1
-rw-r--r--pkg/tcpip/link/channel/channel.go209
-rw-r--r--pkg/tcpip/link/tun/BUILD18
-rw-r--r--pkg/tcpip/link/tun/device.go352
-rw-r--r--pkg/tcpip/link/tun/protocol.go56
-rw-r--r--pkg/tcpip/network/arp/arp.go39
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go16
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go70
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go198
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go6
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go278
-rw-r--r--pkg/tcpip/stack/BUILD6
-rw-r--r--pkg/tcpip/stack/ndp.go107
-rw-r--r--pkg/tcpip/stack/ndp_test.go1218
-rw-r--r--pkg/tcpip/stack/nic.go252
-rw-r--r--pkg/tcpip/stack/nic_test.go62
-rw-r--r--pkg/tcpip/stack/registration.go25
-rw-r--r--pkg/tcpip/stack/route.go4
-rw-r--r--pkg/tcpip/stack/stack.go121
-rw-r--r--pkg/tcpip/stack/stack_test.go991
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go20
-rw-r--r--pkg/tcpip/stack/transport_test.go15
-rw-r--r--pkg/tcpip/tcpip.go66
-rw-r--r--pkg/tcpip/time_unsafe.go2
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go5
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go16
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go26
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go5
-rw-r--r--pkg/tcpip/transport/tcp/BUILD2
-rw-r--r--pkg/tcpip/transport/tcp/accept.go38
-rw-r--r--pkg/tcpip/transport/tcp/connect.go61
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go31
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go76
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go4
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go14
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go195
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go6
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go67
-rw-r--r--pkg/tcpip/transport/udp/protocol.go14
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go120
-rw-r--r--pkg/usermem/BUILD2
-rw-r--r--pkg/usermem/usermem.go9
-rw-r--r--runsc/BUILD5
-rw-r--r--runsc/boot/BUILD2
-rw-r--r--runsc/boot/controller.go11
-rw-r--r--runsc/boot/filter/BUILD1
-rw-r--r--runsc/boot/filter/config.go27
-rw-r--r--runsc/boot/filter/config_profile.go34
-rw-r--r--runsc/boot/loader.go140
-rw-r--r--runsc/boot/network.go27
-rw-r--r--runsc/boot/pprof/BUILD11
-rw-r--r--runsc/boot/pprof/pprof.go (renamed from runsc/boot/pprof.go)6
-rw-r--r--runsc/cmd/BUILD4
-rw-r--r--runsc/cmd/boot.go2
-rw-r--r--runsc/cmd/checkpoint.go2
-rw-r--r--runsc/cmd/create.go2
-rw-r--r--runsc/cmd/debug.go2
-rw-r--r--runsc/cmd/delete.go2
-rw-r--r--runsc/cmd/do.go2
-rw-r--r--runsc/cmd/events.go2
-rw-r--r--runsc/cmd/exec.go2
-rw-r--r--runsc/cmd/gofer.go2
-rw-r--r--runsc/cmd/help.go2
-rw-r--r--runsc/cmd/install.go2
-rw-r--r--runsc/cmd/kill.go2
-rw-r--r--runsc/cmd/list.go2
-rw-r--r--runsc/cmd/pause.go2
-rw-r--r--runsc/cmd/ps.go2
-rw-r--r--runsc/cmd/restore.go2
-rw-r--r--runsc/cmd/resume.go2
-rw-r--r--runsc/cmd/run.go2
-rw-r--r--runsc/cmd/spec.go2
-rw-r--r--runsc/cmd/start.go2
-rw-r--r--runsc/cmd/state.go2
-rw-r--r--runsc/cmd/statefile.go143
-rw-r--r--runsc/cmd/syscalls.go2
-rw-r--r--runsc/cmd/wait.go2
-rw-r--r--runsc/container/BUILD2
-rw-r--r--runsc/container/console_test.go5
-rw-r--r--runsc/container/container_test.go202
-rw-r--r--runsc/container/test_app/BUILD1
-rw-r--r--runsc/container/test_app/fds.go2
-rw-r--r--runsc/container/test_app/test_app.go2
-rw-r--r--runsc/dockerutil/dockerutil.go11
-rw-r--r--runsc/flag/BUILD9
-rw-r--r--runsc/flag/flag.go (renamed from pkg/fspath/builder_unsafe.go)24
-rw-r--r--runsc/fsgofer/filter/config.go12
-rw-r--r--runsc/fsgofer/fsgofer.go16
-rw-r--r--runsc/main.go6
-rw-r--r--runsc/sandbox/network.go55
-rw-r--r--runsc/testutil/BUILD5
-rw-r--r--runsc/testutil/testutil.go127
-rw-r--r--runsc/testutil/testutil_runfiles.go75
-rw-r--r--scripts/benchmark.sh25
-rwxr-xr-xscripts/common.sh12
-rwxr-xr-xscripts/common_build.sh35
-rwxr-xr-xscripts/iptables_tests.sh4
-rwxr-xr-xscripts/packetdrill_tests.sh20
-rwxr-xr-xscripts/release.sh13
-rw-r--r--test/image/image_test.go8
-rw-r--r--test/iptables/README.md14
-rw-r--r--test/iptables/filter_input.go347
-rw-r--r--test/iptables/iptables_test.go62
-rw-r--r--test/iptables/iptables_util.go21
-rw-r--r--test/iptables/nat.go4
-rw-r--r--test/packetdrill/BUILD8
-rw-r--r--test/packetdrill/Dockerfile9
-rw-r--r--test/packetdrill/defs.bzl87
-rw-r--r--test/packetdrill/fin_wait2_timeout.pkt23
-rwxr-xr-xtest/packetdrill/packetdrill_setup.sh26
-rwxr-xr-xtest/packetdrill/packetdrill_test.sh225
-rw-r--r--test/perf/BUILD114
-rw-r--r--test/perf/linux/BUILD356
-rw-r--r--test/perf/linux/clock_getres_benchmark.cc39
-rw-r--r--test/perf/linux/clock_gettime_benchmark.cc60
-rw-r--r--test/perf/linux/death_benchmark.cc36
-rw-r--r--test/perf/linux/epoll_benchmark.cc99
-rw-r--r--test/perf/linux/fork_benchmark.cc350
-rw-r--r--test/perf/linux/futex_benchmark.cc248
-rw-r--r--test/perf/linux/getdents_benchmark.cc149
-rw-r--r--test/perf/linux/getpid_benchmark.cc37
-rw-r--r--test/perf/linux/gettid_benchmark.cc38
-rw-r--r--test/perf/linux/mapping_benchmark.cc163
-rw-r--r--test/perf/linux/open_benchmark.cc56
-rw-r--r--test/perf/linux/pipe_benchmark.cc66
-rw-r--r--test/perf/linux/randread_benchmark.cc100
-rw-r--r--test/perf/linux/read_benchmark.cc53
-rw-r--r--test/perf/linux/sched_yield_benchmark.cc37
-rw-r--r--test/perf/linux/send_recv_benchmark.cc372
-rw-r--r--test/perf/linux/seqwrite_benchmark.cc66
-rw-r--r--test/perf/linux/signal_benchmark.cc59
-rw-r--r--test/perf/linux/sleep_benchmark.cc60
-rw-r--r--test/perf/linux/stat_benchmark.cc62
-rw-r--r--test/perf/linux/unlink_benchmark.cc66
-rw-r--r--test/perf/linux/write_benchmark.cc52
-rw-r--r--test/root/testdata/BUILD2
-rw-r--r--test/runner/BUILD22
-rw-r--r--test/runner/defs.bzl198
-rw-r--r--test/runner/gtest/BUILD (renamed from test/syscalls/gtest/BUILD)0
-rw-r--r--test/runner/gtest/gtest.go167
-rw-r--r--test/runner/runner.go (renamed from test/syscalls/syscall_test_runner.go)40
-rw-r--r--test/runtimes/README.md31
-rw-r--r--test/runtimes/runner.go9
-rw-r--r--test/syscalls/BUILD35
-rw-r--r--test/syscalls/build_defs.bzl153
-rw-r--r--test/syscalls/gtest/gtest.go93
-rw-r--r--test/syscalls/linux/32bit.cc10
-rw-r--r--test/syscalls/linux/BUILD793
-rw-r--r--test/syscalls/linux/alarm.cc3
-rw-r--r--test/syscalls/linux/bad.cc12
-rw-r--r--test/syscalls/linux/chroot.cc4
-rw-r--r--test/syscalls/linux/concurrency.cc3
-rw-r--r--test/syscalls/linux/dev.cc7
-rw-r--r--test/syscalls/linux/exec.cc3
-rw-r--r--test/syscalls/linux/exec_proc_exe_workload.cc6
-rw-r--r--test/syscalls/linux/fallocate.cc12
-rw-r--r--test/syscalls/linux/fcntl.cc7
-rw-r--r--test/syscalls/linux/fork.cc8
-rw-r--r--test/syscalls/linux/getdents.cc11
-rw-r--r--test/syscalls/linux/inotify.cc9
-rw-r--r--test/syscalls/linux/ip_socket_test_util.cc27
-rw-r--r--test/syscalls/linux/ip_socket_test_util.h31
-rw-r--r--test/syscalls/linux/itimer.cc3
-rw-r--r--test/syscalls/linux/mmap.cc8
-rw-r--r--test/syscalls/linux/network_namespace.cc121
-rw-r--r--test/syscalls/linux/open_create.cc1
-rw-r--r--test/syscalls/linux/packet_socket.cc124
-rw-r--r--test/syscalls/linux/pipe.cc6
-rw-r--r--test/syscalls/linux/prctl.cc2
-rw-r--r--test/syscalls/linux/prctl_setuid.cc2
-rw-r--r--test/syscalls/linux/preadv.cc1
-rw-r--r--test/syscalls/linux/preadv2.cc2
-rw-r--r--test/syscalls/linux/proc.cc62
-rw-r--r--test/syscalls/linux/ptrace.cc2
-rw-r--r--test/syscalls/linux/pwritev2.cc2
-rw-r--r--test/syscalls/linux/readv.cc4
-rw-r--r--test/syscalls/linux/rseq.cc2
-rw-r--r--test/syscalls/linux/rseq/uapi.h29
-rw-r--r--test/syscalls/linux/rtsignal.cc3
-rw-r--r--test/syscalls/linux/seccomp.cc7
-rw-r--r--test/syscalls/linux/select.cc2
-rw-r--r--test/syscalls/linux/shm.cc2
-rw-r--r--test/syscalls/linux/sigiret.cc3
-rw-r--r--test/syscalls/linux/signalfd.cc2
-rw-r--r--test/syscalls/linux/sigprocmask.cc2
-rw-r--r--test/syscalls/linux/sigstop.cc2
-rw-r--r--test/syscalls/linux/sigtimedwait.cc3
-rw-r--r--test/syscalls/linux/socket_abstract.cc2
-rw-r--r--test/syscalls/linux/socket_filesystem.cc2
-rw-r--r--test/syscalls/linux/socket_generic.cc96
-rw-r--r--test/syscalls/linux/socket_generic_stress.cc83
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc171
-rw-r--r--test/syscalls/linux/socket_ip_tcp_generic.cc33
-rw-r--r--test/syscalls/linux/socket_ip_tcp_generic_loopback.cc2
-rw-r--r--test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc2
-rw-r--r--test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc2
-rw-r--r--test/syscalls/linux/socket_ip_udp_generic.cc177
-rw-r--r--test/syscalls/linux/socket_ip_udp_loopback.cc2
-rw-r--r--test/syscalls/linux/socket_ip_udp_loopback_blocking.cc2
-rw-r--r--test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc2
-rw-r--r--test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc3
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound.cc105
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc20
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc3
-rw-r--r--test/syscalls/linux/socket_netlink_route.cc296
-rw-r--r--test/syscalls/linux/socket_netlink_route_util.cc163
-rw-r--r--test/syscalls/linux/socket_netlink_route_util.h55
-rw-r--r--test/syscalls/linux/socket_netlink_util.cc45
-rw-r--r--test/syscalls/linux/socket_netlink_util.h9
-rw-r--r--test/syscalls/linux/socket_test_util.cc119
-rw-r--r--test/syscalls/linux/socket_test_util.h11
-rw-r--r--test/syscalls/linux/socket_unix_abstract_nonblock.cc2
-rw-r--r--test/syscalls/linux/socket_unix_blocking_local.cc2
-rw-r--r--test/syscalls/linux/socket_unix_dgram_local.cc2
-rw-r--r--test/syscalls/linux/socket_unix_domain.cc2
-rw-r--r--test/syscalls/linux/socket_unix_filesystem_nonblock.cc2
-rw-r--r--test/syscalls/linux/socket_unix_non_stream.cc4
-rw-r--r--test/syscalls/linux/socket_unix_non_stream_blocking_local.cc2
-rw-r--r--test/syscalls/linux/socket_unix_pair.cc2
-rw-r--r--test/syscalls/linux/socket_unix_pair_nonblock.cc2
-rw-r--r--test/syscalls/linux/socket_unix_seqpacket_local.cc2
-rw-r--r--test/syscalls/linux/socket_unix_stream_blocking_local.cc2
-rw-r--r--test/syscalls/linux/socket_unix_stream_local.cc2
-rw-r--r--test/syscalls/linux/socket_unix_stream_nonblock_local.cc2
-rw-r--r--test/syscalls/linux/splice.cc56
-rw-r--r--test/syscalls/linux/stat.cc2
-rw-r--r--test/syscalls/linux/symlink.cc2
-rw-r--r--test/syscalls/linux/tcp_socket.cc65
-rw-r--r--test/syscalls/linux/time.cc1
-rw-r--r--test/syscalls/linux/timers.cc20
-rw-r--r--test/syscalls/linux/tkill.cc2
-rw-r--r--test/syscalls/linux/tuntap.cc346
-rw-r--r--test/syscalls/linux/udp_socket_test_cases.cc9
-rw-r--r--test/syscalls/linux/vfork.cc2
-rw-r--r--test/syscalls/linux/xattr.cc174
-rwxr-xr-xtest/syscalls/syscall_test_runner.sh34
-rw-r--r--test/util/BUILD31
-rw-r--r--test/util/fs_util.h11
-rw-r--r--test/util/platform_util.cc5
-rw-r--r--test/util/signal_util.h14
-rw-r--r--test/util/temp_path.cc1
-rw-r--r--test/util/test_main.cc2
-rw-r--r--test/util/test_util.cc2
-rw-r--r--test/util/test_util.h1
-rw-r--r--test/util/test_util_impl.cc14
-rw-r--r--tools/bazeldefs/BUILD (renamed from tools/build/BUILD)2
-rw-r--r--tools/bazeldefs/defs.bzl (renamed from tools/build/defs.bzl)6
-rw-r--r--tools/bazeldefs/platforms.bzl17
-rw-r--r--tools/bazeldefs/tags.bzl40
-rw-r--r--tools/checkunsafe/BUILD2
-rw-r--r--tools/defs.bzl134
-rw-r--r--tools/go_generics/BUILD2
-rw-r--r--tools/go_generics/go_merge/BUILD2
-rw-r--r--tools/go_marshal/BUILD5
-rw-r--r--tools/go_marshal/gomarshal/BUILD1
-rw-r--r--tools/go_marshal/gomarshal/generator.go128
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces.go443
-rw-r--r--tools/go_marshal/gomarshal/generator_tests.go67
-rw-r--r--tools/go_marshal/gomarshal/util.go30
-rw-r--r--tools/go_marshal/marshal/BUILD3
-rw-r--r--tools/go_marshal/marshal/marshal.go50
-rw-r--r--tools/go_marshal/test/BUILD14
-rw-r--r--tools/go_marshal/test/benchmark_test.go2
-rw-r--r--tools/go_marshal/test/escape.go114
-rw-r--r--tools/go_marshal/test/test.go10
-rw-r--r--tools/go_stateify/BUILD3
-rw-r--r--tools/go_stateify/defs.bzl14
-rw-r--r--tools/go_stateify/main.go130
-rw-r--r--tools/images/BUILD4
-rw-r--r--tools/images/defs.bzl5
-rw-r--r--tools/installers/BUILD7
-rwxr-xr-xtools/installers/head.sh2
-rwxr-xr-xtools/installers/master.sh17
-rwxr-xr-xtools/tag_release.sh12
-rw-r--r--tools/tags/BUILD11
-rw-r--r--tools/tags/tags.go89
639 files changed, 31663 insertions, 6738 deletions
diff --git a/.travis.yml b/.travis.yml
new file mode 100644
index 000000000..a2a260538
--- /dev/null
+++ b/.travis.yml
@@ -0,0 +1,19 @@
+language: minimal
+sudo: required
+dist: xenial
+cache:
+ directories:
+ - /home/travis/.cache/bazel/
+services:
+ - docker
+matrix:
+ include:
+ - os: linux
+ arch: amd64
+ env: RUNSC_PATH=./bazel-bin/runsc/linux_amd64_pure_stripped/runsc
+ - os: linux
+ arch: arm64
+ env: RUNSC_PATH=./bazel-bin/runsc/linux_arm64_pure_stripped/runsc
+script:
+ - uname -a
+ - make DOCKER_RUN_OPTIONS="" BAZEL_OPTIONS="build runsc:runsc" bazel && $RUNSC_PATH --alsologtostderr --network none --debug --TESTONLY-unsafe-nonroot=true --rootless do ls
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 55a1ad0d9..71650a4b8 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -47,7 +47,7 @@ Definitions for the rules below:
`core`:
* `//pkg/sentry/...`
-* Transitive dependencies in `//pkg/...`, `//third_party/...`.
+* Transitive dependencies in `//pkg/...`, etc.
`runsc`:
diff --git a/Dockerfile b/Dockerfile
index 738623023..2bfdfec6c 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,8 +1,9 @@
-FROM ubuntu:bionic
+FROM fedora:31
-RUN apt-get update && apt-get install -y curl gnupg2 git python python3 python3-distutils python3-pip
-RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \
- curl https://bazel.build/bazel-release.pub.gpg | apt-key add -
-RUN apt-get update && apt-get install -y bazel && apt-get clean
+RUN dnf install -y dnf-plugins-core && dnf copr enable -y vbatts/bazel
+
+RUN dnf install -y bazel2 git gcc make golang gcc-c++ glibc-devel python3 which python3-pip python3-devel libffi-devel openssl-devel pkg-config glibc-static
+
+RUN pip install pycparser
WORKDIR /gvisor
diff --git a/Makefile b/Makefile
index a73bc0c36..d9531fbd5 100644
--- a/Makefile
+++ b/Makefile
@@ -2,6 +2,9 @@ UID := $(shell id -u ${USER})
GID := $(shell id -g ${USER})
GVISOR_BAZEL_CACHE := $(shell readlink -f ~/.cache/bazel/)
+# The --privileged is required to run tests.
+DOCKER_RUN_OPTIONS ?= --privileged
+
all: runsc
docker-build:
@@ -19,7 +22,7 @@ bazel-server-start: docker-build
-v "$(CURDIR):$(CURDIR)" \
--workdir "$(CURDIR)" \
--tmpfs /tmp:rw,exec \
- --privileged \
+ $(DOCKER_RUN_OPTIONS) \
gvisor-bazel \
sh -c "while :; do sleep 100; done" && \
docker exec --user 0:0 -i gvisor-bazel sh -c "groupadd --gid $(GID) --non-unique gvisor && useradd --uid $(UID) --non-unique --gid $(GID) -d $(HOME) gvisor"
diff --git a/WORKSPACE b/WORKSPACE
index 5d2fc36f9..a15238a2e 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -4,19 +4,19 @@ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
# Load go bazel rules and gazelle.
http_archive(
name = "io_bazel_rules_go",
- sha256 = "b27e55d2dcc9e6020e17614ae6e0374818a3e3ce6f2024036e688ada24110444",
+ sha256 = "f99a9d76e972e0c8f935b2fe6d0d9d778f67c760c6d2400e23fc2e469016e2bd",
urls = [
- "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/v0.21.0/rules_go-v0.21.0.tar.gz",
- "https://github.com/bazelbuild/rules_go/releases/download/v0.21.0/rules_go-v0.21.0.tar.gz",
+ "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/v0.21.2/rules_go-v0.21.2.tar.gz",
+ "https://github.com/bazelbuild/rules_go/releases/download/v0.21.2/rules_go-v0.21.2.tar.gz",
],
)
http_archive(
name = "bazel_gazelle",
- sha256 = "86c6d481b3f7aedc1d60c1c211c6f76da282ae197c3b3160f54bd3a8f847896f",
+ sha256 = "d8c45ee70ec39a57e7a05e5027c32b1576cc7f16d9dd37135b0eddde45cf1b10",
urls = [
- "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/bazel-gazelle/releases/download/v0.19.1/bazel-gazelle-v0.19.1.tar.gz",
- "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.19.1/bazel-gazelle-v0.19.1.tar.gz",
+ "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/bazel-gazelle/releases/download/v0.20.0/bazel-gazelle-v0.20.0.tar.gz",
+ "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.20.0/bazel-gazelle-v0.20.0.tar.gz",
],
)
@@ -25,7 +25,7 @@ load("@io_bazel_rules_go//go:deps.bzl", "go_rules_dependencies", "go_register_to
go_rules_dependencies()
go_register_toolchains(
- go_version = "1.13.6",
+ go_version = "1.13.7",
nogo = "@//:nogo",
)
@@ -33,6 +33,20 @@ load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies", "go_repository")
gazelle_dependencies()
+# TODO(gvisor.dev/issue/1876): Move the statement to "External repositories"
+# block below once 1876 is fixed.
+#
+# The com_google_protobuf repository below would trigger downloading a older
+# version of org_golang_x_sys. If putting this repository statment in a place
+# after that of the com_google_protobuf, this statement will not work as
+# expectd to download a new version of org_golang_x_sys.
+go_repository(
+ name = "org_golang_x_sys",
+ importpath = "golang.org/x/sys",
+ sum = "h1:72l8qCJ1nGxMGH26QVBVIxKd/D34cfGt0OvrPtpemyY=",
+ version = "v0.0.0-20191220220014-0732a990476f",
+)
+
# Load C++ rules.
http_archive(
name = "rules_cc",
@@ -257,13 +271,6 @@ go_repository(
)
go_repository(
- name = "org_golang_x_sys",
- importpath = "golang.org/x/sys",
- sum = "h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=",
- version = "v0.0.0-20190215142949-d0b11bdaac8a",
-)
-
-go_repository(
name = "org_golang_x_time",
commit = "c4c64cad1fd0a1a8dab2523e04e61d35308e131e",
importpath = "golang.org/x/time",
@@ -330,3 +337,13 @@ http_archive(
"https://github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz",
],
)
+
+http_archive(
+ name = "com_google_benchmark",
+ sha256 = "3c6a165b6ecc948967a1ead710d4a181d7b0fbcaa183ef7ea84604994966221a",
+ strip_prefix = "benchmark-1.5.0",
+ urls = [
+ "https://mirror.bazel.build/github.com/google/benchmark/archive/v1.5.0.tar.gz",
+ "https://github.com/google/benchmark/archive/v1.5.0.tar.gz",
+ ],
+)
diff --git a/benchmarks/BUILD b/benchmarks/BUILD
index 1455c6c5b..2a2d15d7e 100644
--- a/benchmarks/BUILD
+++ b/benchmarks/BUILD
@@ -1,10 +1,28 @@
package(licenses = ["notice"])
+config_setting(
+ name = "gcloud_rule",
+ values = {
+ "define": "gcloud=off",
+ },
+)
+
py_binary(
name = "benchmarks",
srcs = ["run.py"],
+ data = select({
+ ":gcloud_rule": [],
+ "//conditions:default": [
+ "//tools/images:ubuntu1604",
+ "//tools/images:zone",
+ ],
+ }),
main = "run.py",
python_version = "PY3",
srcs_version = "PY3",
+ tags = [
+ "local",
+ "manual",
+ ],
deps = ["//benchmarks/runner"],
)
diff --git a/benchmarks/README.md b/benchmarks/README.md
index ff21614c5..6d1ea3ae2 100644
--- a/benchmarks/README.md
+++ b/benchmarks/README.md
@@ -26,10 +26,14 @@ For configuring the environment manually, consult the
## Running benchmarks
-Run the following from the benchmarks directory:
+### Locally
+
+The tool is built to, by default, use Google Cloud Platform to run benchmarks,
+but it does support GCP workflows. To run locally, run the following from the
+benchmarks directory:
```bash
-bazel run :benchmarks -- run-local startup
+bazel run --define gcloud=off :benchmarks -- run-local startup
...
method,metric,result
@@ -44,17 +48,16 @@ runtime, runc. Running on another installed runtime, like say runsc, is as
simple as:
```bash
-bazel run :benchmakrs -- run-local startup --runtime=runsc
+bazel run --define gcloud=off :benchmarks -- run-local startup --runtime=runsc
```
-There is help: ``bash bash bazel run :benchmarks -- --help bazel
-run :benchmarks -- run-local --help` ``
+There is help: `bash bazel run --define gcloud=off :benchmarks -- --help bazel
+run --define gcloud=off :benchmarks -- run-local --help`
To list available benchmarks, use the `list` commmand:
```bash
-bazel run :benchmarks -- list
-ls
+bazel --define gcloud=off run :benchmarks -- list
...
Benchmark: sysbench.cpu
@@ -67,7 +70,7 @@ Metrics: events_per_second
You can choose benchmarks by name or regex like:
```bash
-bazel run :benchmarks -- run-local startup.node
+bazel run --define gcloud=off :benchmarks -- run-local startup.node
...
metric,result
startup_time_ms,1671.7178000000001
@@ -77,7 +80,7 @@ startup_time_ms,1671.7178000000001
or
```bash
-bazel run :benchmarks -- run-local s
+bazel run --define gcloud=off :benchmarks -- run-local s
...
method,metric,result
startup.empty,startup_time_ms,1792.8292
@@ -95,15 +98,32 @@ You can run parameterized benchmarks, for example to run with different
runtimes:
```bash
-bazel run :benchmarks -- run-local --runtime=runc --runtime=runsc sysbench.cpu
+bazel run --define gcloud=off :benchmarks -- run-local --runtime=runc --runtime=runsc sysbench.cpu
```
Or with different parameters:
```bash
-bazel run :benchmarks -- run-local --max_prime=10 --max_prime=100 sysbench.cpu
+bazel run --define gcloud=off :benchmarks -- run-local --max_prime=10 --max_prime=100 sysbench.cpu
+```
+
+### On Google Compute Engine (GCE)
+
+Benchmarks may be run on GCE in an automated way. The default project configured
+for `gcloud` will be used.
+
+An additional parameter `installers` may be provided to ensure that the latest
+runtime is installed from the workspace. See the files in `tools/installers` for
+supported install targets.
+
+```bash
+bazel run :benchmarks -- run-gcp --installers=head --runtime=runsc sysbench.cpu
```
+When running on GCE, the scripts generate a per run SSH key, which is added to
+your project. The key is set to expire in GCE after 60 minutes and is stored in
+a temporary directory on the local machine running the scripts.
+
## Writing benchmarks
To write new benchmarks, you should familiarize yourself with the structure of
diff --git a/benchmarks/defs.bzl b/benchmarks/defs.bzl
new file mode 100644
index 000000000..56d28223e
--- /dev/null
+++ b/benchmarks/defs.bzl
@@ -0,0 +1,14 @@
+"""Provides attributes common to many workload tests."""
+
+load("//tools:defs.bzl", "py_requirement")
+
+test_deps = [
+ py_requirement("attrs", direct = False),
+ py_requirement("atomicwrites", direct = False),
+ py_requirement("more-itertools", direct = False),
+ py_requirement("pathlib2", direct = False),
+ py_requirement("pluggy", direct = False),
+ py_requirement("py", direct = False),
+ py_requirement("pytest"),
+ py_requirement("six", direct = False),
+]
diff --git a/benchmarks/harness/BUILD b/benchmarks/harness/BUILD
index 52d4e42f8..48c548d59 100644
--- a/benchmarks/harness/BUILD
+++ b/benchmarks/harness/BUILD
@@ -1,13 +1,33 @@
-load("//tools:defs.bzl", "py_library", "py_requirement")
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement")
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
)
+pkg_tar(
+ name = "installers",
+ srcs = [
+ "//tools/installers:head",
+ "//tools/installers:master",
+ "//tools/installers:runsc",
+ ],
+ mode = "0755",
+)
+
+filegroup(
+ name = "files",
+ srcs = [
+ ":installers",
+ ],
+)
+
py_library(
name = "harness",
srcs = ["__init__.py"],
+ data = [
+ ":files",
+ ],
)
py_library(
@@ -25,16 +45,43 @@ py_library(
srcs = ["container.py"],
deps = [
"//benchmarks/workloads",
- py_requirement("asn1crypto", False),
- py_requirement("chardet", False),
- py_requirement("certifi", False),
- py_requirement("docker", True),
- py_requirement("docker-pycreds", False),
- py_requirement("idna", False),
- py_requirement("ptyprocess", False),
- py_requirement("requests", False),
- py_requirement("urllib3", False),
- py_requirement("websocket-client", False),
+ py_requirement(
+ "asn1crypto",
+ direct = False,
+ ),
+ py_requirement(
+ "chardet",
+ direct = False,
+ ),
+ py_requirement(
+ "certifi",
+ direct = False,
+ ),
+ py_requirement("docker"),
+ py_requirement(
+ "docker-pycreds",
+ direct = False,
+ ),
+ py_requirement(
+ "idna",
+ direct = False,
+ ),
+ py_requirement(
+ "ptyprocess",
+ direct = False,
+ ),
+ py_requirement(
+ "requests",
+ direct = False,
+ ),
+ py_requirement(
+ "urllib3",
+ direct = False,
+ ),
+ py_requirement(
+ "websocket-client",
+ direct = False,
+ ),
],
)
@@ -47,17 +94,47 @@ py_library(
"//benchmarks/harness:ssh_connection",
"//benchmarks/harness:tunnel_dispatcher",
"//benchmarks/harness/machine_mocks",
- py_requirement("asn1crypto", False),
- py_requirement("chardet", False),
- py_requirement("certifi", False),
- py_requirement("docker", True),
- py_requirement("docker-pycreds", False),
- py_requirement("idna", False),
- py_requirement("ptyprocess", False),
- py_requirement("requests", False),
- py_requirement("six", False),
- py_requirement("urllib3", False),
- py_requirement("websocket-client", False),
+ py_requirement(
+ "asn1crypto",
+ direct = False,
+ ),
+ py_requirement(
+ "chardet",
+ direct = False,
+ ),
+ py_requirement(
+ "certifi",
+ direct = False,
+ ),
+ py_requirement("docker"),
+ py_requirement(
+ "docker-pycreds",
+ direct = False,
+ ),
+ py_requirement(
+ "idna",
+ direct = False,
+ ),
+ py_requirement(
+ "ptyprocess",
+ direct = False,
+ ),
+ py_requirement(
+ "requests",
+ direct = False,
+ ),
+ py_requirement(
+ "six",
+ direct = False,
+ ),
+ py_requirement(
+ "urllib3",
+ direct = False,
+ ),
+ py_requirement(
+ "websocket-client",
+ direct = False,
+ ),
],
)
@@ -66,10 +143,16 @@ py_library(
srcs = ["ssh_connection.py"],
deps = [
"//benchmarks/harness",
- py_requirement("bcrypt", False),
- py_requirement("cffi", True),
- py_requirement("paramiko", True),
- py_requirement("cryptography", False),
+ py_requirement(
+ "bcrypt",
+ direct = False,
+ ),
+ py_requirement("cffi"),
+ py_requirement("paramiko"),
+ py_requirement(
+ "cryptography",
+ direct = False,
+ ),
],
)
@@ -77,16 +160,43 @@ py_library(
name = "tunnel_dispatcher",
srcs = ["tunnel_dispatcher.py"],
deps = [
- py_requirement("asn1crypto", False),
- py_requirement("chardet", False),
- py_requirement("certifi", False),
- py_requirement("docker", True),
- py_requirement("docker-pycreds", False),
- py_requirement("idna", False),
- py_requirement("pexpect", True),
- py_requirement("ptyprocess", False),
- py_requirement("requests", False),
- py_requirement("urllib3", False),
- py_requirement("websocket-client", False),
+ py_requirement(
+ "asn1crypto",
+ direct = False,
+ ),
+ py_requirement(
+ "chardet",
+ direct = False,
+ ),
+ py_requirement(
+ "certifi",
+ direct = False,
+ ),
+ py_requirement("docker"),
+ py_requirement(
+ "docker-pycreds",
+ direct = False,
+ ),
+ py_requirement(
+ "idna",
+ direct = False,
+ ),
+ py_requirement("pexpect"),
+ py_requirement(
+ "ptyprocess",
+ direct = False,
+ ),
+ py_requirement(
+ "requests",
+ direct = False,
+ ),
+ py_requirement(
+ "urllib3",
+ direct = False,
+ ),
+ py_requirement(
+ "websocket-client",
+ direct = False,
+ ),
],
)
diff --git a/benchmarks/harness/__init__.py b/benchmarks/harness/__init__.py
index 61fd25f73..15aa2a69a 100644
--- a/benchmarks/harness/__init__.py
+++ b/benchmarks/harness/__init__.py
@@ -1,5 +1,5 @@
# python3
-# Copyright 2019 Google LLC
+# Copyright 2019 The gVisor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,18 +15,48 @@
import getpass
import os
+import subprocess
+import tempfile
# LOCAL_WORKLOADS_PATH defines the path to use for local workloads. This is a
# format string that accepts a single string parameter.
-LOCAL_WORKLOADS_PATH = os.path.join(
- os.path.dirname(__file__), "../workloads/{}/tar.tar")
+LOCAL_WORKLOADS_PATH = os.path.dirname(__file__) + "/../workloads/{}/tar.tar"
# REMOTE_WORKLOADS_PATH defines the path to use for storing the workloads on the
# remote host. This is a format string that accepts a single string parameter.
REMOTE_WORKLOADS_PATH = "workloads/{}"
+# INSTALLER_ROOT is the set of files that needs to be copied.
+INSTALLER_ARCHIVE = os.readlink(os.path.join(
+ os.path.dirname(__file__), "installers.tar"))
+
+# SSH_KEY_DIR holds SSH_PRIVATE_KEY for this run. bm-tools paramiko requires
+# keys generated with the '-t rsa -m PEM' options from ssh-keygen. This is
+# abstracted away from the user.
+SSH_KEY_DIR = tempfile.TemporaryDirectory()
+SSH_PRIVATE_KEY = "key"
+
# DEFAULT_USER is the default user running this script.
DEFAULT_USER = getpass.getuser()
# DEFAULT_USER_HOME is the home directory of the user running the script.
DEFAULT_USER_HOME = os.environ["HOME"] if "HOME" in os.environ else ""
+
+# Default directory to remotely installer "installer" targets.
+REMOTE_INSTALLERS_PATH = "installers"
+
+
+def make_key():
+ """Wraps a valid ssh key in a temporary directory."""
+ path = os.path.join(SSH_KEY_DIR.name, SSH_PRIVATE_KEY)
+ if not os.path.exists(path):
+ cmd = "ssh-keygen -t rsa -m PEM -b 4096 -f {key} -q -N".format(
+ key=path).split(" ")
+ cmd.append("")
+ subprocess.run(cmd, check=True)
+ return path
+
+
+def delete_key():
+ """Deletes temporary directory containing private key."""
+ SSH_KEY_DIR.cleanup()
diff --git a/benchmarks/harness/machine.py b/benchmarks/harness/machine.py
index 2df4c9e31..5bdc4aa85 100644
--- a/benchmarks/harness/machine.py
+++ b/benchmarks/harness/machine.py
@@ -1,5 +1,5 @@
# python3
-# Copyright 2019 Google LLC
+# Copyright 2019 The gVisor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,10 +29,11 @@ to run contianers.
"""
import logging
+import os
import re
import subprocess
import time
-from typing import Tuple
+from typing import List, Tuple
import docker
@@ -42,6 +43,8 @@ from benchmarks.harness import machine_mocks
from benchmarks.harness import ssh_connection
from benchmarks.harness import tunnel_dispatcher
+log = logging.getLogger(__name__)
+
class Machine(object):
"""The machine object is the primary object for benchmarks.
@@ -201,6 +204,7 @@ class RemoteMachine(Machine):
self._tunnel = tunnel_dispatcher.Tunnel(name, **kwargs)
self._tunnel.connect()
self._docker_client = self._tunnel.get_docker_client()
+ self._has_installers = False
def run(self, cmd: str) -> Tuple[str, str]:
return self._ssh_connection.run(cmd)
@@ -210,14 +214,46 @@ class RemoteMachine(Machine):
stdout, stderr = self._ssh_connection.run("cat '{}'".format(path))
return stdout + stderr
+ def install(self,
+ installer: str,
+ results: List[bool] = None,
+ index: int = -1):
+ """Method unique to RemoteMachine to handle installation of installers.
+
+ Handles installers, which install things that may change between runs (e.g.
+ runsc). Usually called from gcloud_producer, which expects this method to
+ to store results.
+
+ Args:
+ installer: the installer target to run.
+ results: Passed by the caller of where to store success.
+ index: Index for this method to store the result in the passed results
+ list.
+ """
+ # This generates a tarball of the full installer root (which will generate
+ # be the full bazel root directory) and sends it over.
+ if not self._has_installers:
+ archive = self._ssh_connection.send_installers()
+ self.run("tar -xvf {archive} -C {dir}".format(
+ archive=archive, dir=harness.REMOTE_INSTALLERS_PATH))
+ self._has_installers = True
+
+ # Execute the remote installer.
+ self.run("sudo {dir}/{file}".format(
+ dir=harness.REMOTE_INSTALLERS_PATH, file=installer))
+
+ if results:
+ results[index] = True
+
def pull(self, workload: str) -> str:
# Push to the remote machine and build.
logging.info("Building %s@%s remotely...", workload, self._name)
remote_path = self._ssh_connection.send_workload(workload)
+ remote_dir = os.path.dirname(remote_path)
# Workloads are all tarballs.
- self.run("tar -xvf {remote_path}/tar.tar -C {remote_path}".format(
- remote_path=remote_path))
- self.run("docker build --tag={} {}".format(workload, remote_path))
+ self.run("tar -xvf {remote_path} -C {remote_dir}".format(
+ remote_path=remote_path, remote_dir=remote_dir))
+ self.run("docker build --tag={} {}".format(workload, remote_dir))
return workload # Workload is the tag.
def container(self, image: str, **kwargs) -> container.Container:
diff --git a/benchmarks/harness/machine_producers/BUILD b/benchmarks/harness/machine_producers/BUILD
index 48ea0ef39..81f19bd08 100644
--- a/benchmarks/harness/machine_producers/BUILD
+++ b/benchmarks/harness/machine_producers/BUILD
@@ -31,7 +31,10 @@ py_library(
deps = [
"//benchmarks/harness:machine",
"//benchmarks/harness/machine_producers:machine_producer",
- py_requirement("PyYAML", False),
+ py_requirement(
+ "PyYAML",
+ direct = False,
+ ),
],
)
@@ -76,5 +79,6 @@ py_test(
python_version = "PY3",
tags = [
"local",
+ "manual",
],
)
diff --git a/benchmarks/harness/machine_producers/gcloud_producer.py b/benchmarks/harness/machine_producers/gcloud_producer.py
index e0b77d52b..513d16e4f 100644
--- a/benchmarks/harness/machine_producers/gcloud_producer.py
+++ b/benchmarks/harness/machine_producers/gcloud_producer.py
@@ -1,5 +1,5 @@
# python3
-# Copyright 2019 Google LLC
+# Copyright 2019 The gVisor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -46,12 +46,11 @@ class GCloudProducer(machine_producer.MachineProducer):
Produces Machine objects backed by GCP instances.
Attributes:
- project: The GCP project name under which to create the machines.
- ssh_key_file: path to a valid ssh private key. See README on vaild ssh keys.
image: image name as a string.
- image_project: image project as a string.
- machine_type: type of GCP to create. e.g. n1-standard-4
zone: string to a valid GCP zone.
+ machine_type: type of GCP to create (e.g. n1-standard-4).
+ installers: list of installers post-boot.
+ ssh_key_file: path to a valid ssh private key. See README on vaild ssh keys.
ssh_user: string of user name for ssh_key
ssh_password: string of password for ssh key
mock: a mock printer which will print mock data if required. Mock data is
@@ -60,21 +59,19 @@ class GCloudProducer(machine_producer.MachineProducer):
"""
def __init__(self,
- project: str,
- ssh_key_file: str,
image: str,
- image_project: str,
- machine_type: str,
zone: str,
+ machine_type: str,
+ installers: List[str],
+ ssh_key_file: str,
ssh_user: str,
ssh_password: str,
mock: gcloud_mock_recorder.MockPrinter = None):
- self.project = project
- self.ssh_key_file = ssh_key_file
self.image = image
- self.image_project = image_project
- self.machine_type = machine_type
self.zone = zone
+ self.machine_type = machine_type
+ self.installers = installers
+ self.ssh_key_file = ssh_key_file
self.ssh_user = ssh_user
self.ssh_password = ssh_password
self.mock = mock
@@ -87,10 +84,34 @@ class GCloudProducer(machine_producer.MachineProducer):
"Cannot ask for {num} machines!".format(num=num_machines))
with self.condition:
names = self._get_unique_names(num_machines)
- self._build_instances(names)
- instances = self._start_command(names)
+ instances = self._build_instances(names)
self._add_ssh_key_to_instances(names)
- return self._machines_from_instances(instances)
+ machines = self._machines_from_instances(instances)
+
+ # Install all bits in lock-step.
+ #
+ # This will perform paralell installations for however many machines we
+ # have, but it's easy to track errors because if installing (a, b, c), we
+ # won't install "c" until "b" is installed on all machines.
+ for installer in self.installers:
+ threads = [None] * len(machines)
+ results = [False] * len(machines)
+ for i in range(len(machines)):
+ threads[i] = threading.Thread(
+ target=machines[i].install, args=(installer, results, i))
+ threads[i].start()
+ for thread in threads:
+ thread.join()
+ for result in results:
+ if not result:
+ raise NotImplementedError(
+ "Installers failed on at least one machine!")
+
+ # Add this user to each machine's docker group.
+ for m in machines:
+ m.run("sudo setfacl -m user:$USER:rw /var/run/docker.sock")
+
+ return machines
def release_machines(self, machine_list: List[machine.Machine]):
"""Releases the requested number of machines, deleting the instances."""
@@ -123,15 +144,7 @@ class GCloudProducer(machine_producer.MachineProducer):
def _get_unique_names(self, num_names) -> List[str]:
"""Returns num_names unique names based on data from the GCP project."""
- curr_machines = self._list_machines()
- curr_names = set([machine["name"] for machine in curr_machines])
- ret = []
- while len(ret) < num_names:
- new_name = "machine-" + str(uuid.uuid4())
- if new_name not in curr_names:
- ret.append(new_name)
- curr_names.update(new_name)
- return ret
+ return ["machine-" + str(uuid.uuid4()) for _ in range(0, num_names)]
def _build_instances(self, names: List[str]) -> List[Dict[str, Any]]:
"""Creates instances using gcloud command.
@@ -151,34 +164,9 @@ class GCloudProducer(machine_producer.MachineProducer):
"_build_instances cannot create instances without names.")
cmd = "gcloud compute instances create".split(" ")
cmd.extend(names)
- cmd.extend(
- "--preemptible --image={image} --zone={zone} --machine-type={machine_type}"
- .format(
- image=self.image, zone=self.zone,
- machine_type=self.machine_type).split(" "))
- if self.image_project:
- cmd.append("--image-project={project}".format(project=self.image_project))
- res = self._run_command(cmd)
- return json.loads(res.stdout)
-
- def _start_command(self, names):
- """Starts instances using gcloud command.
-
- Runs the command `gcloud compute instances start` on list of instances by
- name and returns json data on started instances on success.
-
- Args:
- names: list of names of instances to start.
-
- Returns:
- List of json data describing started machines.
- """
- if not names:
- raise ValueError("_start_command cannot start empty instance list.")
- cmd = "gcloud compute instances start".split(" ")
- cmd.extend(names)
- cmd.append("--zone={zone}".format(zone=self.zone))
- cmd.append("--project={project}".format(project=self.project))
+ cmd.append("--image=" + self.image)
+ cmd.append("--zone=" + self.zone)
+ cmd.append("--machine-type=" + self.machine_type)
res = self._run_command(cmd)
return json.loads(res.stdout)
@@ -186,7 +174,7 @@ class GCloudProducer(machine_producer.MachineProducer):
"""Adds ssh key to instances by calling gcloud ssh command.
Runs the command `gcloud compute ssh instance_name` on list of images by
- name. Tries to ssh into given instance
+ name. Tries to ssh into given instance.
Args:
names: list of machine names to which to add the ssh-key
@@ -202,30 +190,18 @@ class GCloudProducer(machine_producer.MachineProducer):
cmd.append("--ssh-key-file={key}".format(key=self.ssh_key_file))
cmd.append("--zone={zone}".format(zone=self.zone))
cmd.append("--command=uname")
+ cmd.append("--ssh-key-expire-after=60m")
timeout = datetime.timedelta(seconds=5 * 60)
start = datetime.datetime.now()
while datetime.datetime.now() <= timeout + start:
try:
self._run_command(cmd)
break
- except subprocess.CalledProcessError as e:
+ except subprocess.CalledProcessError:
if datetime.datetime.now() > timeout + start:
raise TimeoutError(
"Could not SSH into instance after 5 min: {name}".format(
name=name))
- # 255 is the returncode for ssh connection refused.
- elif e.returncode == 255:
-
- continue
- else:
- raise e
-
- def _list_machines(self) -> List[Dict[str, Any]]:
- """Runs `list` gcloud command and returns list of Machine data."""
- cmd = "gcloud compute instances list --project {project}".format(
- project=self.project).split(" ")
- res = self._run_command(cmd)
- return json.loads(res.stdout)
def _run_command(self,
cmd: List[str],
@@ -261,7 +237,7 @@ class GCloudProducer(machine_producer.MachineProducer):
self.mock.record(res)
if res.returncode != 0:
raise subprocess.CalledProcessError(
- cmd=res.args,
+ cmd=" ".join(res.args),
output=res.stdout,
stderr=res.stderr,
returncode=res.returncode)
diff --git a/benchmarks/harness/ssh_connection.py b/benchmarks/harness/ssh_connection.py
index e0bf258f1..b8c8e42d4 100644
--- a/benchmarks/harness/ssh_connection.py
+++ b/benchmarks/harness/ssh_connection.py
@@ -1,5 +1,5 @@
# python3
-# Copyright 2019 Google LLC
+# Copyright 2019 The gVisor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,6 +13,7 @@
# limitations under the License.
"""SSHConnection handles the details of SSH connections."""
+import logging
import os
import warnings
@@ -23,19 +24,27 @@ from benchmarks import harness
# Get rid of paramiko Cryptography Warnings.
warnings.filterwarnings(action="ignore", module=".*paramiko.*")
+log = logging.getLogger(__name__)
-def send_one_file(client: paramiko.SSHClient, path: str, remote_dir: str):
+
+def send_one_file(client: paramiko.SSHClient, path: str,
+ remote_dir: str) -> str:
"""Sends a single file via an SSH client.
Args:
client: The existing SSH client.
path: The local path.
remote_dir: The remote directory.
+
+ Returns:
+ :return: The remote path as a string.
"""
filename = path.split("/").pop()
- client.exec_command("mkdir -p " + remote_dir)
+ if remote_dir != ".":
+ client.exec_command("mkdir -p " + remote_dir)
with client.open_sftp() as ftp_client:
ftp_client.put(path, os.path.join(remote_dir, filename))
+ return os.path.join(remote_dir, filename)
class SSHConnection:
@@ -87,10 +96,13 @@ class SSHConnection:
The contents of stdout and stderr.
"""
with self._client() as client:
+ log.info("running command: %s", cmd)
_, stdout, stderr = client.exec_command(command=cmd)
- stdout.channel.recv_exit_status()
+ log.info("returned status: %d", stdout.channel.recv_exit_status())
stdout = stdout.read().decode("utf-8")
stderr = stderr.read().decode("utf-8")
+ log.info("stdout: %s", stdout)
+ log.info("stderr: %s", stderr)
return stdout, stderr
def send_workload(self, name: str) -> str:
@@ -103,6 +115,12 @@ class SSHConnection:
The remote path.
"""
with self._client() as client:
- send_one_file(client, harness.LOCAL_WORKLOADS_PATH.format(name),
- harness.REMOTE_WORKLOADS_PATH.format(name))
- return harness.REMOTE_WORKLOADS_PATH.format(name)
+ return send_one_file(client, harness.LOCAL_WORKLOADS_PATH.format(name),
+ harness.REMOTE_WORKLOADS_PATH.format(name))
+
+ def send_installers(self) -> str:
+ with self._client() as client:
+ return send_one_file(
+ client,
+ path=harness.INSTALLER_ARCHIVE,
+ remote_dir=harness.REMOTE_INSTALLERS_PATH)
diff --git a/benchmarks/runner/BUILD b/benchmarks/runner/BUILD
index fae0ca800..471debfdf 100644
--- a/benchmarks/runner/BUILD
+++ b/benchmarks/runner/BUILD
@@ -1,4 +1,5 @@
load("//tools:defs.bzl", "py_library", "py_requirement", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
package(licenses = ["notice"])
@@ -28,7 +29,7 @@ py_library(
"//benchmarks/suites:startup",
"//benchmarks/suites:sysbench",
"//benchmarks/suites:syscall",
- py_requirement("click", True),
+ py_requirement("click"),
],
)
@@ -36,7 +37,7 @@ py_library(
name = "commands",
srcs = ["commands.py"],
deps = [
- py_requirement("click", True),
+ py_requirement("click"),
],
)
@@ -48,16 +49,8 @@ py_test(
"local",
"manual",
],
- deps = [
+ deps = test_deps + [
":runner",
- py_requirement("click", True),
- py_requirement("attrs", False),
- py_requirement("atomicwrites", False),
- py_requirement("more-itertools", False),
- py_requirement("pathlib2", False),
- py_requirement("pluggy", False),
- py_requirement("py", False),
- py_requirement("pytest", True),
- py_requirement("six", False),
+ py_requirement("click"),
],
)
diff --git a/benchmarks/runner/__init__.py b/benchmarks/runner/__init__.py
index ba80d83d7..ba27dc69f 100644
--- a/benchmarks/runner/__init__.py
+++ b/benchmarks/runner/__init__.py
@@ -1,5 +1,5 @@
# python3
-# Copyright 2019 Google LLC
+# Copyright 2019 The gVisor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,13 +15,10 @@
import copy
import csv
-import json
import logging
-import os
import pkgutil
import pydoc
import re
-import subprocess
import sys
import types
from typing import List
@@ -123,57 +120,29 @@ def run_mock(ctx, **kwargs):
@runner.command("run-gcp", commands.GCPCommand)
@click.pass_context
-def run_gcp(ctx, project: str, ssh_key_file: str, image: str,
- image_project: str, machine_type: str, zone: str, ssh_user: str,
- ssh_password: str, **kwargs):
+def run_gcp(ctx, image_file: str, zone_file: str, machine_type: str,
+ installers: List[str], **kwargs):
"""Runs all benchmarks on GCP instances."""
- if not ssh_user:
- ssh_user = harness.DEFAULT_USER
-
- # Get the default project if one was not provided.
- if not project:
- sub = subprocess.run(
- "gcloud config get-value project".split(" "), stdout=subprocess.PIPE)
- if sub.returncode:
- raise ValueError(
- "Cannot get default project from gcloud. Is it configured>")
- project = sub.stdout.decode("utf-8").strip("\n")
-
- if not image_project:
- image_project = project
-
- # Check that the ssh-key exists and is readable.
- if not os.access(ssh_key_file, os.R_OK):
- raise ValueError(
- "ssh key given `{ssh_key}` is does not exist or is not readable."
- .format(ssh_key=ssh_key_file))
-
- # Check that the image exists.
- sub = subprocess.run(
- "gcloud compute images describe {image} --project {image_project} --format=json"
- .format(image=image, image_project=image_project).split(" "),
- stdout=subprocess.PIPE)
- if sub.returncode or "READY" not in json.loads(sub.stdout)["status"]:
- raise ValueError(
- "given image was not found or is not ready: {image} {image_project}."
- .format(image=image, image_project=image_project))
-
- # Check and set zone to default.
- if not zone:
- sub = subprocess.run(
- "gcloud config get-value compute/zone".split(" "),
- stdout=subprocess.PIPE)
- if sub.returncode:
- raise ValueError(
- "Default zone is not set in gcloud. Set one or pass a zone with the --zone flag."
- )
- zone = sub.stdout.decode("utf-8").strip("\n")
-
- producer = gcloud_producer.GCloudProducer(project, ssh_key_file, image,
- image_project, machine_type, zone,
- ssh_user, ssh_password)
- run(ctx, producer, **kwargs)
+ # Resolve all files.
+ image = open(image_file).read().rstrip()
+ zone = open(zone_file).read().rstrip()
+
+ key_file = harness.make_key()
+
+ producer = gcloud_producer.GCloudProducer(
+ image,
+ zone,
+ machine_type,
+ installers,
+ ssh_key_file=key_file,
+ ssh_user=harness.DEFAULT_USER,
+ ssh_password="")
+
+ try:
+ run(ctx, producer, **kwargs)
+ finally:
+ harness.delete_key()
def run(ctx, producer: machine_producer.MachineProducer, method: str, runs: int,
diff --git a/benchmarks/runner/commands.py b/benchmarks/runner/commands.py
index 7ab12fac6..0fccb2fad 100644
--- a/benchmarks/runner/commands.py
+++ b/benchmarks/runner/commands.py
@@ -1,5 +1,5 @@
# python3
-# Copyright 2019 Google LLC
+# Copyright 2019 The gVisor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,9 +22,9 @@ def run_mock(**kwargs):
# mock implementation
"""
-import click
+import os
-from benchmarks import harness
+import click
class RunCommand(click.core.Command):
@@ -90,46 +90,40 @@ class GCPCommand(RunCommand):
"""GCPCommand inherits all flags from RunCommand and adds flags for run_gcp method.
Attributes:
- project: GCP project
- ssh_key_path: path to the ssh-key to use for the run
- image: name of the image to build machines from
- image_project: GCP project under which to find image
- zone: a GCP zone (e.g. us-west1-b)
- ssh_user: username to use for the ssh-key
- ssh_password: password to use for the ssh-key
+ image_file: name of the image to build machines from
+ zone_file: a GCP zone (e.g. us-west1-b)
+ installers: named installers for post-create
+ machine_type: type of machine to create (e.g. n1-standard-4)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- project = click.core.Option(
- ("--project",),
- help="Project to run on if not default value given by 'gcloud config get-value project'."
+ image_file = click.core.Option(
+ ("--image_file",),
+ help="The file containing the image for VMs.",
+ default=os.path.join(
+ os.path.dirname(__file__), "../../tools/images/ubuntu1604.txt"),
+ )
+ zone_file = click.core.Option(
+ ("--zone_file",),
+ help="The file containing the GCP zone.",
+ default=os.path.join(
+ os.path.dirname(__file__), "../../tools/images/zone.txt"),
+ )
+ installers = click.core.Option(
+ ("--installers",),
+ help="The set of installers to use.",
+ multiple=True,
+ )
+ machine_type = click.core.Option(
+ ("--machine_type",),
+ help="Type to make all machines.",
+ default="n1-standard-4",
)
- ssh_key_path = click.core.Option(
- ("--ssh-key-file",),
- help="Path to a valid ssh private key to use. See README on generating a valid ssh key. Set to ~/.ssh/benchmark-tools by default.",
- default=harness.DEFAULT_USER_HOME + "/.ssh/benchmark-tools")
- image = click.core.Option(("--image",),
- help="The image on which to build VMs.",
- default="bm-tools-testing")
- image_project = click.core.Option(
- ("--image_project",),
- help="The project under which the image to be used is listed.",
- default="")
- machine_type = click.core.Option(("--machine_type",),
- help="Type to make all machines.",
- default="n1-standard-4")
- zone = click.core.Option(("--zone",),
- help="The GCP zone to run on.",
- default="")
- ssh_user = click.core.Option(("--ssh-user",),
- help="User for the ssh key.",
- default=harness.DEFAULT_USER)
- ssh_password = click.core.Option(("--ssh-password",),
- help="Password for the ssh key.",
- default="")
self.params.extend([
- project, ssh_key_path, image, image_project, machine_type, zone,
- ssh_user, ssh_password
+ image_file,
+ zone_file,
+ machine_type,
+ installers,
])
diff --git a/benchmarks/tcp/tcp_proxy.go b/benchmarks/tcp/tcp_proxy.go
index 72ada5700..73b7c4f5b 100644
--- a/benchmarks/tcp/tcp_proxy.go
+++ b/benchmarks/tcp/tcp_proxy.go
@@ -274,7 +274,7 @@ func (n netstackImpl) listen(port int) (net.Listener, error) {
NIC: nicID,
Port: uint16(port),
}
- listener, err := gonet.NewListener(n.s, addr, ipv4.ProtocolNumber)
+ listener, err := gonet.ListenTCP(n.s, addr, ipv4.ProtocolNumber)
if err != nil {
return nil, err
}
diff --git a/benchmarks/workloads/ab/BUILD b/benchmarks/workloads/ab/BUILD
index 4dd91ceb3..945ac7026 100644
--- a/benchmarks/workloads/ab/BUILD
+++ b/benchmarks/workloads/ab/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test")
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -14,16 +15,8 @@ py_test(
name = "ab_test",
srcs = ["ab_test.py"],
python_version = "PY3",
- deps = [
+ deps = test_deps + [
":ab",
- py_requirement("attrs", False),
- py_requirement("atomicwrites", False),
- py_requirement("more-itertools", False),
- py_requirement("pathlib2", False),
- py_requirement("pluggy", False),
- py_requirement("py", False),
- py_requirement("pytest", True),
- py_requirement("six", False),
],
)
diff --git a/benchmarks/workloads/absl/BUILD b/benchmarks/workloads/absl/BUILD
index 55dae3baa..bb1a308bf 100644
--- a/benchmarks/workloads/absl/BUILD
+++ b/benchmarks/workloads/absl/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test")
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -14,16 +15,8 @@ py_test(
name = "absl_test",
srcs = ["absl_test.py"],
python_version = "PY3",
- deps = [
+ deps = test_deps + [
":absl",
- py_requirement("attrs", False),
- py_requirement("atomicwrites", False),
- py_requirement("more-itertools", False),
- py_requirement("pathlib2", False),
- py_requirement("pluggy", False),
- py_requirement("py", False),
- py_requirement("pytest", True),
- py_requirement("six", False),
],
)
diff --git a/benchmarks/workloads/fio/BUILD b/benchmarks/workloads/fio/BUILD
index 7b78e8e75..24d909c53 100644
--- a/benchmarks/workloads/fio/BUILD
+++ b/benchmarks/workloads/fio/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test")
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -14,16 +15,8 @@ py_test(
name = "fio_test",
srcs = ["fio_test.py"],
python_version = "PY3",
- deps = [
+ deps = test_deps + [
":fio",
- py_requirement("attrs", False),
- py_requirement("atomicwrites", False),
- py_requirement("more-itertools", False),
- py_requirement("pathlib2", False),
- py_requirement("pluggy", False),
- py_requirement("py", False),
- py_requirement("pytest", True),
- py_requirement("six", False),
],
)
diff --git a/benchmarks/workloads/iperf/BUILD b/benchmarks/workloads/iperf/BUILD
index 570f40148..91b953718 100644
--- a/benchmarks/workloads/iperf/BUILD
+++ b/benchmarks/workloads/iperf/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test")
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -14,16 +15,8 @@ py_test(
name = "iperf_test",
srcs = ["iperf_test.py"],
python_version = "PY3",
- deps = [
+ deps = test_deps + [
":iperf",
- py_requirement("attrs", False),
- py_requirement("atomicwrites", False),
- py_requirement("more-itertools", False),
- py_requirement("pathlib2", False),
- py_requirement("pluggy", False),
- py_requirement("py", False),
- py_requirement("pytest", True),
- py_requirement("six", False),
],
)
diff --git a/benchmarks/workloads/redisbenchmark/BUILD b/benchmarks/workloads/redisbenchmark/BUILD
index f472a4443..147cfedd2 100644
--- a/benchmarks/workloads/redisbenchmark/BUILD
+++ b/benchmarks/workloads/redisbenchmark/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test")
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -14,16 +15,8 @@ py_test(
name = "redisbenchmark_test",
srcs = ["redisbenchmark_test.py"],
python_version = "PY3",
- deps = [
+ deps = test_deps + [
":redisbenchmark",
- py_requirement("attrs", False),
- py_requirement("atomicwrites", False),
- py_requirement("more-itertools", False),
- py_requirement("pathlib2", False),
- py_requirement("pluggy", False),
- py_requirement("py", False),
- py_requirement("pytest", True),
- py_requirement("six", False),
],
)
diff --git a/benchmarks/workloads/ruby/Gemfile.lock b/benchmarks/workloads/ruby/Gemfile.lock
index b44817bd3..da94aefc4 100644
--- a/benchmarks/workloads/ruby/Gemfile.lock
+++ b/benchmarks/workloads/ruby/Gemfile.lock
@@ -1,28 +1,41 @@
GEM
remote: https://rubygems.org/
specs:
+ activemerchant (1.105.0)
+ activesupport (>= 4.2)
+ builder (>= 2.1.2, < 4.0.0)
+ i18n (>= 0.6.9)
+ nokogiri (~> 1.4)
activesupport (5.2.3)
concurrent-ruby (~> 1.0, >= 1.0.2)
i18n (>= 0.7, < 2)
minitest (~> 5.1)
tzinfo (~> 1.1)
+ bcrypt (3.1.13)
+ builder (3.2.4)
cassandra-driver (3.2.3)
ione (~> 1.2)
concurrent-ruby (1.1.5)
+ ffi (1.12.2)
i18n (1.6.0)
concurrent-ruby (~> 1.0)
ione (1.2.4)
+ mini_portile2 (2.4.0)
minitest (5.11.3)
mustermann (1.0.3)
+ nokogiri (1.10.8)
+ mini_portile2 (~> 2.4.0)
pdf-core (0.7.0)
prawn (2.2.2)
pdf-core (~> 0.7.0)
ttfunk (~> 1.5)
puma (3.12.1)
- rack (2.0.7)
+ rack (2.2.2)
rack-protection (2.0.5)
rack
rake (12.3.2)
+ rbnacl (7.1.1)
+ ffi
redis (4.1.1)
ruby-fann (1.2.6)
sinatra (2.0.5)
@@ -43,9 +56,12 @@ PLATFORMS
ruby
DEPENDENCIES
+ activemerchant
+ bcrypt
cassandra-driver
puma
rake
+ rbnacl
redis
ruby-fann
sinatra
diff --git a/benchmarks/workloads/ruby_template/BUILD b/benchmarks/workloads/ruby_template/BUILD
index 59443b14a..72ed9403d 100644
--- a/benchmarks/workloads/ruby_template/BUILD
+++ b/benchmarks/workloads/ruby_template/BUILD
@@ -15,5 +15,4 @@ pkg_tar(
"index.erb",
"main.rb",
],
- strip_prefix = "third_party/gvisor/benchmarks/workloads/ruby_template",
)
diff --git a/benchmarks/workloads/ruby_template/Gemfile.lock b/benchmarks/workloads/ruby_template/Gemfile.lock
index dd8d56fb7..af03ed7fd 100644
--- a/benchmarks/workloads/ruby_template/Gemfile.lock
+++ b/benchmarks/workloads/ruby_template/Gemfile.lock
@@ -2,25 +2,25 @@ GEM
remote: https://rubygems.org/
specs:
mustermann (1.0.3)
- puma (3.12.0)
+ puma (3.12.2)
rack (2.0.6)
rack-protection (2.0.5)
rack
+ redis (4.1.0)
sinatra (2.0.5)
mustermann (~> 1.0)
rack (~> 2.0)
rack-protection (= 2.0.5)
tilt (~> 2.0)
tilt (2.0.9)
- redis (4.1.0)
PLATFORMS
ruby
DEPENDENCIES
puma
- sinatra
redis
+ sinatra
BUNDLED WITH
1.17.1 \ No newline at end of file
diff --git a/benchmarks/workloads/sysbench/BUILD b/benchmarks/workloads/sysbench/BUILD
index 3834af7ed..ab2556064 100644
--- a/benchmarks/workloads/sysbench/BUILD
+++ b/benchmarks/workloads/sysbench/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test")
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -14,16 +15,8 @@ py_test(
name = "sysbench_test",
srcs = ["sysbench_test.py"],
python_version = "PY3",
- deps = [
+ deps = test_deps + [
":sysbench",
- py_requirement("attrs", False),
- py_requirement("atomicwrites", False),
- py_requirement("more-itertools", False),
- py_requirement("pathlib2", False),
- py_requirement("pluggy", False),
- py_requirement("py", False),
- py_requirement("pytest", True),
- py_requirement("six", False),
],
)
diff --git a/benchmarks/workloads/syscall/BUILD b/benchmarks/workloads/syscall/BUILD
index dba4bb1e7..f8c43bca1 100644
--- a/benchmarks/workloads/syscall/BUILD
+++ b/benchmarks/workloads/syscall/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test")
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -14,16 +15,8 @@ py_test(
name = "syscall_test",
srcs = ["syscall_test.py"],
python_version = "PY3",
- deps = [
+ deps = test_deps + [
":syscall",
- py_requirement("attrs", False),
- py_requirement("atomicwrites", False),
- py_requirement("more-itertools", False),
- py_requirement("pathlib2", False),
- py_requirement("pluggy", False),
- py_requirement("py", False),
- py_requirement("pytest", True),
- py_requirement("six", False),
],
)
diff --git a/kokoro/benchmark_tests.cfg b/kokoro/benchmark_tests.cfg
new file mode 100644
index 000000000..c48518a05
--- /dev/null
+++ b/kokoro/benchmark_tests.cfg
@@ -0,0 +1,26 @@
+build_file : 'repo/scripts/benchmark.sh'
+
+
+before_action {
+ fetch_keystore {
+ keystore_resource {
+ keystore_config_id : 73898
+ keyname : 'kokoro-rbe-service-account'
+ },
+ }
+}
+
+env_vars {
+ key : 'PROJECT'
+ value : 'gvisor-kokoro-testing'
+}
+
+env_vars {
+ key : 'ZONE'
+ value : 'us-central1-b'
+}
+
+env_vars {
+ key : 'KOKORO_SERVICE_ACCOUNT'
+ value : '73898_kokoro-rbe-service-account'
+}
diff --git a/kokoro/kythe/generate_xrefs.sh b/kokoro/kythe/generate_xrefs.sh
index 7a0fbb3cd..323b0f77b 100644
--- a/kokoro/kythe/generate_xrefs.sh
+++ b/kokoro/kythe/generate_xrefs.sh
@@ -23,7 +23,7 @@ bazel version
python3 -V
-readonly KYTHE_VERSION='v0.0.39'
+readonly KYTHE_VERSION='v0.0.41'
readonly WORKDIR="$(mktemp -d)"
readonly KYTHE_DIR="${WORKDIR}/kythe-${KYTHE_VERSION}"
if [[ -n "$KOKORO_GIT_COMMIT" ]]; then
diff --git a/kokoro/packetdrill_tests.cfg b/kokoro/packetdrill_tests.cfg
new file mode 100644
index 000000000..258d7deb4
--- /dev/null
+++ b/kokoro/packetdrill_tests.cfg
@@ -0,0 +1,9 @@
+build_file: "repo/scripts/packetdrill_tests.sh"
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ }
+}
diff --git a/kokoro/runtime_tests/go1.12.cfg b/kokoro/runtime_tests/go1.12.cfg
new file mode 100644
index 000000000..fd4911e88
--- /dev/null
+++ b/kokoro/runtime_tests/go1.12.cfg
@@ -0,0 +1,16 @@
+build_file: "github/github/kokoro/runtime_tests/runtime_tests.sh"
+
+env_vars {
+ key: "RUNTIME_TEST_NAME"
+ value: "go1.12"
+}
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc"
+ regex: "**/runsc.*"
+ }
+} \ No newline at end of file
diff --git a/kokoro/runtime_tests/java11.cfg b/kokoro/runtime_tests/java11.cfg
new file mode 100644
index 000000000..7f8611a08
--- /dev/null
+++ b/kokoro/runtime_tests/java11.cfg
@@ -0,0 +1,16 @@
+build_file: "github/github/kokoro/runtime_tests/runtime_tests.sh"
+
+env_vars {
+ key: "RUNTIME_TEST_NAME"
+ value: "java11"
+}
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc"
+ regex: "**/runsc.*"
+ }
+} \ No newline at end of file
diff --git a/kokoro/runtime_tests/nodejs12.4.0.cfg b/kokoro/runtime_tests/nodejs12.4.0.cfg
new file mode 100644
index 000000000..c67ad5567
--- /dev/null
+++ b/kokoro/runtime_tests/nodejs12.4.0.cfg
@@ -0,0 +1,16 @@
+build_file: "github/github/kokoro/runtime_tests/runtime_tests.sh"
+
+env_vars {
+ key: "RUNTIME_TEST_NAME"
+ value: "nodejs12.4.0"
+}
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc"
+ regex: "**/runsc.*"
+ }
+} \ No newline at end of file
diff --git a/kokoro/runtime_tests/php7.3.6.cfg b/kokoro/runtime_tests/php7.3.6.cfg
new file mode 100644
index 000000000..f266c5e26
--- /dev/null
+++ b/kokoro/runtime_tests/php7.3.6.cfg
@@ -0,0 +1,16 @@
+build_file: "github/github/kokoro/runtime_tests/runtime_tests.sh"
+
+env_vars {
+ key: "RUNTIME_TEST_NAME"
+ value: "php7.3.6"
+}
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc"
+ regex: "**/runsc.*"
+ }
+} \ No newline at end of file
diff --git a/kokoro/runtime_tests/python3.7.3.cfg b/kokoro/runtime_tests/python3.7.3.cfg
new file mode 100644
index 000000000..574add152
--- /dev/null
+++ b/kokoro/runtime_tests/python3.7.3.cfg
@@ -0,0 +1,16 @@
+build_file: "github/github/kokoro/runtime_tests/runtime_tests.sh"
+
+env_vars {
+ key: "RUNTIME_TEST_NAME"
+ value: "python3.7.3"
+}
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc"
+ regex: "**/runsc.*"
+ }
+} \ No newline at end of file
diff --git a/scripts/runtime_tests.sh b/kokoro/runtime_tests/runtime_tests.sh
index 9ee991e42..73a58f806 100755
--- a/scripts/runtime_tests.sh
+++ b/kokoro/runtime_tests/runtime_tests.sh
@@ -14,7 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-source $(dirname $0)/common.sh
+# Run in the root of the repo.
+cd "$(dirname "$0")"
+cd "$(git rev-parse --show-toplevel)"
+
+source scripts/common.sh
if [ ! -v RUNTIME_TEST_NAME ]; then
echo 'Must set $RUNTIME_TEST_NAME' >&2
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index 1f3c0c687..322d1ccc4 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -17,6 +17,8 @@ go_library(
"dev.go",
"elf.go",
"epoll.go",
+ "epoll_amd64.go",
+ "epoll_arm64.go",
"errors.go",
"eventfd.go",
"exec.go",
@@ -28,6 +30,7 @@ go_library(
"futex.go",
"inotify.go",
"ioctl.go",
+ "ioctl_tun.go",
"ip.go",
"ipc.go",
"limits.go",
@@ -59,6 +62,7 @@ go_library(
"wait.go",
"xattr.go",
],
+ marshal = True,
visibility = ["//visibility:public"],
deps = [
"//pkg/abi",
diff --git a/pkg/abi/linux/dev.go b/pkg/abi/linux/dev.go
index 421e11256..89f9a793f 100644
--- a/pkg/abi/linux/dev.go
+++ b/pkg/abi/linux/dev.go
@@ -36,6 +36,9 @@ func DecodeDeviceID(rdev uint32) (uint16, uint32) {
//
// See Documentations/devices.txt and uapi/linux/major.h.
const (
+ // MEM_MAJOR is the major device number for "memory" character devices.
+ MEM_MAJOR = 1
+
// TTYAUX_MAJOR is the major device number for alternate TTY devices.
TTYAUX_MAJOR = 5
diff --git a/pkg/abi/linux/epoll.go b/pkg/abi/linux/epoll.go
index 0e881aa3c..1121a1a92 100644
--- a/pkg/abi/linux/epoll.go
+++ b/pkg/abi/linux/epoll.go
@@ -14,12 +14,9 @@
package linux
-// EpollEvent is equivalent to struct epoll_event from epoll(2).
-type EpollEvent struct {
- Events uint32
- Fd int32
- Data int32
-}
+import (
+ "gvisor.dev/gvisor/pkg/binary"
+)
// Event masks.
const (
@@ -60,3 +57,6 @@ const (
EPOLL_CTL_DEL = 0x2
EPOLL_CTL_MOD = 0x3
)
+
+// SizeOfEpollEvent is the size of EpollEvent struct.
+var SizeOfEpollEvent = int(binary.Size(EpollEvent{}))
diff --git a/pkg/usermem/usermem_unsafe.go b/pkg/abi/linux/epoll_amd64.go
index 876783e78..34ff18009 100644
--- a/pkg/usermem/usermem_unsafe.go
+++ b/pkg/abi/linux/epoll_amd64.go
@@ -12,16 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package usermem
+package linux
-import (
- "unsafe"
-)
-
-// stringFromImmutableBytes is equivalent to string(bs), except that it never
-// copies even if escape analysis can't prove that bs does not escape. This is
-// only valid if bs is never mutated after stringFromImmutableBytes returns.
-func stringFromImmutableBytes(bs []byte) string {
- // Compare strings.Builder.String().
- return *(*string)(unsafe.Pointer(&bs))
+// EpollEvent is equivalent to struct epoll_event from epoll(2).
+//
+// +marshal
+type EpollEvent struct {
+ Events uint32
+ // Linux makes struct epoll_event::data a __u64. We represent it as
+ // [2]int32 because, on amd64, Linux also makes struct epoll_event
+ // __attribute__((packed)), such that there is no padding between Events
+ // and Data.
+ Data [2]int32
}
diff --git a/pkg/abi/linux/epoll_arm64.go b/pkg/abi/linux/epoll_arm64.go
new file mode 100644
index 000000000..f86c35329
--- /dev/null
+++ b/pkg/abi/linux/epoll_arm64.go
@@ -0,0 +1,26 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// EpollEvent is equivalent to struct epoll_event from epoll(2).
+//
+// +marshal
+type EpollEvent struct {
+ Events uint32
+ // Linux makes struct epoll_event a __u64, necessitating 4 bytes of padding
+ // here.
+ _ int32
+ Data [2]int32
+}
diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go
index c3ab15a4f..e229ac21c 100644
--- a/pkg/abi/linux/file.go
+++ b/pkg/abi/linux/file.go
@@ -241,6 +241,8 @@ const (
)
// Statx represents struct statx.
+//
+// +marshal
type Statx struct {
Mask uint32
Blksize uint32
diff --git a/pkg/abi/linux/file_amd64.go b/pkg/abi/linux/file_amd64.go
index 9d307e840..6b72364ea 100644
--- a/pkg/abi/linux/file_amd64.go
+++ b/pkg/abi/linux/file_amd64.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build amd64
+
package linux
// Constants for open(2).
@@ -23,6 +25,8 @@ const (
)
// Stat represents struct stat.
+//
+// +marshal
type Stat struct {
Dev uint64
Ino uint64
diff --git a/pkg/abi/linux/file_arm64.go b/pkg/abi/linux/file_arm64.go
index 26a54f416..6492c9038 100644
--- a/pkg/abi/linux/file_arm64.go
+++ b/pkg/abi/linux/file_arm64.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build arm64
+
package linux
// Constants for open(2).
@@ -23,6 +25,8 @@ const (
)
// Stat represents struct stat.
+//
+// +marshal
type Stat struct {
Dev uint64
Ino uint64
diff --git a/pkg/abi/linux/fs.go b/pkg/abi/linux/fs.go
index 2c652baa2..158d2db5b 100644
--- a/pkg/abi/linux/fs.go
+++ b/pkg/abi/linux/fs.go
@@ -38,6 +38,8 @@ const (
)
// Statfs is struct statfs, from uapi/asm-generic/statfs.h.
+//
+// +marshal
type Statfs struct {
// Type is one of the filesystem magic values, defined above.
Type uint64
diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go
index 0e18db9ef..2062e6a4b 100644
--- a/pkg/abi/linux/ioctl.go
+++ b/pkg/abi/linux/ioctl.go
@@ -72,3 +72,29 @@ const (
SIOCGMIIPHY = 0x8947
SIOCGMIIREG = 0x8948
)
+
+// ioctl(2) directions. Used to calculate requests number.
+// Constants from asm-generic/ioctl.h.
+const (
+ _IOC_NONE = 0
+ _IOC_WRITE = 1
+ _IOC_READ = 2
+)
+
+// Constants from asm-generic/ioctl.h.
+const (
+ _IOC_NRBITS = 8
+ _IOC_TYPEBITS = 8
+ _IOC_SIZEBITS = 14
+ _IOC_DIRBITS = 2
+
+ _IOC_NRSHIFT = 0
+ _IOC_TYPESHIFT = _IOC_NRSHIFT + _IOC_NRBITS
+ _IOC_SIZESHIFT = _IOC_TYPESHIFT + _IOC_TYPEBITS
+ _IOC_DIRSHIFT = _IOC_SIZESHIFT + _IOC_SIZEBITS
+)
+
+// IOC outputs the result of _IOC macro in asm-generic/ioctl.h.
+func IOC(dir, typ, nr, size uint32) uint32 {
+ return uint32(dir)<<_IOC_DIRSHIFT | typ<<_IOC_TYPESHIFT | nr<<_IOC_NRSHIFT | size<<_IOC_SIZESHIFT
+}
diff --git a/pkg/abi/linux/ioctl_tun.go b/pkg/abi/linux/ioctl_tun.go
new file mode 100644
index 000000000..c59c9c136
--- /dev/null
+++ b/pkg/abi/linux/ioctl_tun.go
@@ -0,0 +1,29 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// ioctl(2) request numbers from linux/if_tun.h
+var (
+ TUNSETIFF = IOC(_IOC_WRITE, 'T', 202, 4)
+ TUNGETIFF = IOC(_IOC_READ, 'T', 210, 4)
+)
+
+// Flags from net/if_tun.h
+const (
+ IFF_TUN = 0x0001
+ IFF_TAP = 0x0002
+ IFF_NO_PI = 0x1000
+ IFF_NOFILTER = 0x1000
+)
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index 33fcc6c95..bd2e13ba1 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -198,6 +198,13 @@ type XTEntryMatch struct {
// SizeOfXTEntryMatch is the size of an XTEntryMatch.
const SizeOfXTEntryMatch = 32
+// KernelXTEntryMatch is identical to XTEntryMatch, but contains
+// variable-length Data field.
+type KernelXTEntryMatch struct {
+ XTEntryMatch
+ Data []byte
+}
+
// XTEntryTarget holds a target for a rule. For example, it can specify that
// packets matching the rule should DROP, ACCEPT, or use an extension target.
// iptables-extension(8) has a list of possible targets.
@@ -218,11 +225,14 @@ type XTEntryTarget struct {
// SizeOfXTEntryTarget is the size of an XTEntryTarget.
const SizeOfXTEntryTarget = 32
-// XTStandardTarget is a builtin target, one of ACCEPT, DROP, JUMP, QUEUE, or
-// RETURN. It corresponds to struct xt_standard_target in
+// XTStandardTarget is a built-in target, one of ACCEPT, DROP, JUMP, QUEUE,
+// RETURN, or jump. It corresponds to struct xt_standard_target in
// include/uapi/linux/netfilter/x_tables.h.
type XTStandardTarget struct {
- Target XTEntryTarget
+ Target XTEntryTarget
+ // A positive verdict indicates a jump, and is the offset from the
+ // start of the table to jump to. A negative value means one of the
+ // other built-in targets.
Verdict int32
_ [4]byte
}
@@ -340,3 +350,96 @@ func goString(cstring []byte) string {
}
return string(cstring)
}
+
+// XTTCP holds data for matching TCP packets. It corresponds to struct xt_tcp
+// in include/uapi/linux/netfilter/xt_tcpudp.h.
+type XTTCP struct {
+ // SourcePortStart specifies the inclusive start of the range of source
+ // ports to which the matcher applies.
+ SourcePortStart uint16
+
+ // SourcePortEnd specifies the inclusive end of the range of source ports
+ // to which the matcher applies.
+ SourcePortEnd uint16
+
+ // DestinationPortStart specifies the start of the destination port
+ // range to which the matcher applies.
+ DestinationPortStart uint16
+
+ // DestinationPortEnd specifies the end of the destination port
+ // range to which the matcher applies.
+ DestinationPortEnd uint16
+
+ // Option specifies that a particular TCP option must be set.
+ Option uint8
+
+ // FlagMask masks TCP flags when comparing to the FlagCompare byte. It allows
+ // for specification of which flags are important to the matcher.
+ FlagMask uint8
+
+ // FlagCompare, in combination with FlagMask, is used to match only packets
+ // that have certain flags set.
+ FlagCompare uint8
+
+ // InverseFlags flips the meaning of certain fields. See the
+ // TX_TCP_INV_* flags.
+ InverseFlags uint8
+}
+
+// SizeOfXTTCP is the size of an XTTCP.
+const SizeOfXTTCP = 12
+
+// Flags in XTTCP.InverseFlags. Corresponding constants are in
+// include/uapi/linux/netfilter/xt_tcpudp.h.
+const (
+ // Invert the meaning of SourcePortStart/End.
+ XT_TCP_INV_SRCPT = 0x01
+ // Invert the meaning of DestinationPortStart/End.
+ XT_TCP_INV_DSTPT = 0x02
+ // Invert the meaning of FlagCompare.
+ XT_TCP_INV_FLAGS = 0x04
+ // Invert the meaning of Option.
+ XT_TCP_INV_OPTION = 0x08
+ // Enable all flags.
+ XT_TCP_INV_MASK = 0x0F
+)
+
+// XTUDP holds data for matching UDP packets. It corresponds to struct xt_udp
+// in include/uapi/linux/netfilter/xt_tcpudp.h.
+type XTUDP struct {
+ // SourcePortStart is the inclusive start of the range of source ports
+ // to which the matcher applies.
+ SourcePortStart uint16
+
+ // SourcePortEnd is the inclusive end of the range of source ports to
+ // which the matcher applies.
+ SourcePortEnd uint16
+
+ // DestinationPortStart is the inclusive start of the destination port
+ // range to which the matcher applies.
+ DestinationPortStart uint16
+
+ // DestinationPortEnd is the inclusive end of the destination port
+ // range to which the matcher applies.
+ DestinationPortEnd uint16
+
+ // InverseFlags flips the meaning of certain fields. See the
+ // TX_UDP_INV_* flags.
+ InverseFlags uint8
+
+ _ uint8
+}
+
+// SizeOfXTUDP is the size of an XTUDP.
+const SizeOfXTUDP = 10
+
+// Flags in XTUDP.InverseFlags. Corresponding constants are in
+// include/uapi/linux/netfilter/xt_tcpudp.h.
+const (
+ // Invert the meaning of SourcePortStart/End.
+ XT_UDP_INV_SRCPT = 0x01
+ // Invert the meaning of DestinationPortStart/End.
+ XT_UDP_INV_DSTPT = 0x02
+ // Enable all flags.
+ XT_UDP_INV_MASK = 0x03
+)
diff --git a/pkg/abi/linux/signal.go b/pkg/abi/linux/signal.go
index c69b04ea9..1c330e763 100644
--- a/pkg/abi/linux/signal.go
+++ b/pkg/abi/linux/signal.go
@@ -115,6 +115,8 @@ const (
)
// SignalSet is a signal mask with a bit corresponding to each signal.
+//
+// +marshal
type SignalSet uint64
// SignalSetSize is the size in bytes of a SignalSet.
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index 766ee4014..4a14ef691 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -411,6 +411,15 @@ type ControlMessageCredentials struct {
GID uint32
}
+// A ControlMessageIPPacketInfo is IP_PKTINFO socket control message.
+//
+// ControlMessageIPPacketInfo represents struct in_pktinfo from linux/in.h.
+type ControlMessageIPPacketInfo struct {
+ NIC int32
+ LocalAddr InetAddr
+ DestinationAddr InetAddr
+}
+
// SizeOfControlMessageCredentials is the binary size of a
// ControlMessageCredentials struct.
var SizeOfControlMessageCredentials = int(binary.Size(ControlMessageCredentials{}))
@@ -431,6 +440,10 @@ const SizeOfControlMessageTOS = 1
// SizeOfControlMessageTClass is the size of an IPV6_TCLASS control message.
const SizeOfControlMessageTClass = 4
+// SizeOfControlMessageIPPacketInfo is the size of an IP_PKTINFO
+// control message.
+const SizeOfControlMessageIPPacketInfo = 12
+
// SCM_MAX_FD is the maximum number of FDs accepted in a single sendmsg call.
// From net/scm.h.
const SCM_MAX_FD = 253
diff --git a/pkg/abi/linux/time.go b/pkg/abi/linux/time.go
index 5c5a58cd4..e6860ed49 100644
--- a/pkg/abi/linux/time.go
+++ b/pkg/abi/linux/time.go
@@ -101,6 +101,8 @@ func NsecToTimeT(nsec int64) TimeT {
}
// Timespec represents struct timespec in <time.h>.
+//
+// +marshal
type Timespec struct {
Sec int64
Nsec int64
@@ -155,6 +157,8 @@ func DurationToTimespec(dur time.Duration) Timespec {
const SizeOfTimeval = 16
// Timeval represents struct timeval in <time.h>.
+//
+// +marshal
type Timeval struct {
Sec int64
Usec int64
@@ -228,6 +232,8 @@ type Tms struct {
type TimerID int32
// StatxTimestamp represents struct statx_timestamp.
+//
+// +marshal
type StatxTimestamp struct {
Sec int64
Nsec uint32
@@ -256,6 +262,8 @@ func NsecToStatxTimestamp(nsec int64) (ts StatxTimestamp) {
}
// Utime represents struct utimbuf used by utimes(2).
+//
+// +marshal
type Utime struct {
Actime int64
Modtime int64
diff --git a/pkg/abi/linux/xattr.go b/pkg/abi/linux/xattr.go
index a3b6406fa..99180b208 100644
--- a/pkg/abi/linux/xattr.go
+++ b/pkg/abi/linux/xattr.go
@@ -18,6 +18,7 @@ package linux
const (
XATTR_NAME_MAX = 255
XATTR_SIZE_MAX = 65536
+ XATTR_LIST_MAX = 65536
XATTR_CREATE = 1
XATTR_REPLACE = 2
diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD
index 3948074ba..1a30f6967 100644
--- a/pkg/atomicbitops/BUILD
+++ b/pkg/atomicbitops/BUILD
@@ -5,10 +5,10 @@ package(licenses = ["notice"])
go_library(
name = "atomicbitops",
srcs = [
- "atomic_bitops.go",
- "atomic_bitops_amd64.s",
- "atomic_bitops_arm64.s",
- "atomic_bitops_common.go",
+ "atomicbitops.go",
+ "atomicbitops_amd64.s",
+ "atomicbitops_arm64.s",
+ "atomicbitops_noasm.go",
],
visibility = ["//:sandbox"],
)
@@ -16,7 +16,7 @@ go_library(
go_test(
name = "atomicbitops_test",
size = "small",
- srcs = ["atomic_bitops_test.go"],
+ srcs = ["atomicbitops_test.go"],
library = ":atomicbitops",
deps = ["//pkg/sync"],
)
diff --git a/pkg/atomicbitops/atomic_bitops.go b/pkg/atomicbitops/atomicbitops.go
index fcc41a9ea..1be081719 100644
--- a/pkg/atomicbitops/atomic_bitops.go
+++ b/pkg/atomicbitops/atomicbitops.go
@@ -14,47 +14,34 @@
// +build amd64 arm64
-// Package atomicbitops provides basic bitwise operations in an atomic way.
-// The implementation on amd64 leverages the LOCK prefix directly instead of
-// relying on the generic cas primitives, and the arm64 leverages the LDAXR
-// and STLXR pair primitives.
+// Package atomicbitops provides extensions to the sync/atomic package.
//
-// WARNING: the bitwise ops provided in this package doesn't imply any memory
-// ordering. Using them to construct locks must employ proper memory barriers.
+// All read-modify-write operations implemented by this package have
+// acquire-release memory ordering (like sync/atomic).
package atomicbitops
-// AndUint32 atomically applies bitwise and operation to *addr with val.
+// AndUint32 atomically applies bitwise AND operation to *addr with val.
func AndUint32(addr *uint32, val uint32)
-// OrUint32 atomically applies bitwise or operation to *addr with val.
+// OrUint32 atomically applies bitwise OR operation to *addr with val.
func OrUint32(addr *uint32, val uint32)
-// XorUint32 atomically applies bitwise xor operation to *addr with val.
+// XorUint32 atomically applies bitwise XOR operation to *addr with val.
func XorUint32(addr *uint32, val uint32)
// CompareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns
// the value previously stored at addr.
func CompareAndSwapUint32(addr *uint32, old, new uint32) uint32
-// AndUint64 atomically applies bitwise and operation to *addr with val.
+// AndUint64 atomically applies bitwise AND operation to *addr with val.
func AndUint64(addr *uint64, val uint64)
-// OrUint64 atomically applies bitwise or operation to *addr with val.
+// OrUint64 atomically applies bitwise OR operation to *addr with val.
func OrUint64(addr *uint64, val uint64)
-// XorUint64 atomically applies bitwise xor operation to *addr with val.
+// XorUint64 atomically applies bitwise XOR operation to *addr with val.
func XorUint64(addr *uint64, val uint64)
// CompareAndSwapUint64 is like sync/atomic.CompareAndSwapUint64, but returns
// the value previously stored at addr.
func CompareAndSwapUint64(addr *uint64, old, new uint64) uint64
-
-// IncUnlessZeroInt32 increments the value stored at the given address and
-// returns true; unless the value stored in the pointer is zero, in which case
-// it is left unmodified and false is returned.
-func IncUnlessZeroInt32(addr *int32) bool
-
-// DecUnlessOneInt32 decrements the value stored at the given address and
-// returns true; unless the value stored in the pointer is 1, in which case it
-// is left unmodified and false is returned.
-func DecUnlessOneInt32(addr *int32) bool
diff --git a/pkg/atomicbitops/atomic_bitops_amd64.s b/pkg/atomicbitops/atomicbitops_amd64.s
index db0972001..54c887ee5 100644
--- a/pkg/atomicbitops/atomic_bitops_amd64.s
+++ b/pkg/atomicbitops/atomicbitops_amd64.s
@@ -75,41 +75,3 @@ TEXT ·CompareAndSwapUint64(SB),$0-32
CMPXCHGQ DX, 0(DI)
MOVQ AX, ret+24(FP)
RET
-
-TEXT ·IncUnlessZeroInt32(SB),NOSPLIT,$0-9
- MOVQ addr+0(FP), DI
- MOVL 0(DI), AX
-
-retry:
- TESTL AX, AX
- JZ fail
- LEAL 1(AX), DX
- LOCK
- CMPXCHGL DX, 0(DI)
- JNZ retry
-
- SETEQ ret+8(FP)
- RET
-
-fail:
- MOVB AX, ret+8(FP)
- RET
-
-TEXT ·DecUnlessOneInt32(SB),NOSPLIT,$0-9
- MOVQ addr+0(FP), DI
- MOVL 0(DI), AX
-
-retry:
- LEAL -1(AX), DX
- TESTL DX, DX
- JZ fail
- LOCK
- CMPXCHGL DX, 0(DI)
- JNZ retry
-
- SETEQ ret+8(FP)
- RET
-
-fail:
- MOVB DX, ret+8(FP)
- RET
diff --git a/pkg/atomicbitops/atomic_bitops_arm64.s b/pkg/atomicbitops/atomicbitops_arm64.s
index 97f8808c1..5c780851b 100644
--- a/pkg/atomicbitops/atomic_bitops_arm64.s
+++ b/pkg/atomicbitops/atomicbitops_arm64.s
@@ -50,7 +50,6 @@ TEXT ·CompareAndSwapUint32(SB),$0-20
MOVD addr+0(FP), R0
MOVW old+8(FP), R1
MOVW new+12(FP), R2
-
again:
LDAXRW (R0), R3
CMPW R1, R3
@@ -95,7 +94,6 @@ TEXT ·CompareAndSwapUint64(SB),$0-32
MOVD addr+0(FP), R0
MOVD old+8(FP), R1
MOVD new+16(FP), R2
-
again:
LDAXR (R0), R3
CMP R1, R3
@@ -105,35 +103,3 @@ again:
done:
MOVD R3, prev+24(FP)
RET
-
-TEXT ·IncUnlessZeroInt32(SB),NOSPLIT,$0-9
- MOVD addr+0(FP), R0
-
-again:
- LDAXRW (R0), R1
- CBZ R1, fail
- ADDW $1, R1
- STLXRW R1, (R0), R2
- CBNZ R2, again
- MOVW $1, R2
- MOVB R2, ret+8(FP)
- RET
-fail:
- MOVB ZR, ret+8(FP)
- RET
-
-TEXT ·DecUnlessOneInt32(SB),NOSPLIT,$0-9
- MOVD addr+0(FP), R0
-
-again:
- LDAXRW (R0), R1
- SUBSW $1, R1, R1
- BEQ fail
- STLXRW R1, (R0), R2
- CBNZ R2, again
- MOVW $1, R2
- MOVB R2, ret+8(FP)
- RET
-fail:
- MOVB ZR, ret+8(FP)
- RET
diff --git a/pkg/atomicbitops/atomic_bitops_common.go b/pkg/atomicbitops/atomicbitops_noasm.go
index 85163ad62..3b2898256 100644
--- a/pkg/atomicbitops/atomic_bitops_common.go
+++ b/pkg/atomicbitops/atomicbitops_noasm.go
@@ -20,7 +20,6 @@ import (
"sync/atomic"
)
-// AndUint32 atomically applies bitwise and operation to *addr with val.
func AndUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -31,7 +30,6 @@ func AndUint32(addr *uint32, val uint32) {
}
}
-// OrUint32 atomically applies bitwise or operation to *addr with val.
func OrUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -42,7 +40,6 @@ func OrUint32(addr *uint32, val uint32) {
}
}
-// XorUint32 atomically applies bitwise xor operation to *addr with val.
func XorUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -53,8 +50,6 @@ func XorUint32(addr *uint32, val uint32) {
}
}
-// CompareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns
-// the value previously stored at addr.
func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) {
for {
prev = atomic.LoadUint32(addr)
@@ -67,7 +62,6 @@ func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) {
}
}
-// AndUint64 atomically applies bitwise and operation to *addr with val.
func AndUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -78,7 +72,6 @@ func AndUint64(addr *uint64, val uint64) {
}
}
-// OrUint64 atomically applies bitwise or operation to *addr with val.
func OrUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -89,7 +82,6 @@ func OrUint64(addr *uint64, val uint64) {
}
}
-// XorUint64 atomically applies bitwise xor operation to *addr with val.
func XorUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -100,8 +92,6 @@ func XorUint64(addr *uint64, val uint64) {
}
}
-// CompareAndSwapUint64 is like sync/atomic.CompareAndSwapUint64, but returns
-// the value previously stored at addr.
func CompareAndSwapUint64(addr *uint64, old, new uint64) (prev uint64) {
for {
prev = atomic.LoadUint64(addr)
@@ -113,35 +103,3 @@ func CompareAndSwapUint64(addr *uint64, old, new uint64) (prev uint64) {
}
}
}
-
-// IncUnlessZeroInt32 increments the value stored at the given address and
-// returns true; unless the value stored in the pointer is zero, in which case
-// it is left unmodified and false is returned.
-func IncUnlessZeroInt32(addr *int32) bool {
- for {
- v := atomic.LoadInt32(addr)
- if v == 0 {
- return false
- }
-
- if atomic.CompareAndSwapInt32(addr, v, v+1) {
- return true
- }
- }
-}
-
-// DecUnlessOneInt32 decrements the value stored at the given address and
-// returns true; unless the value stored in the pointer is 1, in which case it
-// is left unmodified and false is returned.
-func DecUnlessOneInt32(addr *int32) bool {
- for {
- v := atomic.LoadInt32(addr)
- if v == 1 {
- return false
- }
-
- if atomic.CompareAndSwapInt32(addr, v, v-1) {
- return true
- }
- }
-}
diff --git a/pkg/atomicbitops/atomic_bitops_test.go b/pkg/atomicbitops/atomicbitops_test.go
index 9466d3e23..73af71bb4 100644
--- a/pkg/atomicbitops/atomic_bitops_test.go
+++ b/pkg/atomicbitops/atomicbitops_test.go
@@ -196,67 +196,3 @@ func TestCompareAndSwapUint64(t *testing.T) {
}
}
}
-
-func TestIncUnlessZeroInt32(t *testing.T) {
- for _, test := range []struct {
- initial int32
- final int32
- ret bool
- }{
- {
- initial: 0,
- final: 0,
- ret: false,
- },
- {
- initial: 1,
- final: 2,
- ret: true,
- },
- {
- initial: 2,
- final: 3,
- ret: true,
- },
- } {
- val := test.initial
- if got, want := IncUnlessZeroInt32(&val), test.ret; got != want {
- t.Errorf("For initial value of %d: incorrect return value: got %v, wanted %v", test.initial, got, want)
- }
- if got, want := val, test.final; got != want {
- t.Errorf("For initial value of %d: incorrect final value: got %d, wanted %d", test.initial, got, want)
- }
- }
-}
-
-func TestDecUnlessOneInt32(t *testing.T) {
- for _, test := range []struct {
- initial int32
- final int32
- ret bool
- }{
- {
- initial: 0,
- final: -1,
- ret: true,
- },
- {
- initial: 1,
- final: 1,
- ret: false,
- },
- {
- initial: 2,
- final: 1,
- ret: true,
- },
- } {
- val := test.initial
- if got, want := DecUnlessOneInt32(&val), test.ret; got != want {
- t.Errorf("For initial value of %d: incorrect return value: got %v, wanted %v", test.initial, got, want)
- }
- if got, want := val, test.final; got != want {
- t.Errorf("For initial value of %d: incorrect final value: got %d, wanted %d", test.initial, got, want)
- }
- }
-}
diff --git a/pkg/binary/binary.go b/pkg/binary/binary.go
index 631785f7b..25065aef9 100644
--- a/pkg/binary/binary.go
+++ b/pkg/binary/binary.go
@@ -254,3 +254,13 @@ func WriteUint64(w io.Writer, order binary.ByteOrder, num uint64) error {
_, err := w.Write(buf)
return err
}
+
+// AlignUp rounds a length up to an alignment. align must be a power of 2.
+func AlignUp(length int, align uint) int {
+ return (length + int(align) - 1) & ^(int(align) - 1)
+}
+
+// AlignDown rounds a length down to an alignment. align must be a power of 2.
+func AlignDown(length int, align uint) int {
+ return length & ^(int(align) - 1)
+}
diff --git a/pkg/cpuid/BUILD b/pkg/cpuid/BUILD
index 43a432190..d6cb1a549 100644
--- a/pkg/cpuid/BUILD
+++ b/pkg/cpuid/BUILD
@@ -7,6 +7,8 @@ go_library(
srcs = [
"cpu_amd64.s",
"cpuid.go",
+ "cpuid_arm64.go",
+ "cpuid_x86.go",
],
visibility = ["//:sandbox"],
deps = ["//pkg/log"],
@@ -15,7 +17,10 @@ go_library(
go_test(
name = "cpuid_test",
size = "small",
- srcs = ["cpuid_test.go"],
+ srcs = [
+ "cpuid_arm64_test.go",
+ "cpuid_x86_test.go",
+ ],
library = ":cpuid",
)
@@ -23,7 +28,7 @@ go_test(
name = "cpuid_parse_test",
size = "small",
srcs = [
- "cpuid_parse_test.go",
+ "cpuid_parse_x86_test.go",
],
library = ":cpuid",
tags = ["manual"],
diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go
index cf50ee53f..f7f9dbf86 100644
--- a/pkg/cpuid/cpuid.go
+++ b/pkg/cpuid/cpuid.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build i386 amd64
-
// Package cpuid provides basic functionality for creating and adjusting CPU
// feature sets.
//
@@ -21,1100 +19,20 @@
// known platform, or HostFeatureSet()) and then add, remove, and test for
// features as desired.
//
-// For example: Test for hardware extended state saving, and if we don't have
-// it, don't expose AVX, which cannot be saved with fxsave.
+// For example: on x86, test for hardware extended state saving, and if
+// we don't have it, don't expose AVX, which cannot be saved with fxsave.
//
// if !HostFeatureSet().HasFeature(X86FeatureXSAVE) {
// exposedFeatures.Remove(X86FeatureAVX)
// }
package cpuid
-import (
- "bytes"
- "fmt"
- "io/ioutil"
- "strconv"
- "strings"
-
- "gvisor.dev/gvisor/pkg/log"
-)
-
-// Common references for CPUID leaves and bits:
-//
-// Intel:
-// * Intel SDM Volume 2, Chapter 3.2 "CPUID" (more up-to-date)
-// * Intel Application Note 485 (more detailed)
-//
-// AMD:
-// * AMD64 APM Volume 3, Appendix 3 "Obtaining Processor Information ..."
-
// Feature is a unique identifier for a particular cpu feature. We just use an
-// int as a feature number on x86.
+// int as a feature number on x86 and arm64.
//
-// Features are numbered according to "blocks". Each block is 32 bits, and
-// feature bits from the same source (cpuid leaf/level) are in the same block.
-type Feature int
-
-// block is a collection of 32 Feature bits.
-type block int
-
-const blockSize = 32
-
-// Feature bits are numbered according to "blocks". Each block is 32 bits, and
+// On x86, features are numbered according to "blocks". Each block is 32 bits, and
// feature bits from the same source (cpuid leaf/level) are in the same block.
-func featureID(b block, bit int) Feature {
- return Feature(32*int(b) + bit)
-}
-
-// Block 0 constants are all of the "basic" feature bits returned by a cpuid in
-// ecx with eax=1.
-const (
- X86FeatureSSE3 Feature = iota
- X86FeaturePCLMULDQ
- X86FeatureDTES64
- X86FeatureMONITOR
- X86FeatureDSCPL
- X86FeatureVMX
- X86FeatureSMX
- X86FeatureEST
- X86FeatureTM2
- X86FeatureSSSE3 // Not a typo, "supplemental" SSE3.
- X86FeatureCNXTID
- X86FeatureSDBG
- X86FeatureFMA
- X86FeatureCX16
- X86FeatureXTPR
- X86FeaturePDCM
- _ // ecx bit 16 is reserved.
- X86FeaturePCID
- X86FeatureDCA
- X86FeatureSSE4_1
- X86FeatureSSE4_2
- X86FeatureX2APIC
- X86FeatureMOVBE
- X86FeaturePOPCNT
- X86FeatureTSCD
- X86FeatureAES
- X86FeatureXSAVE
- X86FeatureOSXSAVE
- X86FeatureAVX
- X86FeatureF16C
- X86FeatureRDRAND
- _ // ecx bit 31 is reserved.
-)
-
-// Block 1 constants are all of the "basic" feature bits returned by a cpuid in
-// edx with eax=1.
-const (
- X86FeatureFPU Feature = 32 + iota
- X86FeatureVME
- X86FeatureDE
- X86FeaturePSE
- X86FeatureTSC
- X86FeatureMSR
- X86FeaturePAE
- X86FeatureMCE
- X86FeatureCX8
- X86FeatureAPIC
- _ // edx bit 10 is reserved.
- X86FeatureSEP
- X86FeatureMTRR
- X86FeaturePGE
- X86FeatureMCA
- X86FeatureCMOV
- X86FeaturePAT
- X86FeaturePSE36
- X86FeaturePSN
- X86FeatureCLFSH
- _ // edx bit 20 is reserved.
- X86FeatureDS
- X86FeatureACPI
- X86FeatureMMX
- X86FeatureFXSR
- X86FeatureSSE
- X86FeatureSSE2
- X86FeatureSS
- X86FeatureHTT
- X86FeatureTM
- X86FeatureIA64
- X86FeaturePBE
-)
-
-// Block 2 bits are the "structured extended" features returned in ebx for
-// eax=7, ecx=0.
-const (
- X86FeatureFSGSBase Feature = 2*32 + iota
- X86FeatureTSC_ADJUST
- _ // ebx bit 2 is reserved.
- X86FeatureBMI1
- X86FeatureHLE
- X86FeatureAVX2
- X86FeatureFDP_EXCPTN_ONLY
- X86FeatureSMEP
- X86FeatureBMI2
- X86FeatureERMS
- X86FeatureINVPCID
- X86FeatureRTM
- X86FeatureCQM
- X86FeatureFPCSDS
- X86FeatureMPX
- X86FeatureRDT
- X86FeatureAVX512F
- X86FeatureAVX512DQ
- X86FeatureRDSEED
- X86FeatureADX
- X86FeatureSMAP
- X86FeatureAVX512IFMA
- X86FeaturePCOMMIT
- X86FeatureCLFLUSHOPT
- X86FeatureCLWB
- X86FeatureIPT // Intel processor trace.
- X86FeatureAVX512PF
- X86FeatureAVX512ER
- X86FeatureAVX512CD
- X86FeatureSHA
- X86FeatureAVX512BW
- X86FeatureAVX512VL
-)
-
-// Block 3 bits are the "extended" features returned in ecx for eax=7, ecx=0.
-const (
- X86FeaturePREFETCHWT1 Feature = 3*32 + iota
- X86FeatureAVX512VBMI
- X86FeatureUMIP
- X86FeaturePKU
- X86FeatureOSPKE
- X86FeatureWAITPKG
- X86FeatureAVX512_VBMI2
- _ // ecx bit 7 is reserved
- X86FeatureGFNI
- X86FeatureVAES
- X86FeatureVPCLMULQDQ
- X86FeatureAVX512_VNNI
- X86FeatureAVX512_BITALG
- X86FeatureTME
- X86FeatureAVX512_VPOPCNTDQ
- _ // ecx bit 15 is reserved
- X86FeatureLA57
- // ecx bits 17-21 are reserved
- _
- _
- _
- _
- _
- X86FeatureRDPID
- // ecx bits 23-24 are reserved
- _
- _
- X86FeatureCLDEMOTE
- _ // ecx bit 26 is reserved
- X86FeatureMOVDIRI
- X86FeatureMOVDIR64B
-)
-
-// Block 4 constants are for xsave capabilities in CPUID.(EAX=0DH,ECX=01H):EAX.
-// The CPUID leaf is available only if 'X86FeatureXSAVE' is present.
-const (
- X86FeatureXSAVEOPT Feature = 4*32 + iota
- X86FeatureXSAVEC
- X86FeatureXGETBV1
- X86FeatureXSAVES
- // EAX[31:4] are reserved.
-)
-
-// Block 5 constants are the extended feature bits in
-// CPUID.(EAX=0x80000001):ECX.
-const (
- X86FeatureLAHF64 Feature = 5*32 + iota
- X86FeatureCMP_LEGACY
- X86FeatureSVM
- X86FeatureEXTAPIC
- X86FeatureCR8_LEGACY
- X86FeatureLZCNT
- X86FeatureSSE4A
- X86FeatureMISALIGNSSE
- X86FeaturePREFETCHW
- X86FeatureOSVW
- X86FeatureIBS
- X86FeatureXOP
- X86FeatureSKINIT
- X86FeatureWDT
- _ // ecx bit 14 is reserved.
- X86FeatureLWP
- X86FeatureFMA4
- X86FeatureTCE
- _ // ecx bit 18 is reserved.
- _ // ecx bit 19 is reserved.
- _ // ecx bit 20 is reserved.
- X86FeatureTBM
- X86FeatureTOPOLOGY
- X86FeaturePERFCTR_CORE
- X86FeaturePERFCTR_NB
- _ // ecx bit 25 is reserved.
- X86FeatureBPEXT
- X86FeaturePERFCTR_TSC
- X86FeaturePERFCTR_LLC
- X86FeatureMWAITX
- // ECX[31:30] are reserved.
-)
-
-// Block 6 constants are the extended feature bits in
-// CPUID.(EAX=0x80000001):EDX.
//
-// These are sparse, and so the bit positions are assigned manually.
-const (
- // On AMD, EDX[24:23] | EDX[17:12] | EDX[9:0] are duplicate features
- // also defined in block 1 (in identical bit positions). Those features
- // are not listed here.
- block6DuplicateMask = 0x183f3ff
-
- X86FeatureSYSCALL Feature = 6*32 + 11
- X86FeatureNX Feature = 6*32 + 20
- X86FeatureMMXEXT Feature = 6*32 + 22
- X86FeatureFXSR_OPT Feature = 6*32 + 25
- X86FeatureGBPAGES Feature = 6*32 + 26
- X86FeatureRDTSCP Feature = 6*32 + 27
- X86FeatureLM Feature = 6*32 + 29
- X86Feature3DNOWEXT Feature = 6*32 + 30
- X86Feature3DNOW Feature = 6*32 + 31
-)
-
-// linuxBlockOrder defines the order in which linux organizes the feature
-// blocks. Linux also tracks feature bits in 32-bit blocks, but in an order
-// which doesn't match well here, so for the /proc/cpuinfo generation we simply
-// re-map the blocks to Linux's ordering and then go through the bits in each
-// block.
-var linuxBlockOrder = []block{1, 6, 0, 5, 2, 4, 3}
-
-// To make emulation of /proc/cpuinfo easy, these names match the names of the
-// basic features in Linux defined in arch/x86/kernel/cpu/capflags.c.
-var x86FeatureStrings = map[Feature]string{
- // Block 0.
- X86FeatureSSE3: "pni",
- X86FeaturePCLMULDQ: "pclmulqdq",
- X86FeatureDTES64: "dtes64",
- X86FeatureMONITOR: "monitor",
- X86FeatureDSCPL: "ds_cpl",
- X86FeatureVMX: "vmx",
- X86FeatureSMX: "smx",
- X86FeatureEST: "est",
- X86FeatureTM2: "tm2",
- X86FeatureSSSE3: "ssse3",
- X86FeatureCNXTID: "cid",
- X86FeatureSDBG: "sdbg",
- X86FeatureFMA: "fma",
- X86FeatureCX16: "cx16",
- X86FeatureXTPR: "xtpr",
- X86FeaturePDCM: "pdcm",
- X86FeaturePCID: "pcid",
- X86FeatureDCA: "dca",
- X86FeatureSSE4_1: "sse4_1",
- X86FeatureSSE4_2: "sse4_2",
- X86FeatureX2APIC: "x2apic",
- X86FeatureMOVBE: "movbe",
- X86FeaturePOPCNT: "popcnt",
- X86FeatureTSCD: "tsc_deadline_timer",
- X86FeatureAES: "aes",
- X86FeatureXSAVE: "xsave",
- X86FeatureAVX: "avx",
- X86FeatureF16C: "f16c",
- X86FeatureRDRAND: "rdrand",
-
- // Block 1.
- X86FeatureFPU: "fpu",
- X86FeatureVME: "vme",
- X86FeatureDE: "de",
- X86FeaturePSE: "pse",
- X86FeatureTSC: "tsc",
- X86FeatureMSR: "msr",
- X86FeaturePAE: "pae",
- X86FeatureMCE: "mce",
- X86FeatureCX8: "cx8",
- X86FeatureAPIC: "apic",
- X86FeatureSEP: "sep",
- X86FeatureMTRR: "mtrr",
- X86FeaturePGE: "pge",
- X86FeatureMCA: "mca",
- X86FeatureCMOV: "cmov",
- X86FeaturePAT: "pat",
- X86FeaturePSE36: "pse36",
- X86FeaturePSN: "pn",
- X86FeatureCLFSH: "clflush",
- X86FeatureDS: "dts",
- X86FeatureACPI: "acpi",
- X86FeatureMMX: "mmx",
- X86FeatureFXSR: "fxsr",
- X86FeatureSSE: "sse",
- X86FeatureSSE2: "sse2",
- X86FeatureSS: "ss",
- X86FeatureHTT: "ht",
- X86FeatureTM: "tm",
- X86FeatureIA64: "ia64",
- X86FeaturePBE: "pbe",
-
- // Block 2.
- X86FeatureFSGSBase: "fsgsbase",
- X86FeatureTSC_ADJUST: "tsc_adjust",
- X86FeatureBMI1: "bmi1",
- X86FeatureHLE: "hle",
- X86FeatureAVX2: "avx2",
- X86FeatureSMEP: "smep",
- X86FeatureBMI2: "bmi2",
- X86FeatureERMS: "erms",
- X86FeatureINVPCID: "invpcid",
- X86FeatureRTM: "rtm",
- X86FeatureCQM: "cqm",
- X86FeatureMPX: "mpx",
- X86FeatureRDT: "rdt_a",
- X86FeatureAVX512F: "avx512f",
- X86FeatureAVX512DQ: "avx512dq",
- X86FeatureRDSEED: "rdseed",
- X86FeatureADX: "adx",
- X86FeatureSMAP: "smap",
- X86FeatureCLWB: "clwb",
- X86FeatureAVX512PF: "avx512pf",
- X86FeatureAVX512ER: "avx512er",
- X86FeatureAVX512CD: "avx512cd",
- X86FeatureSHA: "sha_ni",
- X86FeatureAVX512BW: "avx512bw",
- X86FeatureAVX512VL: "avx512vl",
-
- // Block 3.
- X86FeatureAVX512VBMI: "avx512vbmi",
- X86FeatureUMIP: "umip",
- X86FeaturePKU: "pku",
- X86FeatureOSPKE: "ospke",
- X86FeatureWAITPKG: "waitpkg",
- X86FeatureAVX512_VBMI2: "avx512_vbmi2",
- X86FeatureGFNI: "gfni",
- X86FeatureVAES: "vaes",
- X86FeatureVPCLMULQDQ: "vpclmulqdq",
- X86FeatureAVX512_VNNI: "avx512_vnni",
- X86FeatureAVX512_BITALG: "avx512_bitalg",
- X86FeatureTME: "tme",
- X86FeatureAVX512_VPOPCNTDQ: "avx512_vpopcntdq",
- X86FeatureLA57: "la57",
- X86FeatureRDPID: "rdpid",
- X86FeatureCLDEMOTE: "cldemote",
- X86FeatureMOVDIRI: "movdiri",
- X86FeatureMOVDIR64B: "movdir64b",
-
- // Block 4.
- X86FeatureXSAVEOPT: "xsaveopt",
- X86FeatureXSAVEC: "xsavec",
- X86FeatureXGETBV1: "xgetbv1",
- X86FeatureXSAVES: "xsaves",
-
- // Block 5.
- X86FeatureLAHF64: "lahf_lm", // LAHF/SAHF in long mode
- X86FeatureCMP_LEGACY: "cmp_legacy",
- X86FeatureSVM: "svm",
- X86FeatureEXTAPIC: "extapic",
- X86FeatureCR8_LEGACY: "cr8_legacy",
- X86FeatureLZCNT: "abm", // Advanced bit manipulation
- X86FeatureSSE4A: "sse4a",
- X86FeatureMISALIGNSSE: "misalignsse",
- X86FeaturePREFETCHW: "3dnowprefetch",
- X86FeatureOSVW: "osvw",
- X86FeatureIBS: "ibs",
- X86FeatureXOP: "xop",
- X86FeatureSKINIT: "skinit",
- X86FeatureWDT: "wdt",
- X86FeatureLWP: "lwp",
- X86FeatureFMA4: "fma4",
- X86FeatureTCE: "tce",
- X86FeatureTBM: "tbm",
- X86FeatureTOPOLOGY: "topoext",
- X86FeaturePERFCTR_CORE: "perfctr_core",
- X86FeaturePERFCTR_NB: "perfctr_nb",
- X86FeatureBPEXT: "bpext",
- X86FeaturePERFCTR_TSC: "ptsc",
- X86FeaturePERFCTR_LLC: "perfctr_llc",
- X86FeatureMWAITX: "mwaitx",
-
- // Block 6.
- X86FeatureSYSCALL: "syscall",
- X86FeatureNX: "nx",
- X86FeatureMMXEXT: "mmxext",
- X86FeatureFXSR_OPT: "fxsr_opt",
- X86FeatureGBPAGES: "pdpe1gb",
- X86FeatureRDTSCP: "rdtscp",
- X86FeatureLM: "lm",
- X86Feature3DNOWEXT: "3dnowext",
- X86Feature3DNOW: "3dnow",
-}
-
-// These flags are parse only---they can be used for setting / unsetting the
-// flags, but will not get printed out in /proc/cpuinfo.
-var x86FeatureParseOnlyStrings = map[Feature]string{
- // Block 0.
- X86FeatureOSXSAVE: "osxsave",
-
- // Block 2.
- X86FeatureFDP_EXCPTN_ONLY: "fdp_excptn_only",
- X86FeatureFPCSDS: "fpcsds",
- X86FeatureIPT: "pt",
- X86FeatureCLFLUSHOPT: "clfushopt",
-
- // Block 3.
- X86FeaturePREFETCHWT1: "prefetchwt1",
-}
-
-// intelCacheDescriptors describe the caches and TLBs on the system. They are
-// returned in the registers for eax=2. Intel only.
-type intelCacheDescriptor uint8
-
-// Valid cache/TLB descriptors. All descriptors can be found in Intel SDM Vol.
-// 2, Ch. 3.2, "CPUID", Table 3-12 "Encoding of CPUID Leaf 2 Descriptors".
-const (
- intelNullDescriptor intelCacheDescriptor = 0
- intelNoTLBDescriptor intelCacheDescriptor = 0xfe
- intelNoCacheDescriptor intelCacheDescriptor = 0xff
-
- // Most descriptors omitted for brevity as they are currently unused.
-)
-
-// CacheType describes the type of a cache, as returned in eax[4:0] for eax=4.
-type CacheType uint8
-
-const (
- // cacheNull indicates that there are no more entries.
- cacheNull CacheType = iota
-
- // CacheData is a data cache.
- CacheData
-
- // CacheInstruction is an instruction cache.
- CacheInstruction
-
- // CacheUnified is a unified instruction and data cache.
- CacheUnified
-)
-
-// Cache describes the parameters of a single cache on the system.
-//
-// +stateify savable
-type Cache struct {
- // Level is the hierarchical level of this cache (L1, L2, etc).
- Level uint32
-
- // Type is the type of cache.
- Type CacheType
-
- // FullyAssociative indicates that entries may be placed in any block.
- FullyAssociative bool
-
- // Partitions is the number of physical partitions in the cache.
- Partitions uint32
-
- // Ways is the number of ways of associativity in the cache.
- Ways uint32
-
- // Sets is the number of sets in the cache.
- Sets uint32
-
- // InvalidateHierarchical indicates that WBINVD/INVD from threads
- // sharing this cache acts upon lower level caches for threads sharing
- // this cache.
- InvalidateHierarchical bool
-
- // Inclusive indicates that this cache is inclusive of lower cache
- // levels.
- Inclusive bool
-
- // DirectMapped indicates that this cache is directly mapped from
- // address, rather than using a hash function.
- DirectMapped bool
-}
-
-// Just a way to wrap cpuid function numbers.
-type cpuidFunction uint32
-
-// The constants below are the lower or "standard" cpuid functions, ordered as
-// defined by the hardware.
-const (
- vendorID cpuidFunction = iota // Returns vendor ID and largest standard function.
- featureInfo // Returns basic feature bits and processor signature.
- intelCacheDescriptors // Returns list of cache descriptors. Intel only.
- intelSerialNumber // Returns processor serial number (obsolete on new hardware). Intel only.
- intelDeterministicCacheParams // Returns deterministic cache information. Intel only.
- monitorMwaitParams // Returns information about monitor/mwait instructions.
- powerParams // Returns information about power management and thermal sensors.
- extendedFeatureInfo // Returns extended feature bits.
- _ // Function 0x8 is reserved.
- intelDCAParams // Returns direct cache access information. Intel only.
- intelPMCInfo // Returns information about performance monitoring features. Intel only.
- intelX2APICInfo // Returns core/logical processor topology. Intel only.
- _ // Function 0xc is reserved.
- xSaveInfo // Returns information about extended state management.
-)
-
-// The "extended" functions start at 0x80000000.
-const (
- extendedFunctionInfo cpuidFunction = 0x80000000 + iota // Returns highest available extended function in eax.
- extendedFeatures // Returns some extended feature bits in edx and ecx.
-)
-
-// These are the extended floating point state features. They are used to
-// enumerate floating point features in XCR0, XSTATE_BV, etc.
-const (
- XSAVEFeatureX87 = 1 << 0
- XSAVEFeatureSSE = 1 << 1
- XSAVEFeatureAVX = 1 << 2
- XSAVEFeatureBNDREGS = 1 << 3
- XSAVEFeatureBNDCSR = 1 << 4
- XSAVEFeatureAVX512op = 1 << 5
- XSAVEFeatureAVX512zmm0 = 1 << 6
- XSAVEFeatureAVX512zmm16 = 1 << 7
- XSAVEFeaturePKRU = 1 << 9
-)
-
-var cpuFreqMHz float64
-
-// x86FeaturesFromString includes features from x86FeatureStrings and
-// x86FeatureParseOnlyStrings.
-var x86FeaturesFromString = make(map[string]Feature)
-
-// FeatureFromString returns the Feature associated with the given feature
-// string plus a bool to indicate if it could find the feature.
-func FeatureFromString(s string) (Feature, bool) {
- f, b := x86FeaturesFromString[s]
- return f, b
-}
-
-// String implements fmt.Stringer.
-func (f Feature) String() string {
- if s := f.flagString(false); s != "" {
- return s
- }
-
- block := int(f) / 32
- bit := int(f) % 32
- return fmt.Sprintf("<cpuflag %d; block %d bit %d>", f, block, bit)
-}
-
-func (f Feature) flagString(cpuinfoOnly bool) string {
- if s, ok := x86FeatureStrings[f]; ok {
- return s
- }
- if !cpuinfoOnly {
- return x86FeatureParseOnlyStrings[f]
- }
- return ""
-}
-
-// FeatureSet is a set of Features for a CPU.
-//
-// +stateify savable
-type FeatureSet struct {
- // Set is the set of features that are enabled in this FeatureSet.
- Set map[Feature]bool
-
- // VendorID is the 12-char string returned in ebx:edx:ecx for eax=0.
- VendorID string
-
- // ExtendedFamily is part of the processor signature.
- ExtendedFamily uint8
-
- // ExtendedModel is part of the processor signature.
- ExtendedModel uint8
-
- // ProcessorType is part of the processor signature.
- ProcessorType uint8
-
- // Family is part of the processor signature.
- Family uint8
-
- // Model is part of the processor signature.
- Model uint8
-
- // SteppingID is part of the processor signature.
- SteppingID uint8
-
- // Caches describes the caches on the CPU.
- Caches []Cache
-
- // CacheLine is the size of a cache line in bytes.
- //
- // All caches use the same line size. This is not enforced in the CPUID
- // encoding, but is true on all known x86 processors.
- CacheLine uint32
-}
-
-// FlagsString prints out supported CPU flags. If cpuinfoOnly is true, it is
-// equivalent to the "flags" field in /proc/cpuinfo.
-func (fs *FeatureSet) FlagsString(cpuinfoOnly bool) string {
- var s []string
- for _, b := range linuxBlockOrder {
- for i := 0; i < blockSize; i++ {
- if f := featureID(b, i); fs.Set[f] {
- if fstr := f.flagString(cpuinfoOnly); fstr != "" {
- s = append(s, fstr)
- }
- }
- }
- }
- return strings.Join(s, " ")
-}
-
-// WriteCPUInfoTo is to generate a section of one cpu in /proc/cpuinfo. This is
-// a minimal /proc/cpuinfo, it is missing some fields like "microcode" that are
-// not always printed in Linux. The bogomips field is simply made up.
-func (fs FeatureSet) WriteCPUInfoTo(cpu uint, b *bytes.Buffer) {
- fmt.Fprintf(b, "processor\t: %d\n", cpu)
- fmt.Fprintf(b, "vendor_id\t: %s\n", fs.VendorID)
- fmt.Fprintf(b, "cpu family\t: %d\n", ((fs.ExtendedFamily<<4)&0xff)|fs.Family)
- fmt.Fprintf(b, "model\t\t: %d\n", ((fs.ExtendedModel<<4)&0xff)|fs.Model)
- fmt.Fprintf(b, "model name\t: %s\n", "unknown") // Unknown for now.
- fmt.Fprintf(b, "stepping\t: %s\n", "unknown") // Unknown for now.
- fmt.Fprintf(b, "cpu MHz\t\t: %.3f\n", cpuFreqMHz)
- fmt.Fprintln(b, "fpu\t\t: yes")
- fmt.Fprintln(b, "fpu_exception\t: yes")
- fmt.Fprintf(b, "cpuid level\t: %d\n", uint32(xSaveInfo)) // Same as ax in vendorID.
- fmt.Fprintln(b, "wp\t\t: yes")
- fmt.Fprintf(b, "flags\t\t: %s\n", fs.FlagsString(true))
- fmt.Fprintf(b, "bogomips\t: %.02f\n", cpuFreqMHz) // It's bogus anyway.
- fmt.Fprintf(b, "clflush size\t: %d\n", fs.CacheLine)
- fmt.Fprintf(b, "cache_alignment\t: %d\n", fs.CacheLine)
- fmt.Fprintf(b, "address sizes\t: %d bits physical, %d bits virtual\n", 46, 48)
- fmt.Fprintln(b, "power management:") // This is always here, but can be blank.
- fmt.Fprintln(b, "") // The /proc/cpuinfo file ends with an extra newline.
-}
-
-const (
- amdVendorID = "AuthenticAMD"
- intelVendorID = "GenuineIntel"
-)
-
-// AMD returns true if fs describes an AMD CPU.
-func (fs *FeatureSet) AMD() bool {
- return fs.VendorID == amdVendorID
-}
-
-// Intel returns true if fs describes an Intel CPU.
-func (fs *FeatureSet) Intel() bool {
- return fs.VendorID == intelVendorID
-}
-
-// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a
-// subset of the host feature set.
-type ErrIncompatible struct {
- message string
-}
-
-// Error implements error.
-func (e ErrIncompatible) Error() string {
- return e.message
-}
-
-// CheckHostCompatible returns nil if fs is a subset of the host feature set.
-func (fs *FeatureSet) CheckHostCompatible() error {
- hfs := HostFeatureSet()
-
- if diff := fs.Subtract(hfs); diff != nil {
- return ErrIncompatible{fmt.Sprintf("CPU feature set %v incompatible with host feature set %v (missing: %v)", fs.FlagsString(false), hfs.FlagsString(false), diff)}
- }
-
- // The size of a cache line must match, as it is critical to correctly
- // utilizing CLFLUSH. Other cache properties are allowed to change, as
- // they are not important to correctness.
- if fs.CacheLine != hfs.CacheLine {
- return ErrIncompatible{fmt.Sprintf("CPU cache line size %d incompatible with host cache line size %d", fs.CacheLine, hfs.CacheLine)}
- }
-
- return nil
-}
-
-// Helper to convert 3 regs into 12-byte vendor ID.
-func vendorIDFromRegs(bx, cx, dx uint32) string {
- bytes := make([]byte, 0, 12)
- for i := uint(0); i < 4; i++ {
- b := byte(bx >> (i * 8))
- bytes = append(bytes, b)
- }
-
- for i := uint(0); i < 4; i++ {
- b := byte(dx >> (i * 8))
- bytes = append(bytes, b)
- }
-
- for i := uint(0); i < 4; i++ {
- b := byte(cx >> (i * 8))
- bytes = append(bytes, b)
- }
- return string(bytes)
-}
-
-// ExtendedStateSize returns the number of bytes needed to save the "extended
-// state" for this processor and the boundary it must be aligned to. Extended
-// state includes floating point registers, and other cpu state that's not
-// associated with the normal task context.
-//
-// Note: We can save some space here with an optimization where we use a
-// smaller chunk of memory depending on features that are actually enabled.
-// Currently we just use the largest possible size for simplicity (which is
-// about 2.5K worst case, with avx512).
-func (fs *FeatureSet) ExtendedStateSize() (size, align uint) {
- if fs.UseXsave() {
- // Leaf 0 of xsaveinfo function returns the size for currently
- // enabled xsave features in ebx, the maximum size if all valid
- // features are saved with xsave in ecx, and valid XCR0 bits in
- // edx:eax.
- _, _, maxSize, _ := HostID(uint32(xSaveInfo), 0)
- return uint(maxSize), 64
- }
-
- // If we don't support xsave, we fall back to fxsave, which requires
- // 512 bytes aligned to 16 bytes.
- return 512, 16
-}
-
-// ValidXCR0Mask returns the bits that may be set to 1 in control register
-// XCR0.
-func (fs *FeatureSet) ValidXCR0Mask() uint64 {
- if !fs.UseXsave() {
- return 0
- }
- eax, _, _, edx := HostID(uint32(xSaveInfo), 0)
- return uint64(edx)<<32 | uint64(eax)
-}
-
-// vendorIDRegs returns the 3 register values used to construct the 12-byte
-// vendor ID string for eax=0.
-func (fs *FeatureSet) vendorIDRegs() (bx, dx, cx uint32) {
- for i := uint(0); i < 4; i++ {
- bx |= uint32(fs.VendorID[i]) << (i * 8)
- }
-
- for i := uint(0); i < 4; i++ {
- dx |= uint32(fs.VendorID[i+4]) << (i * 8)
- }
-
- for i := uint(0); i < 4; i++ {
- cx |= uint32(fs.VendorID[i+8]) << (i * 8)
- }
- return
-}
-
-// signature returns the signature dword that's returned in eax when eax=1.
-func (fs *FeatureSet) signature() uint32 {
- var s uint32
- s |= uint32(fs.SteppingID & 0xf)
- s |= uint32(fs.Model&0xf) << 4
- s |= uint32(fs.Family&0xf) << 8
- s |= uint32(fs.ProcessorType&0x3) << 12
- s |= uint32(fs.ExtendedModel&0xf) << 16
- s |= uint32(fs.ExtendedFamily&0xff) << 20
- return s
-}
-
-// Helper to deconstruct signature dword.
-func signatureSplit(v uint32) (ef, em, pt, f, m, sid uint8) {
- sid = uint8(v & 0xf)
- m = uint8(v>>4) & 0xf
- f = uint8(v>>8) & 0xf
- pt = uint8(v>>12) & 0x3
- em = uint8(v>>16) & 0xf
- ef = uint8(v >> 20)
- return
-}
-
-// Helper to convert blockwise feature bit masks into a set of features. Masks
-// must be provided in order for each block, without skipping them. If a block
-// does not matter for this feature set, 0 is specified.
-func setFromBlockMasks(blocks ...uint32) map[Feature]bool {
- s := make(map[Feature]bool)
- for b, blockMask := range blocks {
- for i := 0; i < blockSize; i++ {
- if blockMask&1 != 0 {
- s[featureID(block(b), i)] = true
- }
- blockMask >>= 1
- }
- }
- return s
-}
-
-// blockMask returns the 32-bit mask associated with a block of features.
-func (fs *FeatureSet) blockMask(b block) uint32 {
- var mask uint32
- for i := 0; i < blockSize; i++ {
- if fs.Set[featureID(b, i)] {
- mask |= 1 << uint(i)
- }
- }
- return mask
-}
-
-// Remove removes a Feature from a FeatureSet. It ignores features
-// that are not in the FeatureSet.
-func (fs *FeatureSet) Remove(feature Feature) {
- delete(fs.Set, feature)
-}
-
-// Add adds a Feature to a FeatureSet. It ignores duplicate features.
-func (fs *FeatureSet) Add(feature Feature) {
- fs.Set[feature] = true
-}
-
-// HasFeature tests whether or not a feature is in the given feature set.
-func (fs *FeatureSet) HasFeature(feature Feature) bool {
- return fs.Set[feature]
-}
-
-// Subtract returns the features present in fs that are not present in other.
-// If all features in fs are present in other, Subtract returns nil.
-func (fs *FeatureSet) Subtract(other *FeatureSet) (diff map[Feature]bool) {
- for f := range fs.Set {
- if !other.Set[f] {
- if diff == nil {
- diff = make(map[Feature]bool)
- }
- diff[f] = true
- }
- }
-
- return
-}
-
-// EmulateID emulates a cpuid instruction based on the feature set.
-func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) {
- switch cpuidFunction(origAx) {
- case vendorID:
- ax = uint32(xSaveInfo) // 0xd (xSaveInfo) is the highest function we support.
- bx, dx, cx = fs.vendorIDRegs()
- case featureInfo:
- // CLFLUSH line size is encoded in quadwords. Other fields in bx unsupported.
- bx = (fs.CacheLine / 8) << 8
- cx = fs.blockMask(block(0))
- dx = fs.blockMask(block(1))
- ax = fs.signature()
- case intelCacheDescriptors:
- if !fs.Intel() {
- // Reserved on non-Intel.
- return 0, 0, 0, 0
- }
-
- // "The least-significant byte in register EAX (register AL)
- // will always return 01H. Software should ignore this value
- // and not interpret it as an informational descriptor." - SDM
- //
- // We only support reporting cache parameters via
- // intelDeterministicCacheParams; report as much here.
- //
- // We do not support exposing TLB information at all.
- ax = 1 | (uint32(intelNoCacheDescriptor) << 8)
- case intelDeterministicCacheParams:
- if !fs.Intel() {
- // Reserved on non-Intel.
- return 0, 0, 0, 0
- }
-
- // cx is the index of the cache to describe.
- if int(origCx) >= len(fs.Caches) {
- return uint32(cacheNull), 0, 0, 0
- }
- c := fs.Caches[origCx]
-
- ax = uint32(c.Type)
- ax |= c.Level << 5
- ax |= 1 << 8 // Always claim the cache is "self-initializing".
- if c.FullyAssociative {
- ax |= 1 << 9
- }
- // Processor topology not supported.
-
- bx = fs.CacheLine - 1
- bx |= (c.Partitions - 1) << 12
- bx |= (c.Ways - 1) << 22
-
- cx = c.Sets - 1
-
- if !c.InvalidateHierarchical {
- dx |= 1
- }
- if c.Inclusive {
- dx |= 1 << 1
- }
- if !c.DirectMapped {
- dx |= 1 << 2
- }
- case xSaveInfo:
- if !fs.UseXsave() {
- return 0, 0, 0, 0
- }
- return HostID(uint32(xSaveInfo), origCx)
- case extendedFeatureInfo:
- if origCx != 0 {
- break // Only leaf 0 is supported.
- }
- bx = fs.blockMask(block(2))
- cx = fs.blockMask(block(3))
- case extendedFunctionInfo:
- // We only support showing the extended features.
- ax = uint32(extendedFeatures)
- cx = 0
- case extendedFeatures:
- cx = fs.blockMask(block(5))
- dx = fs.blockMask(block(6))
- if fs.AMD() {
- // AMD duplicates some block 1 features in block 6.
- dx |= fs.blockMask(block(1)) & block6DuplicateMask
- }
- }
-
- return
-}
-
-// UseXsave returns the choice of fp state saving instruction.
-func (fs *FeatureSet) UseXsave() bool {
- return fs.HasFeature(X86FeatureXSAVE) && fs.HasFeature(X86FeatureOSXSAVE)
-}
-
-// UseXsaveopt returns true if 'fs' supports the "xsaveopt" instruction.
-func (fs *FeatureSet) UseXsaveopt() bool {
- return fs.UseXsave() && fs.HasFeature(X86FeatureXSAVEOPT)
-}
-
-// HostID executes a native CPUID instruction.
-func HostID(axArg, cxArg uint32) (ax, bx, cx, dx uint32)
-
-// HostFeatureSet uses cpuid to get host values and construct a feature set
-// that matches that of the host machine. Note that there are several places
-// where there appear to be some unnecessary assignments between register names
-// (ax, bx, cx, or dx) and featureBlockN variables. This is to explicitly show
-// where the different feature blocks come from, to make the code easier to
-// inspect and read.
-func HostFeatureSet() *FeatureSet {
- // eax=0 gets max supported feature and vendor ID.
- _, bx, cx, dx := HostID(0, 0)
- vendorID := vendorIDFromRegs(bx, cx, dx)
-
- // eax=1 gets basic features in ecx:edx.
- ax, bx, cx, dx := HostID(1, 0)
- featureBlock0 := cx
- featureBlock1 := dx
- ef, em, pt, f, m, sid := signatureSplit(ax)
- cacheLine := 8 * (bx >> 8) & 0xff
-
- // eax=4, ecx=i gets details about cache index i. Only supported on Intel.
- var caches []Cache
- if vendorID == intelVendorID {
- // ecx selects the cache index until a null type is returned.
- for i := uint32(0); ; i++ {
- ax, bx, cx, dx := HostID(4, i)
- t := CacheType(ax & 0xf)
- if t == cacheNull {
- break
- }
-
- lineSize := (bx & 0xfff) + 1
- if lineSize != cacheLine {
- panic(fmt.Sprintf("Mismatched cache line size: %d vs %d", lineSize, cacheLine))
- }
-
- caches = append(caches, Cache{
- Type: t,
- Level: (ax >> 5) & 0x7,
- FullyAssociative: ((ax >> 9) & 1) == 1,
- Partitions: ((bx >> 12) & 0x3ff) + 1,
- Ways: ((bx >> 22) & 0x3ff) + 1,
- Sets: cx + 1,
- InvalidateHierarchical: (dx & 1) == 0,
- Inclusive: ((dx >> 1) & 1) == 1,
- DirectMapped: ((dx >> 2) & 1) == 0,
- })
- }
- }
-
- // eax=7, ecx=0 gets extended features in ecx:ebx.
- _, bx, cx, _ = HostID(7, 0)
- featureBlock2 := bx
- featureBlock3 := cx
-
- // Leaf 0xd is supported only if CPUID.1:ECX.XSAVE[bit 26] is set.
- var featureBlock4 uint32
- if (featureBlock0 & (1 << 26)) != 0 {
- featureBlock4, _, _, _ = HostID(uint32(xSaveInfo), 1)
- }
-
- // eax=0x80000000 gets supported extended levels. We use this to
- // determine if there are any non-zero block 4 or block 6 bits to find.
- var featureBlock5, featureBlock6 uint32
- if ax, _, _, _ := HostID(uint32(extendedFunctionInfo), 0); ax >= uint32(extendedFeatures) {
- // eax=0x80000001 gets AMD added feature bits.
- _, _, cx, dx = HostID(uint32(extendedFeatures), 0)
- featureBlock5 = cx
- // Ignore features duplicated from block 1 on AMD. These bits
- // are reserved on Intel.
- featureBlock6 = dx &^ block6DuplicateMask
- }
-
- set := setFromBlockMasks(featureBlock0, featureBlock1, featureBlock2, featureBlock3, featureBlock4, featureBlock5, featureBlock6)
- return &FeatureSet{
- Set: set,
- VendorID: vendorID,
- ExtendedFamily: ef,
- ExtendedModel: em,
- ProcessorType: pt,
- Family: f,
- Model: m,
- SteppingID: sid,
- CacheLine: cacheLine,
- Caches: caches,
- }
-}
-
-// Reads max cpu frequency from host /proc/cpuinfo. Must run before
-// whitelisting. This value is used to create the fake /proc/cpuinfo from a
-// FeatureSet.
-func initCPUFreq() {
- cpuinfob, err := ioutil.ReadFile("/proc/cpuinfo")
- if err != nil {
- // Leave it as 0... The standalone VDSO bails out in the same
- // way.
- log.Warningf("Could not read /proc/cpuinfo: %v", err)
- return
- }
- cpuinfo := string(cpuinfob)
-
- // We get the value straight from host /proc/cpuinfo. On machines with
- // frequency scaling enabled, this will only get the current value
- // which will likely be inaccurate. This is fine on machines with
- // frequency scaling disabled.
- for _, line := range strings.Split(cpuinfo, "\n") {
- if strings.Contains(line, "cpu MHz") {
- splitMHz := strings.Split(line, ":")
- if len(splitMHz) < 2 {
- log.Warningf("Could not read /proc/cpuinfo: malformed cpu MHz line")
- return
- }
-
- // If there was a problem, leave cpuFreqMHz as 0.
- var err error
- cpuFreqMHz, err = strconv.ParseFloat(strings.TrimSpace(splitMHz[1]), 64)
- if err != nil {
- log.Warningf("Could not parse cpu MHz value %v: %v", splitMHz[1], err)
- cpuFreqMHz = 0
- return
- }
- return
- }
- }
- log.Warningf("Could not parse /proc/cpuinfo, it is empty or does not contain cpu MHz")
-}
-
-func initFeaturesFromString() {
- for f, s := range x86FeatureStrings {
- x86FeaturesFromString[s] = f
- }
- for f, s := range x86FeatureParseOnlyStrings {
- x86FeaturesFromString[s] = f
- }
-}
-
-func init() {
- // initCpuFreq must be run before whitelists are enabled.
- initCPUFreq()
- initFeaturesFromString()
-}
+// On arm64, features are numbered according to the ELF HWCAP definition.
+// arch/arm64/include/uapi/asm/hwcap.h
+type Feature int
diff --git a/pkg/cpuid/cpuid_arm64.go b/pkg/cpuid/cpuid_arm64.go
new file mode 100644
index 000000000..08381c1c0
--- /dev/null
+++ b/pkg/cpuid/cpuid_arm64.go
@@ -0,0 +1,482 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package cpuid
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "io/ioutil"
+ "strconv"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// ARM64 doesn't have a 'cpuid' equivalent, which means it have no architected
+// discovery mechanism for hardware features available to userspace code at EL0.
+// The kernel exposes the presence of these features to userspace through a set
+// of flags(HWCAP/HWCAP2) bits, exposed in the auxilliary vector.
+// Ref Documentation/arm64/elf_hwcaps.rst for more info.
+//
+// Currently, only the HWCAP bits are supported.
+
+const (
+ // ARM64FeatureFP indicates support for single and double precision
+ // float point types.
+ ARM64FeatureFP Feature = iota
+
+ // ARM64FeatureASIMD indicates support for Advanced SIMD with single
+ // and double precision float point arithmetic.
+ ARM64FeatureASIMD
+
+ // ARM64FeatureEVTSTRM indicates support for the generic timer
+ // configured to generate events at a frequency of approximately
+ // 100KHz.
+ ARM64FeatureEVTSTRM
+
+ // ARM64FeatureAES indicates support for AES instructions
+ // (AESE/AESD/AESMC/AESIMC).
+ ARM64FeatureAES
+
+ // ARM64FeaturePMULL indicates support for AES instructions
+ // (PMULL/PMULL2).
+ ARM64FeaturePMULL
+
+ // ARM64FeatureSHA1 indicates support for SHA1 instructions
+ // (SHA1C/SHA1P/SHA1M etc).
+ ARM64FeatureSHA1
+
+ // ARM64FeatureSHA2 indicates support for SHA2 instructions
+ // (SHA256H/SHA256H2/SHA256SU0 etc).
+ ARM64FeatureSHA2
+
+ // ARM64FeatureCRC32 indicates support for CRC32 instructions
+ // (CRC32B/CRC32H/CRC32W etc).
+ ARM64FeatureCRC32
+
+ // ARM64FeatureATOMICS indicates support for atomic instructions
+ // (LDADD/LDCLR/LDEOR/LDSET etc).
+ ARM64FeatureATOMICS
+
+ // ARM64FeatureFPHP indicates support for half precision float point
+ // arithmetic.
+ ARM64FeatureFPHP
+
+ // ARM64FeatureASIMDHP indicates support for ASIMD with half precision
+ // float point arithmetic.
+ ARM64FeatureASIMDHP
+
+ // ARM64FeatureCPUID indicates support for EL0 access to certain ID
+ // registers is available.
+ ARM64FeatureCPUID
+
+ // ARM64FeatureASIMDRDM indicates support for SQRDMLAH and SQRDMLSH
+ // instructions.
+ ARM64FeatureASIMDRDM
+
+ // ARM64FeatureJSCVT indicates support for the FJCVTZS instruction.
+ ARM64FeatureJSCVT
+
+ // ARM64FeatureFCMA indicates support for the FCMLA and FCADD
+ // instructions.
+ ARM64FeatureFCMA
+
+ // ARM64FeatureLRCPC indicates support for the LDAPRB/LDAPRH/LDAPR
+ // instructions.
+ ARM64FeatureLRCPC
+
+ // ARM64FeatureDCPOP indicates support for DC instruction (DC CVAP).
+ ARM64FeatureDCPOP
+
+ // ARM64FeatureSHA3 indicates support for SHA3 instructions
+ // (EOR3/RAX1/XAR/BCAX).
+ ARM64FeatureSHA3
+
+ // ARM64FeatureSM3 indicates support for SM3 instructions
+ // (SM3SS1/SM3TT1A/SM3TT1B).
+ ARM64FeatureSM3
+
+ // ARM64FeatureSM4 indicates support for SM4 instructions
+ // (SM4E/SM4EKEY).
+ ARM64FeatureSM4
+
+ // ARM64FeatureASIMDDP indicates support for dot product instructions
+ // (UDOT/SDOT).
+ ARM64FeatureASIMDDP
+
+ // ARM64FeatureSHA512 indicates support for SHA2 instructions
+ // (SHA512H/SHA512H2/SHA512SU0).
+ ARM64FeatureSHA512
+
+ // ARM64FeatureSVE indicates support for Scalable Vector Extension.
+ ARM64FeatureSVE
+
+ // ARM64FeatureASIMDFHM indicates support for FMLAL and FMLSL
+ // instructions.
+ ARM64FeatureASIMDFHM
+)
+
+// ELF auxiliary vector tags
+const (
+ _AT_NULL = 0 // End of vector
+ _AT_HWCAP = 16 // hardware capability bit vector
+ _AT_HWCAP2 = 26 // hardware capability bit vector 2
+)
+
+// These should not be changed after they are initialized.
+var hwCap uint
+
+// To make emulation of /proc/cpuinfo easy, these names match the names of the
+// basic features in Linux defined in arch/arm64/kernel/cpuinfo.c.
+var arm64FeatureStrings = map[Feature]string{
+ ARM64FeatureFP: "fp",
+ ARM64FeatureASIMD: "asimd",
+ ARM64FeatureEVTSTRM: "evtstrm",
+ ARM64FeatureAES: "aes",
+ ARM64FeaturePMULL: "pmull",
+ ARM64FeatureSHA1: "sha1",
+ ARM64FeatureSHA2: "sha2",
+ ARM64FeatureCRC32: "crc32",
+ ARM64FeatureATOMICS: "atomics",
+ ARM64FeatureFPHP: "fphp",
+ ARM64FeatureASIMDHP: "asimdhp",
+ ARM64FeatureCPUID: "cpuid",
+ ARM64FeatureASIMDRDM: "asimdrdm",
+ ARM64FeatureJSCVT: "jscvt",
+ ARM64FeatureFCMA: "fcma",
+ ARM64FeatureLRCPC: "lrcpc",
+ ARM64FeatureDCPOP: "dcpop",
+ ARM64FeatureSHA3: "sha3",
+ ARM64FeatureSM3: "sm3",
+ ARM64FeatureSM4: "sm4",
+ ARM64FeatureASIMDDP: "asimddp",
+ ARM64FeatureSHA512: "sha512",
+ ARM64FeatureSVE: "sve",
+ ARM64FeatureASIMDFHM: "asimdfhm",
+}
+
+var (
+ cpuFreqMHz float64
+ cpuImplHex uint64
+ cpuArchDec uint64
+ cpuVarHex uint64
+ cpuPartHex uint64
+ cpuRevDec uint64
+)
+
+// arm64FeaturesFromString includes features from arm64FeatureStrings.
+var arm64FeaturesFromString = make(map[string]Feature)
+
+// FeatureFromString returns the Feature associated with the given feature
+// string plus a bool to indicate if it could find the feature.
+func FeatureFromString(s string) (Feature, bool) {
+ f, b := arm64FeaturesFromString[s]
+ return f, b
+}
+
+// String implements fmt.Stringer.
+func (f Feature) String() string {
+ if s := f.flagString(); s != "" {
+ return s
+ }
+
+ return fmt.Sprintf("<cpuflag %d>", f)
+}
+
+func (f Feature) flagString() string {
+ if s, ok := arm64FeatureStrings[f]; ok {
+ return s
+ }
+
+ return ""
+}
+
+// FeatureSet is a set of Features for a CPU.
+//
+// +stateify savable
+type FeatureSet struct {
+ // Set is the set of features that are enabled in this FeatureSet.
+ Set map[Feature]bool
+
+ // CPUImplementer is part of the processor signature.
+ CPUImplementer uint8
+
+ // CPUArchitecture is part of the processor signature.
+ CPUArchitecture uint8
+
+ // CPUVariant is part of the processor signature.
+ CPUVariant uint8
+
+ // CPUPartnum is part of the processor signature.
+ CPUPartnum uint16
+
+ // CPURevision is part of the processor signature.
+ CPURevision uint8
+}
+
+// CheckHostCompatible returns nil if fs is a subset of the host feature set.
+// Noop on arm64.
+func (fs *FeatureSet) CheckHostCompatible() error {
+ return nil
+}
+
+// ExtendedStateSize returns the number of bytes needed to save the "extended
+// state" for this processor and the boundary it must be aligned to. Extended
+// state includes floating point(NEON) registers, and other cpu state that's not
+// associated with the normal task context.
+func (fs *FeatureSet) ExtendedStateSize() (size, align uint) {
+ // ARMv8 provide 32x128bits NEON registers.
+ //
+ // Ref arch/arm64/include/uapi/asm/ptrace.h
+ // struct user_fpsimd_state {
+ // __uint128_t vregs[32];
+ // __u32 fpsr;
+ // __u32 fpcr;
+ // __u32 __reserved[2];
+ // };
+ return 528, 16
+}
+
+// HasFeature tests whether or not a feature is in the given feature set.
+func (fs *FeatureSet) HasFeature(feature Feature) bool {
+ return fs.Set[feature]
+}
+
+// UseXsave returns true if 'fs' supports the "xsave" instruction.
+//
+// Irrelevant on arm64.
+func (fs *FeatureSet) UseXsave() bool {
+ return false
+}
+
+// FlagsString prints out supported CPU "flags" field in /proc/cpuinfo.
+func (fs *FeatureSet) FlagsString() string {
+ var s []string
+ for f, _ := range arm64FeatureStrings {
+ if fs.Set[f] {
+ if fstr := f.flagString(); fstr != "" {
+ s = append(s, fstr)
+ }
+ }
+ }
+ return strings.Join(s, " ")
+}
+
+// WriteCPUInfoTo is to generate a section of one cpu in /proc/cpuinfo. This is
+// a minimal /proc/cpuinfo, and the bogomips field is simply made up.
+func (fs FeatureSet) WriteCPUInfoTo(cpu uint, b *bytes.Buffer) {
+ fmt.Fprintf(b, "processor\t: %d\n", cpu)
+ fmt.Fprintf(b, "BogoMIPS\t: %.02f\n", cpuFreqMHz) // It's bogus anyway.
+ fmt.Fprintf(b, "Features\t\t: %s\n", fs.FlagsString())
+ fmt.Fprintf(b, "CPU implementer\t: 0x%x\n", cpuImplHex)
+ fmt.Fprintf(b, "CPU architecture\t: %d\n", cpuArchDec)
+ fmt.Fprintf(b, "CPU variant\t: 0x%x\n", cpuVarHex)
+ fmt.Fprintf(b, "CPU part\t: 0x%x\n", cpuPartHex)
+ fmt.Fprintf(b, "CPU revision\t: %d\n", cpuRevDec)
+ fmt.Fprintln(b, "") // The /proc/cpuinfo file ends with an extra newline.
+}
+
+// HostFeatureSet uses hwCap to get host values and construct a feature set
+// that matches that of the host machine.
+func HostFeatureSet() *FeatureSet {
+ s := make(map[Feature]bool)
+
+ for f, _ := range arm64FeatureStrings {
+ if hwCap&(1<<f) != 0 {
+ s[f] = true
+ }
+ }
+
+ return &FeatureSet{
+ Set: s,
+ CPUImplementer: uint8(cpuImplHex),
+ CPUArchitecture: uint8(cpuArchDec),
+ CPUVariant: uint8(cpuVarHex),
+ CPUPartnum: uint16(cpuPartHex),
+ CPURevision: uint8(cpuRevDec),
+ }
+}
+
+// Reads bogomips from host /proc/cpuinfo. Must run before whitelisting.
+// This value is used to create the fake /proc/cpuinfo from a FeatureSet.
+func initCPUInfo() {
+ cpuinfob, err := ioutil.ReadFile("/proc/cpuinfo")
+ if err != nil {
+ // Leave it as 0. The standalone VDSO bails out in the same way.
+ log.Warningf("Could not read /proc/cpuinfo: %v", err)
+ return
+ }
+ cpuinfo := string(cpuinfob)
+
+ // We get the value straight from host /proc/cpuinfo.
+ for _, line := range strings.Split(cpuinfo, "\n") {
+ switch {
+ case strings.Contains(line, "BogoMIPS"):
+ {
+ splitMHz := strings.Split(line, ":")
+ if len(splitMHz) < 2 {
+ log.Warningf("Could not read /proc/cpuinfo: malformed BogoMIPS")
+ break
+ }
+
+ // If there was a problem, leave cpuFreqMHz as 0.
+ var err error
+ cpuFreqMHz, err = strconv.ParseFloat(strings.TrimSpace(splitMHz[1]), 64)
+ if err != nil {
+ log.Warningf("Could not parse BogoMIPS value %v: %v", splitMHz[1], err)
+ cpuFreqMHz = 0
+ }
+ }
+ case strings.Contains(line, "CPU implementer"):
+ {
+ splitImpl := strings.Split(line, ":")
+ if len(splitImpl) < 2 {
+ log.Warningf("Could not read /proc/cpuinfo: malformed CPU implementer")
+ break
+ }
+
+ // If there was a problem, leave cpuImplHex as 0.
+ var err error
+ cpuImplHex, err = strconv.ParseUint(strings.TrimSpace(splitImpl[1]), 0, 64)
+ if err != nil {
+ log.Warningf("Could not parse CPU implementer value %v: %v", splitImpl[1], err)
+ cpuImplHex = 0
+ }
+ }
+ case strings.Contains(line, "CPU architecture"):
+ {
+ splitArch := strings.Split(line, ":")
+ if len(splitArch) < 2 {
+ log.Warningf("Could not read /proc/cpuinfo: malformed CPU architecture")
+ break
+ }
+
+ // If there was a problem, leave cpuArchDec as 0.
+ var err error
+ cpuArchDec, err = strconv.ParseUint(strings.TrimSpace(splitArch[1]), 0, 64)
+ if err != nil {
+ log.Warningf("Could not parse CPU architecture value %v: %v", splitArch[1], err)
+ cpuArchDec = 0
+ }
+ }
+ case strings.Contains(line, "CPU variant"):
+ {
+ splitVar := strings.Split(line, ":")
+ if len(splitVar) < 2 {
+ log.Warningf("Could not read /proc/cpuinfo: malformed CPU variant")
+ break
+ }
+
+ // If there was a problem, leave cpuVarHex as 0.
+ var err error
+ cpuVarHex, err = strconv.ParseUint(strings.TrimSpace(splitVar[1]), 0, 64)
+ if err != nil {
+ log.Warningf("Could not parse CPU variant value %v: %v", splitVar[1], err)
+ cpuVarHex = 0
+ }
+ }
+ case strings.Contains(line, "CPU part"):
+ {
+ splitPart := strings.Split(line, ":")
+ if len(splitPart) < 2 {
+ log.Warningf("Could not read /proc/cpuinfo: malformed CPU part")
+ break
+ }
+
+ // If there was a problem, leave cpuPartHex as 0.
+ var err error
+ cpuPartHex, err = strconv.ParseUint(strings.TrimSpace(splitPart[1]), 0, 64)
+ if err != nil {
+ log.Warningf("Could not parse CPU part value %v: %v", splitPart[1], err)
+ cpuPartHex = 0
+ }
+ }
+ case strings.Contains(line, "CPU revision"):
+ {
+ splitRev := strings.Split(line, ":")
+ if len(splitRev) < 2 {
+ log.Warningf("Could not read /proc/cpuinfo: malformed CPU revision")
+ break
+ }
+
+ // If there was a problem, leave cpuRevDec as 0.
+ var err error
+ cpuRevDec, err = strconv.ParseUint(strings.TrimSpace(splitRev[1]), 0, 64)
+ if err != nil {
+ log.Warningf("Could not parse CPU revision value %v: %v", splitRev[1], err)
+ cpuRevDec = 0
+ }
+ }
+ }
+ }
+}
+
+// The auxiliary vector of a process on the Linux system can be read
+// from /proc/self/auxv, and tags and values are stored as 8-bytes
+// decimal key-value pairs on the 64-bit system.
+//
+// $ od -t d8 /proc/self/auxv
+// 0000000 33 140734615224320
+// 0000020 16 3219913727
+// 0000040 6 4096
+// 0000060 17 100
+// 0000100 3 94665627353152
+// 0000120 4 56
+// 0000140 5 9
+// 0000160 7 140425502162944
+// 0000200 8 0
+// 0000220 9 94665627365760
+// 0000240 11 1000
+// 0000260 12 1000
+// 0000300 13 1000
+// 0000320 14 1000
+// 0000340 23 0
+// 0000360 25 140734614619513
+// 0000400 26 0
+// 0000420 31 140734614626284
+// 0000440 15 140734614619529
+// 0000460 0 0
+func initHwCap() {
+ auxv, err := ioutil.ReadFile("/proc/self/auxv")
+ if err != nil {
+ log.Warningf("Could not read /proc/self/auxv: %v", err)
+ return
+ }
+
+ l := len(auxv) / 16
+ for i := 0; i < l; i++ {
+ tag := binary.LittleEndian.Uint64(auxv[i*16:])
+ val := binary.LittleEndian.Uint64(auxv[(i*16 + 8):])
+ if tag == _AT_HWCAP {
+ hwCap = uint(val)
+ break
+ }
+ }
+}
+
+func initFeaturesFromString() {
+ for f, s := range arm64FeatureStrings {
+ arm64FeaturesFromString[s] = f
+ }
+}
+
+func init() {
+ initCPUInfo()
+ initHwCap()
+ initFeaturesFromString()
+}
diff --git a/pkg/cpuid/cpuid_arm64_test.go b/pkg/cpuid/cpuid_arm64_test.go
new file mode 100644
index 000000000..a34f67779
--- /dev/null
+++ b/pkg/cpuid/cpuid_arm64_test.go
@@ -0,0 +1,55 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package cpuid
+
+import (
+ "testing"
+)
+
+var justFP = &FeatureSet{
+ Set: map[Feature]bool{
+ ARM64FeatureFP: true,
+ }}
+
+func TestHostFeatureSet(t *testing.T) {
+ hostFeatures := HostFeatureSet()
+ if len(hostFeatures.Set) == 0 {
+ t.Errorf("Got invalid feature set %v from HostFeatureSet()", hostFeatures)
+ }
+}
+
+func TestHasFeature(t *testing.T) {
+ if !justFP.HasFeature(ARM64FeatureFP) {
+ t.Errorf("HasFeature failed, %v should contain %v", justFP, ARM64FeatureFP)
+ }
+
+ if justFP.HasFeature(ARM64FeatureSM3) {
+ t.Errorf("HasFeature failed, %v should not contain %v", justFP, ARM64FeatureSM3)
+ }
+}
+
+func TestFeatureFromString(t *testing.T) {
+ f, ok := FeatureFromString("asimd")
+ if f != ARM64FeatureASIMD || !ok {
+ t.Errorf("got %v want asimd", f)
+ }
+
+ f, ok = FeatureFromString("bad")
+ if ok {
+ t.Errorf("got %v want nothing", f)
+ }
+}
diff --git a/pkg/cpuid/cpuid_parse_test.go b/pkg/cpuid/cpuid_parse_x86_test.go
index dd9969db4..d48418e69 100644
--- a/pkg/cpuid/cpuid_parse_test.go
+++ b/pkg/cpuid/cpuid_parse_x86_test.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build i386 amd64
+
package cpuid
import (
diff --git a/pkg/cpuid/cpuid_x86.go b/pkg/cpuid/cpuid_x86.go
new file mode 100644
index 000000000..a0bc55ea1
--- /dev/null
+++ b/pkg/cpuid/cpuid_x86.go
@@ -0,0 +1,1107 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build i386 amd64
+
+package cpuid
+
+import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "strconv"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// Common references for CPUID leaves and bits:
+//
+// Intel:
+// * Intel SDM Volume 2, Chapter 3.2 "CPUID" (more up-to-date)
+// * Intel Application Note 485 (more detailed)
+//
+// AMD:
+// * AMD64 APM Volume 3, Appendix 3 "Obtaining Processor Information ..."
+
+// block is a collection of 32 Feature bits.
+type block int
+
+const blockSize = 32
+
+// Feature bits are numbered according to "blocks". Each block is 32 bits, and
+// feature bits from the same source (cpuid leaf/level) are in the same block.
+func featureID(b block, bit int) Feature {
+ return Feature(32*int(b) + bit)
+}
+
+// Block 0 constants are all of the "basic" feature bits returned by a cpuid in
+// ecx with eax=1.
+const (
+ X86FeatureSSE3 Feature = iota
+ X86FeaturePCLMULDQ
+ X86FeatureDTES64
+ X86FeatureMONITOR
+ X86FeatureDSCPL
+ X86FeatureVMX
+ X86FeatureSMX
+ X86FeatureEST
+ X86FeatureTM2
+ X86FeatureSSSE3 // Not a typo, "supplemental" SSE3.
+ X86FeatureCNXTID
+ X86FeatureSDBG
+ X86FeatureFMA
+ X86FeatureCX16
+ X86FeatureXTPR
+ X86FeaturePDCM
+ _ // ecx bit 16 is reserved.
+ X86FeaturePCID
+ X86FeatureDCA
+ X86FeatureSSE4_1
+ X86FeatureSSE4_2
+ X86FeatureX2APIC
+ X86FeatureMOVBE
+ X86FeaturePOPCNT
+ X86FeatureTSCD
+ X86FeatureAES
+ X86FeatureXSAVE
+ X86FeatureOSXSAVE
+ X86FeatureAVX
+ X86FeatureF16C
+ X86FeatureRDRAND
+ _ // ecx bit 31 is reserved.
+)
+
+// Block 1 constants are all of the "basic" feature bits returned by a cpuid in
+// edx with eax=1.
+const (
+ X86FeatureFPU Feature = 32 + iota
+ X86FeatureVME
+ X86FeatureDE
+ X86FeaturePSE
+ X86FeatureTSC
+ X86FeatureMSR
+ X86FeaturePAE
+ X86FeatureMCE
+ X86FeatureCX8
+ X86FeatureAPIC
+ _ // edx bit 10 is reserved.
+ X86FeatureSEP
+ X86FeatureMTRR
+ X86FeaturePGE
+ X86FeatureMCA
+ X86FeatureCMOV
+ X86FeaturePAT
+ X86FeaturePSE36
+ X86FeaturePSN
+ X86FeatureCLFSH
+ _ // edx bit 20 is reserved.
+ X86FeatureDS
+ X86FeatureACPI
+ X86FeatureMMX
+ X86FeatureFXSR
+ X86FeatureSSE
+ X86FeatureSSE2
+ X86FeatureSS
+ X86FeatureHTT
+ X86FeatureTM
+ X86FeatureIA64
+ X86FeaturePBE
+)
+
+// Block 2 bits are the "structured extended" features returned in ebx for
+// eax=7, ecx=0.
+const (
+ X86FeatureFSGSBase Feature = 2*32 + iota
+ X86FeatureTSC_ADJUST
+ _ // ebx bit 2 is reserved.
+ X86FeatureBMI1
+ X86FeatureHLE
+ X86FeatureAVX2
+ X86FeatureFDP_EXCPTN_ONLY
+ X86FeatureSMEP
+ X86FeatureBMI2
+ X86FeatureERMS
+ X86FeatureINVPCID
+ X86FeatureRTM
+ X86FeatureCQM
+ X86FeatureFPCSDS
+ X86FeatureMPX
+ X86FeatureRDT
+ X86FeatureAVX512F
+ X86FeatureAVX512DQ
+ X86FeatureRDSEED
+ X86FeatureADX
+ X86FeatureSMAP
+ X86FeatureAVX512IFMA
+ X86FeaturePCOMMIT
+ X86FeatureCLFLUSHOPT
+ X86FeatureCLWB
+ X86FeatureIPT // Intel processor trace.
+ X86FeatureAVX512PF
+ X86FeatureAVX512ER
+ X86FeatureAVX512CD
+ X86FeatureSHA
+ X86FeatureAVX512BW
+ X86FeatureAVX512VL
+)
+
+// Block 3 bits are the "extended" features returned in ecx for eax=7, ecx=0.
+const (
+ X86FeaturePREFETCHWT1 Feature = 3*32 + iota
+ X86FeatureAVX512VBMI
+ X86FeatureUMIP
+ X86FeaturePKU
+ X86FeatureOSPKE
+ X86FeatureWAITPKG
+ X86FeatureAVX512_VBMI2
+ _ // ecx bit 7 is reserved
+ X86FeatureGFNI
+ X86FeatureVAES
+ X86FeatureVPCLMULQDQ
+ X86FeatureAVX512_VNNI
+ X86FeatureAVX512_BITALG
+ X86FeatureTME
+ X86FeatureAVX512_VPOPCNTDQ
+ _ // ecx bit 15 is reserved
+ X86FeatureLA57
+ // ecx bits 17-21 are reserved
+ _
+ _
+ _
+ _
+ _
+ X86FeatureRDPID
+ // ecx bits 23-24 are reserved
+ _
+ _
+ X86FeatureCLDEMOTE
+ _ // ecx bit 26 is reserved
+ X86FeatureMOVDIRI
+ X86FeatureMOVDIR64B
+)
+
+// Block 4 constants are for xsave capabilities in CPUID.(EAX=0DH,ECX=01H):EAX.
+// The CPUID leaf is available only if 'X86FeatureXSAVE' is present.
+const (
+ X86FeatureXSAVEOPT Feature = 4*32 + iota
+ X86FeatureXSAVEC
+ X86FeatureXGETBV1
+ X86FeatureXSAVES
+ // EAX[31:4] are reserved.
+)
+
+// Block 5 constants are the extended feature bits in
+// CPUID.(EAX=0x80000001):ECX.
+const (
+ X86FeatureLAHF64 Feature = 5*32 + iota
+ X86FeatureCMP_LEGACY
+ X86FeatureSVM
+ X86FeatureEXTAPIC
+ X86FeatureCR8_LEGACY
+ X86FeatureLZCNT
+ X86FeatureSSE4A
+ X86FeatureMISALIGNSSE
+ X86FeaturePREFETCHW
+ X86FeatureOSVW
+ X86FeatureIBS
+ X86FeatureXOP
+ X86FeatureSKINIT
+ X86FeatureWDT
+ _ // ecx bit 14 is reserved.
+ X86FeatureLWP
+ X86FeatureFMA4
+ X86FeatureTCE
+ _ // ecx bit 18 is reserved.
+ _ // ecx bit 19 is reserved.
+ _ // ecx bit 20 is reserved.
+ X86FeatureTBM
+ X86FeatureTOPOLOGY
+ X86FeaturePERFCTR_CORE
+ X86FeaturePERFCTR_NB
+ _ // ecx bit 25 is reserved.
+ X86FeatureBPEXT
+ X86FeaturePERFCTR_TSC
+ X86FeaturePERFCTR_LLC
+ X86FeatureMWAITX
+ // ECX[31:30] are reserved.
+)
+
+// Block 6 constants are the extended feature bits in
+// CPUID.(EAX=0x80000001):EDX.
+//
+// These are sparse, and so the bit positions are assigned manually.
+const (
+ // On AMD, EDX[24:23] | EDX[17:12] | EDX[9:0] are duplicate features
+ // also defined in block 1 (in identical bit positions). Those features
+ // are not listed here.
+ block6DuplicateMask = 0x183f3ff
+
+ X86FeatureSYSCALL Feature = 6*32 + 11
+ X86FeatureNX Feature = 6*32 + 20
+ X86FeatureMMXEXT Feature = 6*32 + 22
+ X86FeatureFXSR_OPT Feature = 6*32 + 25
+ X86FeatureGBPAGES Feature = 6*32 + 26
+ X86FeatureRDTSCP Feature = 6*32 + 27
+ X86FeatureLM Feature = 6*32 + 29
+ X86Feature3DNOWEXT Feature = 6*32 + 30
+ X86Feature3DNOW Feature = 6*32 + 31
+)
+
+// linuxBlockOrder defines the order in which linux organizes the feature
+// blocks. Linux also tracks feature bits in 32-bit blocks, but in an order
+// which doesn't match well here, so for the /proc/cpuinfo generation we simply
+// re-map the blocks to Linux's ordering and then go through the bits in each
+// block.
+var linuxBlockOrder = []block{1, 6, 0, 5, 2, 4, 3}
+
+// To make emulation of /proc/cpuinfo easy, these names match the names of the
+// basic features in Linux defined in arch/x86/kernel/cpu/capflags.c.
+var x86FeatureStrings = map[Feature]string{
+ // Block 0.
+ X86FeatureSSE3: "pni",
+ X86FeaturePCLMULDQ: "pclmulqdq",
+ X86FeatureDTES64: "dtes64",
+ X86FeatureMONITOR: "monitor",
+ X86FeatureDSCPL: "ds_cpl",
+ X86FeatureVMX: "vmx",
+ X86FeatureSMX: "smx",
+ X86FeatureEST: "est",
+ X86FeatureTM2: "tm2",
+ X86FeatureSSSE3: "ssse3",
+ X86FeatureCNXTID: "cid",
+ X86FeatureSDBG: "sdbg",
+ X86FeatureFMA: "fma",
+ X86FeatureCX16: "cx16",
+ X86FeatureXTPR: "xtpr",
+ X86FeaturePDCM: "pdcm",
+ X86FeaturePCID: "pcid",
+ X86FeatureDCA: "dca",
+ X86FeatureSSE4_1: "sse4_1",
+ X86FeatureSSE4_2: "sse4_2",
+ X86FeatureX2APIC: "x2apic",
+ X86FeatureMOVBE: "movbe",
+ X86FeaturePOPCNT: "popcnt",
+ X86FeatureTSCD: "tsc_deadline_timer",
+ X86FeatureAES: "aes",
+ X86FeatureXSAVE: "xsave",
+ X86FeatureAVX: "avx",
+ X86FeatureF16C: "f16c",
+ X86FeatureRDRAND: "rdrand",
+
+ // Block 1.
+ X86FeatureFPU: "fpu",
+ X86FeatureVME: "vme",
+ X86FeatureDE: "de",
+ X86FeaturePSE: "pse",
+ X86FeatureTSC: "tsc",
+ X86FeatureMSR: "msr",
+ X86FeaturePAE: "pae",
+ X86FeatureMCE: "mce",
+ X86FeatureCX8: "cx8",
+ X86FeatureAPIC: "apic",
+ X86FeatureSEP: "sep",
+ X86FeatureMTRR: "mtrr",
+ X86FeaturePGE: "pge",
+ X86FeatureMCA: "mca",
+ X86FeatureCMOV: "cmov",
+ X86FeaturePAT: "pat",
+ X86FeaturePSE36: "pse36",
+ X86FeaturePSN: "pn",
+ X86FeatureCLFSH: "clflush",
+ X86FeatureDS: "dts",
+ X86FeatureACPI: "acpi",
+ X86FeatureMMX: "mmx",
+ X86FeatureFXSR: "fxsr",
+ X86FeatureSSE: "sse",
+ X86FeatureSSE2: "sse2",
+ X86FeatureSS: "ss",
+ X86FeatureHTT: "ht",
+ X86FeatureTM: "tm",
+ X86FeatureIA64: "ia64",
+ X86FeaturePBE: "pbe",
+
+ // Block 2.
+ X86FeatureFSGSBase: "fsgsbase",
+ X86FeatureTSC_ADJUST: "tsc_adjust",
+ X86FeatureBMI1: "bmi1",
+ X86FeatureHLE: "hle",
+ X86FeatureAVX2: "avx2",
+ X86FeatureSMEP: "smep",
+ X86FeatureBMI2: "bmi2",
+ X86FeatureERMS: "erms",
+ X86FeatureINVPCID: "invpcid",
+ X86FeatureRTM: "rtm",
+ X86FeatureCQM: "cqm",
+ X86FeatureMPX: "mpx",
+ X86FeatureRDT: "rdt_a",
+ X86FeatureAVX512F: "avx512f",
+ X86FeatureAVX512DQ: "avx512dq",
+ X86FeatureRDSEED: "rdseed",
+ X86FeatureADX: "adx",
+ X86FeatureSMAP: "smap",
+ X86FeatureCLWB: "clwb",
+ X86FeatureAVX512PF: "avx512pf",
+ X86FeatureAVX512ER: "avx512er",
+ X86FeatureAVX512CD: "avx512cd",
+ X86FeatureSHA: "sha_ni",
+ X86FeatureAVX512BW: "avx512bw",
+ X86FeatureAVX512VL: "avx512vl",
+
+ // Block 3.
+ X86FeatureAVX512VBMI: "avx512vbmi",
+ X86FeatureUMIP: "umip",
+ X86FeaturePKU: "pku",
+ X86FeatureOSPKE: "ospke",
+ X86FeatureWAITPKG: "waitpkg",
+ X86FeatureAVX512_VBMI2: "avx512_vbmi2",
+ X86FeatureGFNI: "gfni",
+ X86FeatureVAES: "vaes",
+ X86FeatureVPCLMULQDQ: "vpclmulqdq",
+ X86FeatureAVX512_VNNI: "avx512_vnni",
+ X86FeatureAVX512_BITALG: "avx512_bitalg",
+ X86FeatureTME: "tme",
+ X86FeatureAVX512_VPOPCNTDQ: "avx512_vpopcntdq",
+ X86FeatureLA57: "la57",
+ X86FeatureRDPID: "rdpid",
+ X86FeatureCLDEMOTE: "cldemote",
+ X86FeatureMOVDIRI: "movdiri",
+ X86FeatureMOVDIR64B: "movdir64b",
+
+ // Block 4.
+ X86FeatureXSAVEOPT: "xsaveopt",
+ X86FeatureXSAVEC: "xsavec",
+ X86FeatureXGETBV1: "xgetbv1",
+ X86FeatureXSAVES: "xsaves",
+
+ // Block 5.
+ X86FeatureLAHF64: "lahf_lm", // LAHF/SAHF in long mode
+ X86FeatureCMP_LEGACY: "cmp_legacy",
+ X86FeatureSVM: "svm",
+ X86FeatureEXTAPIC: "extapic",
+ X86FeatureCR8_LEGACY: "cr8_legacy",
+ X86FeatureLZCNT: "abm", // Advanced bit manipulation
+ X86FeatureSSE4A: "sse4a",
+ X86FeatureMISALIGNSSE: "misalignsse",
+ X86FeaturePREFETCHW: "3dnowprefetch",
+ X86FeatureOSVW: "osvw",
+ X86FeatureIBS: "ibs",
+ X86FeatureXOP: "xop",
+ X86FeatureSKINIT: "skinit",
+ X86FeatureWDT: "wdt",
+ X86FeatureLWP: "lwp",
+ X86FeatureFMA4: "fma4",
+ X86FeatureTCE: "tce",
+ X86FeatureTBM: "tbm",
+ X86FeatureTOPOLOGY: "topoext",
+ X86FeaturePERFCTR_CORE: "perfctr_core",
+ X86FeaturePERFCTR_NB: "perfctr_nb",
+ X86FeatureBPEXT: "bpext",
+ X86FeaturePERFCTR_TSC: "ptsc",
+ X86FeaturePERFCTR_LLC: "perfctr_llc",
+ X86FeatureMWAITX: "mwaitx",
+
+ // Block 6.
+ X86FeatureSYSCALL: "syscall",
+ X86FeatureNX: "nx",
+ X86FeatureMMXEXT: "mmxext",
+ X86FeatureFXSR_OPT: "fxsr_opt",
+ X86FeatureGBPAGES: "pdpe1gb",
+ X86FeatureRDTSCP: "rdtscp",
+ X86FeatureLM: "lm",
+ X86Feature3DNOWEXT: "3dnowext",
+ X86Feature3DNOW: "3dnow",
+}
+
+// These flags are parse only---they can be used for setting / unsetting the
+// flags, but will not get printed out in /proc/cpuinfo.
+var x86FeatureParseOnlyStrings = map[Feature]string{
+ // Block 0.
+ X86FeatureOSXSAVE: "osxsave",
+
+ // Block 2.
+ X86FeatureFDP_EXCPTN_ONLY: "fdp_excptn_only",
+ X86FeatureFPCSDS: "fpcsds",
+ X86FeatureIPT: "pt",
+ X86FeatureCLFLUSHOPT: "clfushopt",
+
+ // Block 3.
+ X86FeaturePREFETCHWT1: "prefetchwt1",
+}
+
+// intelCacheDescriptors describe the caches and TLBs on the system. They are
+// returned in the registers for eax=2. Intel only.
+type intelCacheDescriptor uint8
+
+// Valid cache/TLB descriptors. All descriptors can be found in Intel SDM Vol.
+// 2, Ch. 3.2, "CPUID", Table 3-12 "Encoding of CPUID Leaf 2 Descriptors".
+const (
+ intelNullDescriptor intelCacheDescriptor = 0
+ intelNoTLBDescriptor intelCacheDescriptor = 0xfe
+ intelNoCacheDescriptor intelCacheDescriptor = 0xff
+
+ // Most descriptors omitted for brevity as they are currently unused.
+)
+
+// CacheType describes the type of a cache, as returned in eax[4:0] for eax=4.
+type CacheType uint8
+
+const (
+ // cacheNull indicates that there are no more entries.
+ cacheNull CacheType = iota
+
+ // CacheData is a data cache.
+ CacheData
+
+ // CacheInstruction is an instruction cache.
+ CacheInstruction
+
+ // CacheUnified is a unified instruction and data cache.
+ CacheUnified
+)
+
+// Cache describes the parameters of a single cache on the system.
+//
+// +stateify savable
+type Cache struct {
+ // Level is the hierarchical level of this cache (L1, L2, etc).
+ Level uint32
+
+ // Type is the type of cache.
+ Type CacheType
+
+ // FullyAssociative indicates that entries may be placed in any block.
+ FullyAssociative bool
+
+ // Partitions is the number of physical partitions in the cache.
+ Partitions uint32
+
+ // Ways is the number of ways of associativity in the cache.
+ Ways uint32
+
+ // Sets is the number of sets in the cache.
+ Sets uint32
+
+ // InvalidateHierarchical indicates that WBINVD/INVD from threads
+ // sharing this cache acts upon lower level caches for threads sharing
+ // this cache.
+ InvalidateHierarchical bool
+
+ // Inclusive indicates that this cache is inclusive of lower cache
+ // levels.
+ Inclusive bool
+
+ // DirectMapped indicates that this cache is directly mapped from
+ // address, rather than using a hash function.
+ DirectMapped bool
+}
+
+// Just a way to wrap cpuid function numbers.
+type cpuidFunction uint32
+
+// The constants below are the lower or "standard" cpuid functions, ordered as
+// defined by the hardware.
+const (
+ vendorID cpuidFunction = iota // Returns vendor ID and largest standard function.
+ featureInfo // Returns basic feature bits and processor signature.
+ intelCacheDescriptors // Returns list of cache descriptors. Intel only.
+ intelSerialNumber // Returns processor serial number (obsolete on new hardware). Intel only.
+ intelDeterministicCacheParams // Returns deterministic cache information. Intel only.
+ monitorMwaitParams // Returns information about monitor/mwait instructions.
+ powerParams // Returns information about power management and thermal sensors.
+ extendedFeatureInfo // Returns extended feature bits.
+ _ // Function 0x8 is reserved.
+ intelDCAParams // Returns direct cache access information. Intel only.
+ intelPMCInfo // Returns information about performance monitoring features. Intel only.
+ intelX2APICInfo // Returns core/logical processor topology. Intel only.
+ _ // Function 0xc is reserved.
+ xSaveInfo // Returns information about extended state management.
+)
+
+// The "extended" functions start at 0x80000000.
+const (
+ extendedFunctionInfo cpuidFunction = 0x80000000 + iota // Returns highest available extended function in eax.
+ extendedFeatures // Returns some extended feature bits in edx and ecx.
+)
+
+// These are the extended floating point state features. They are used to
+// enumerate floating point features in XCR0, XSTATE_BV, etc.
+const (
+ XSAVEFeatureX87 = 1 << 0
+ XSAVEFeatureSSE = 1 << 1
+ XSAVEFeatureAVX = 1 << 2
+ XSAVEFeatureBNDREGS = 1 << 3
+ XSAVEFeatureBNDCSR = 1 << 4
+ XSAVEFeatureAVX512op = 1 << 5
+ XSAVEFeatureAVX512zmm0 = 1 << 6
+ XSAVEFeatureAVX512zmm16 = 1 << 7
+ XSAVEFeaturePKRU = 1 << 9
+)
+
+var cpuFreqMHz float64
+
+// x86FeaturesFromString includes features from x86FeatureStrings and
+// x86FeatureParseOnlyStrings.
+var x86FeaturesFromString = make(map[string]Feature)
+
+// FeatureFromString returns the Feature associated with the given feature
+// string plus a bool to indicate if it could find the feature.
+func FeatureFromString(s string) (Feature, bool) {
+ f, b := x86FeaturesFromString[s]
+ return f, b
+}
+
+// String implements fmt.Stringer.
+func (f Feature) String() string {
+ if s := f.flagString(false); s != "" {
+ return s
+ }
+
+ block := int(f) / 32
+ bit := int(f) % 32
+ return fmt.Sprintf("<cpuflag %d; block %d bit %d>", f, block, bit)
+}
+
+func (f Feature) flagString(cpuinfoOnly bool) string {
+ if s, ok := x86FeatureStrings[f]; ok {
+ return s
+ }
+ if !cpuinfoOnly {
+ return x86FeatureParseOnlyStrings[f]
+ }
+ return ""
+}
+
+// FeatureSet is a set of Features for a CPU.
+//
+// +stateify savable
+type FeatureSet struct {
+ // Set is the set of features that are enabled in this FeatureSet.
+ Set map[Feature]bool
+
+ // VendorID is the 12-char string returned in ebx:edx:ecx for eax=0.
+ VendorID string
+
+ // ExtendedFamily is part of the processor signature.
+ ExtendedFamily uint8
+
+ // ExtendedModel is part of the processor signature.
+ ExtendedModel uint8
+
+ // ProcessorType is part of the processor signature.
+ ProcessorType uint8
+
+ // Family is part of the processor signature.
+ Family uint8
+
+ // Model is part of the processor signature.
+ Model uint8
+
+ // SteppingID is part of the processor signature.
+ SteppingID uint8
+
+ // Caches describes the caches on the CPU.
+ Caches []Cache
+
+ // CacheLine is the size of a cache line in bytes.
+ //
+ // All caches use the same line size. This is not enforced in the CPUID
+ // encoding, but is true on all known x86 processors.
+ CacheLine uint32
+}
+
+// FlagsString prints out supported CPU flags. If cpuinfoOnly is true, it is
+// equivalent to the "flags" field in /proc/cpuinfo.
+func (fs *FeatureSet) FlagsString(cpuinfoOnly bool) string {
+ var s []string
+ for _, b := range linuxBlockOrder {
+ for i := 0; i < blockSize; i++ {
+ if f := featureID(b, i); fs.Set[f] {
+ if fstr := f.flagString(cpuinfoOnly); fstr != "" {
+ s = append(s, fstr)
+ }
+ }
+ }
+ }
+ return strings.Join(s, " ")
+}
+
+// WriteCPUInfoTo is to generate a section of one cpu in /proc/cpuinfo. This is
+// a minimal /proc/cpuinfo, it is missing some fields like "microcode" that are
+// not always printed in Linux. The bogomips field is simply made up.
+func (fs FeatureSet) WriteCPUInfoTo(cpu uint, b *bytes.Buffer) {
+ fmt.Fprintf(b, "processor\t: %d\n", cpu)
+ fmt.Fprintf(b, "vendor_id\t: %s\n", fs.VendorID)
+ fmt.Fprintf(b, "cpu family\t: %d\n", ((fs.ExtendedFamily<<4)&0xff)|fs.Family)
+ fmt.Fprintf(b, "model\t\t: %d\n", ((fs.ExtendedModel<<4)&0xff)|fs.Model)
+ fmt.Fprintf(b, "model name\t: %s\n", "unknown") // Unknown for now.
+ fmt.Fprintf(b, "stepping\t: %s\n", "unknown") // Unknown for now.
+ fmt.Fprintf(b, "cpu MHz\t\t: %.3f\n", cpuFreqMHz)
+ fmt.Fprintln(b, "fpu\t\t: yes")
+ fmt.Fprintln(b, "fpu_exception\t: yes")
+ fmt.Fprintf(b, "cpuid level\t: %d\n", uint32(xSaveInfo)) // Same as ax in vendorID.
+ fmt.Fprintln(b, "wp\t\t: yes")
+ fmt.Fprintf(b, "flags\t\t: %s\n", fs.FlagsString(true))
+ fmt.Fprintf(b, "bogomips\t: %.02f\n", cpuFreqMHz) // It's bogus anyway.
+ fmt.Fprintf(b, "clflush size\t: %d\n", fs.CacheLine)
+ fmt.Fprintf(b, "cache_alignment\t: %d\n", fs.CacheLine)
+ fmt.Fprintf(b, "address sizes\t: %d bits physical, %d bits virtual\n", 46, 48)
+ fmt.Fprintln(b, "power management:") // This is always here, but can be blank.
+ fmt.Fprintln(b, "") // The /proc/cpuinfo file ends with an extra newline.
+}
+
+const (
+ amdVendorID = "AuthenticAMD"
+ intelVendorID = "GenuineIntel"
+)
+
+// AMD returns true if fs describes an AMD CPU.
+func (fs *FeatureSet) AMD() bool {
+ return fs.VendorID == amdVendorID
+}
+
+// Intel returns true if fs describes an Intel CPU.
+func (fs *FeatureSet) Intel() bool {
+ return fs.VendorID == intelVendorID
+}
+
+// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a
+// subset of the host feature set.
+type ErrIncompatible struct {
+ message string
+}
+
+// Error implements error.
+func (e ErrIncompatible) Error() string {
+ return e.message
+}
+
+// CheckHostCompatible returns nil if fs is a subset of the host feature set.
+func (fs *FeatureSet) CheckHostCompatible() error {
+ hfs := HostFeatureSet()
+
+ if diff := fs.Subtract(hfs); diff != nil {
+ return ErrIncompatible{fmt.Sprintf("CPU feature set %v incompatible with host feature set %v (missing: %v)", fs.FlagsString(false), hfs.FlagsString(false), diff)}
+ }
+
+ // The size of a cache line must match, as it is critical to correctly
+ // utilizing CLFLUSH. Other cache properties are allowed to change, as
+ // they are not important to correctness.
+ if fs.CacheLine != hfs.CacheLine {
+ return ErrIncompatible{fmt.Sprintf("CPU cache line size %d incompatible with host cache line size %d", fs.CacheLine, hfs.CacheLine)}
+ }
+
+ return nil
+}
+
+// Helper to convert 3 regs into 12-byte vendor ID.
+func vendorIDFromRegs(bx, cx, dx uint32) string {
+ bytes := make([]byte, 0, 12)
+ for i := uint(0); i < 4; i++ {
+ b := byte(bx >> (i * 8))
+ bytes = append(bytes, b)
+ }
+
+ for i := uint(0); i < 4; i++ {
+ b := byte(dx >> (i * 8))
+ bytes = append(bytes, b)
+ }
+
+ for i := uint(0); i < 4; i++ {
+ b := byte(cx >> (i * 8))
+ bytes = append(bytes, b)
+ }
+ return string(bytes)
+}
+
+var maxXsaveSize = func() uint32 {
+ // Leaf 0 of xsaveinfo function returns the size for currently
+ // enabled xsave features in ebx, the maximum size if all valid
+ // features are saved with xsave in ecx, and valid XCR0 bits in
+ // edx:eax.
+ //
+ // If xSaveInfo isn't supported, cpuid will not fault but will
+ // return bogus values.
+ _, _, maxXsaveSize, _ := HostID(uint32(xSaveInfo), 0)
+ return maxXsaveSize
+}()
+
+// ExtendedStateSize returns the number of bytes needed to save the "extended
+// state" for this processor and the boundary it must be aligned to. Extended
+// state includes floating point registers, and other cpu state that's not
+// associated with the normal task context.
+//
+// Note: We can save some space here with an optimization where we use a
+// smaller chunk of memory depending on features that are actually enabled.
+// Currently we just use the largest possible size for simplicity (which is
+// about 2.5K worst case, with avx512).
+func (fs *FeatureSet) ExtendedStateSize() (size, align uint) {
+ if fs.UseXsave() {
+ return uint(maxXsaveSize), 64
+ }
+
+ // If we don't support xsave, we fall back to fxsave, which requires
+ // 512 bytes aligned to 16 bytes.
+ return 512, 16
+}
+
+// ValidXCR0Mask returns the bits that may be set to 1 in control register
+// XCR0.
+func (fs *FeatureSet) ValidXCR0Mask() uint64 {
+ if !fs.UseXsave() {
+ return 0
+ }
+ eax, _, _, edx := HostID(uint32(xSaveInfo), 0)
+ return uint64(edx)<<32 | uint64(eax)
+}
+
+// vendorIDRegs returns the 3 register values used to construct the 12-byte
+// vendor ID string for eax=0.
+func (fs *FeatureSet) vendorIDRegs() (bx, dx, cx uint32) {
+ for i := uint(0); i < 4; i++ {
+ bx |= uint32(fs.VendorID[i]) << (i * 8)
+ }
+
+ for i := uint(0); i < 4; i++ {
+ dx |= uint32(fs.VendorID[i+4]) << (i * 8)
+ }
+
+ for i := uint(0); i < 4; i++ {
+ cx |= uint32(fs.VendorID[i+8]) << (i * 8)
+ }
+ return
+}
+
+// signature returns the signature dword that's returned in eax when eax=1.
+func (fs *FeatureSet) signature() uint32 {
+ var s uint32
+ s |= uint32(fs.SteppingID & 0xf)
+ s |= uint32(fs.Model&0xf) << 4
+ s |= uint32(fs.Family&0xf) << 8
+ s |= uint32(fs.ProcessorType&0x3) << 12
+ s |= uint32(fs.ExtendedModel&0xf) << 16
+ s |= uint32(fs.ExtendedFamily&0xff) << 20
+ return s
+}
+
+// Helper to deconstruct signature dword.
+func signatureSplit(v uint32) (ef, em, pt, f, m, sid uint8) {
+ sid = uint8(v & 0xf)
+ m = uint8(v>>4) & 0xf
+ f = uint8(v>>8) & 0xf
+ pt = uint8(v>>12) & 0x3
+ em = uint8(v>>16) & 0xf
+ ef = uint8(v >> 20)
+ return
+}
+
+// Helper to convert blockwise feature bit masks into a set of features. Masks
+// must be provided in order for each block, without skipping them. If a block
+// does not matter for this feature set, 0 is specified.
+func setFromBlockMasks(blocks ...uint32) map[Feature]bool {
+ s := make(map[Feature]bool)
+ for b, blockMask := range blocks {
+ for i := 0; i < blockSize; i++ {
+ if blockMask&1 != 0 {
+ s[featureID(block(b), i)] = true
+ }
+ blockMask >>= 1
+ }
+ }
+ return s
+}
+
+// blockMask returns the 32-bit mask associated with a block of features.
+func (fs *FeatureSet) blockMask(b block) uint32 {
+ var mask uint32
+ for i := 0; i < blockSize; i++ {
+ if fs.Set[featureID(b, i)] {
+ mask |= 1 << uint(i)
+ }
+ }
+ return mask
+}
+
+// Remove removes a Feature from a FeatureSet. It ignores features
+// that are not in the FeatureSet.
+func (fs *FeatureSet) Remove(feature Feature) {
+ delete(fs.Set, feature)
+}
+
+// Add adds a Feature to a FeatureSet. It ignores duplicate features.
+func (fs *FeatureSet) Add(feature Feature) {
+ fs.Set[feature] = true
+}
+
+// HasFeature tests whether or not a feature is in the given feature set.
+func (fs *FeatureSet) HasFeature(feature Feature) bool {
+ return fs.Set[feature]
+}
+
+// Subtract returns the features present in fs that are not present in other.
+// If all features in fs are present in other, Subtract returns nil.
+func (fs *FeatureSet) Subtract(other *FeatureSet) (diff map[Feature]bool) {
+ for f := range fs.Set {
+ if !other.Set[f] {
+ if diff == nil {
+ diff = make(map[Feature]bool)
+ }
+ diff[f] = true
+ }
+ }
+
+ return
+}
+
+// EmulateID emulates a cpuid instruction based on the feature set.
+func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) {
+ switch cpuidFunction(origAx) {
+ case vendorID:
+ ax = uint32(xSaveInfo) // 0xd (xSaveInfo) is the highest function we support.
+ bx, dx, cx = fs.vendorIDRegs()
+ case featureInfo:
+ // CLFLUSH line size is encoded in quadwords. Other fields in bx unsupported.
+ bx = (fs.CacheLine / 8) << 8
+ cx = fs.blockMask(block(0))
+ dx = fs.blockMask(block(1))
+ ax = fs.signature()
+ case intelCacheDescriptors:
+ if !fs.Intel() {
+ // Reserved on non-Intel.
+ return 0, 0, 0, 0
+ }
+
+ // "The least-significant byte in register EAX (register AL)
+ // will always return 01H. Software should ignore this value
+ // and not interpret it as an informational descriptor." - SDM
+ //
+ // We only support reporting cache parameters via
+ // intelDeterministicCacheParams; report as much here.
+ //
+ // We do not support exposing TLB information at all.
+ ax = 1 | (uint32(intelNoCacheDescriptor) << 8)
+ case intelDeterministicCacheParams:
+ if !fs.Intel() {
+ // Reserved on non-Intel.
+ return 0, 0, 0, 0
+ }
+
+ // cx is the index of the cache to describe.
+ if int(origCx) >= len(fs.Caches) {
+ return uint32(cacheNull), 0, 0, 0
+ }
+ c := fs.Caches[origCx]
+
+ ax = uint32(c.Type)
+ ax |= c.Level << 5
+ ax |= 1 << 8 // Always claim the cache is "self-initializing".
+ if c.FullyAssociative {
+ ax |= 1 << 9
+ }
+ // Processor topology not supported.
+
+ bx = fs.CacheLine - 1
+ bx |= (c.Partitions - 1) << 12
+ bx |= (c.Ways - 1) << 22
+
+ cx = c.Sets - 1
+
+ if !c.InvalidateHierarchical {
+ dx |= 1
+ }
+ if c.Inclusive {
+ dx |= 1 << 1
+ }
+ if !c.DirectMapped {
+ dx |= 1 << 2
+ }
+ case xSaveInfo:
+ if !fs.UseXsave() {
+ return 0, 0, 0, 0
+ }
+ return HostID(uint32(xSaveInfo), origCx)
+ case extendedFeatureInfo:
+ if origCx != 0 {
+ break // Only leaf 0 is supported.
+ }
+ bx = fs.blockMask(block(2))
+ cx = fs.blockMask(block(3))
+ case extendedFunctionInfo:
+ // We only support showing the extended features.
+ ax = uint32(extendedFeatures)
+ cx = 0
+ case extendedFeatures:
+ cx = fs.blockMask(block(5))
+ dx = fs.blockMask(block(6))
+ if fs.AMD() {
+ // AMD duplicates some block 1 features in block 6.
+ dx |= fs.blockMask(block(1)) & block6DuplicateMask
+ }
+ }
+
+ return
+}
+
+// UseXsave returns the choice of fp state saving instruction.
+func (fs *FeatureSet) UseXsave() bool {
+ return fs.HasFeature(X86FeatureXSAVE) && fs.HasFeature(X86FeatureOSXSAVE)
+}
+
+// UseXsaveopt returns true if 'fs' supports the "xsaveopt" instruction.
+func (fs *FeatureSet) UseXsaveopt() bool {
+ return fs.UseXsave() && fs.HasFeature(X86FeatureXSAVEOPT)
+}
+
+// HostID executes a native CPUID instruction.
+func HostID(axArg, cxArg uint32) (ax, bx, cx, dx uint32)
+
+// HostFeatureSet uses cpuid to get host values and construct a feature set
+// that matches that of the host machine. Note that there are several places
+// where there appear to be some unnecessary assignments between register names
+// (ax, bx, cx, or dx) and featureBlockN variables. This is to explicitly show
+// where the different feature blocks come from, to make the code easier to
+// inspect and read.
+func HostFeatureSet() *FeatureSet {
+ // eax=0 gets max supported feature and vendor ID.
+ _, bx, cx, dx := HostID(0, 0)
+ vendorID := vendorIDFromRegs(bx, cx, dx)
+
+ // eax=1 gets basic features in ecx:edx.
+ ax, bx, cx, dx := HostID(1, 0)
+ featureBlock0 := cx
+ featureBlock1 := dx
+ ef, em, pt, f, m, sid := signatureSplit(ax)
+ cacheLine := 8 * (bx >> 8) & 0xff
+
+ // eax=4, ecx=i gets details about cache index i. Only supported on Intel.
+ var caches []Cache
+ if vendorID == intelVendorID {
+ // ecx selects the cache index until a null type is returned.
+ for i := uint32(0); ; i++ {
+ ax, bx, cx, dx := HostID(4, i)
+ t := CacheType(ax & 0xf)
+ if t == cacheNull {
+ break
+ }
+
+ lineSize := (bx & 0xfff) + 1
+ if lineSize != cacheLine {
+ panic(fmt.Sprintf("Mismatched cache line size: %d vs %d", lineSize, cacheLine))
+ }
+
+ caches = append(caches, Cache{
+ Type: t,
+ Level: (ax >> 5) & 0x7,
+ FullyAssociative: ((ax >> 9) & 1) == 1,
+ Partitions: ((bx >> 12) & 0x3ff) + 1,
+ Ways: ((bx >> 22) & 0x3ff) + 1,
+ Sets: cx + 1,
+ InvalidateHierarchical: (dx & 1) == 0,
+ Inclusive: ((dx >> 1) & 1) == 1,
+ DirectMapped: ((dx >> 2) & 1) == 0,
+ })
+ }
+ }
+
+ // eax=7, ecx=0 gets extended features in ecx:ebx.
+ _, bx, cx, _ = HostID(7, 0)
+ featureBlock2 := bx
+ featureBlock3 := cx
+
+ // Leaf 0xd is supported only if CPUID.1:ECX.XSAVE[bit 26] is set.
+ var featureBlock4 uint32
+ if (featureBlock0 & (1 << 26)) != 0 {
+ featureBlock4, _, _, _ = HostID(uint32(xSaveInfo), 1)
+ }
+
+ // eax=0x80000000 gets supported extended levels. We use this to
+ // determine if there are any non-zero block 4 or block 6 bits to find.
+ var featureBlock5, featureBlock6 uint32
+ if ax, _, _, _ := HostID(uint32(extendedFunctionInfo), 0); ax >= uint32(extendedFeatures) {
+ // eax=0x80000001 gets AMD added feature bits.
+ _, _, cx, dx = HostID(uint32(extendedFeatures), 0)
+ featureBlock5 = cx
+ // Ignore features duplicated from block 1 on AMD. These bits
+ // are reserved on Intel.
+ featureBlock6 = dx &^ block6DuplicateMask
+ }
+
+ set := setFromBlockMasks(featureBlock0, featureBlock1, featureBlock2, featureBlock3, featureBlock4, featureBlock5, featureBlock6)
+ return &FeatureSet{
+ Set: set,
+ VendorID: vendorID,
+ ExtendedFamily: ef,
+ ExtendedModel: em,
+ ProcessorType: pt,
+ Family: f,
+ Model: m,
+ SteppingID: sid,
+ CacheLine: cacheLine,
+ Caches: caches,
+ }
+}
+
+// Reads max cpu frequency from host /proc/cpuinfo. Must run before
+// whitelisting. This value is used to create the fake /proc/cpuinfo from a
+// FeatureSet.
+func initCPUFreq() {
+ cpuinfob, err := ioutil.ReadFile("/proc/cpuinfo")
+ if err != nil {
+ // Leave it as 0... The standalone VDSO bails out in the same
+ // way.
+ log.Warningf("Could not read /proc/cpuinfo: %v", err)
+ return
+ }
+ cpuinfo := string(cpuinfob)
+
+ // We get the value straight from host /proc/cpuinfo. On machines with
+ // frequency scaling enabled, this will only get the current value
+ // which will likely be inaccurate. This is fine on machines with
+ // frequency scaling disabled.
+ for _, line := range strings.Split(cpuinfo, "\n") {
+ if strings.Contains(line, "cpu MHz") {
+ splitMHz := strings.Split(line, ":")
+ if len(splitMHz) < 2 {
+ log.Warningf("Could not read /proc/cpuinfo: malformed cpu MHz line")
+ return
+ }
+
+ // If there was a problem, leave cpuFreqMHz as 0.
+ var err error
+ cpuFreqMHz, err = strconv.ParseFloat(strings.TrimSpace(splitMHz[1]), 64)
+ if err != nil {
+ log.Warningf("Could not parse cpu MHz value %v: %v", splitMHz[1], err)
+ cpuFreqMHz = 0
+ return
+ }
+ return
+ }
+ }
+ log.Warningf("Could not parse /proc/cpuinfo, it is empty or does not contain cpu MHz")
+}
+
+func initFeaturesFromString() {
+ for f, s := range x86FeatureStrings {
+ x86FeaturesFromString[s] = f
+ }
+ for f, s := range x86FeatureParseOnlyStrings {
+ x86FeaturesFromString[s] = f
+ }
+}
+
+func init() {
+ // initCpuFreq must be run before whitelists are enabled.
+ initCPUFreq()
+ initFeaturesFromString()
+}
diff --git a/pkg/cpuid/cpuid_test.go b/pkg/cpuid/cpuid_x86_test.go
index a707ebb55..0fe20c213 100644
--- a/pkg/cpuid/cpuid_test.go
+++ b/pkg/cpuid/cpuid_x86_test.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build i386 amd64
+
package cpuid
import (
diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD
index ee84471b2..67dd1e225 100644
--- a/pkg/fspath/BUILD
+++ b/pkg/fspath/BUILD
@@ -8,9 +8,11 @@ go_library(
name = "fspath",
srcs = [
"builder.go",
- "builder_unsafe.go",
"fspath.go",
],
+ deps = [
+ "//pkg/gohacks",
+ ],
)
go_test(
diff --git a/pkg/fspath/builder.go b/pkg/fspath/builder.go
index 7ddb36826..6318d3874 100644
--- a/pkg/fspath/builder.go
+++ b/pkg/fspath/builder.go
@@ -16,6 +16,8 @@ package fspath
import (
"fmt"
+
+ "gvisor.dev/gvisor/pkg/gohacks"
)
// Builder is similar to strings.Builder, but is used to produce pathnames
@@ -102,3 +104,9 @@ func (b *Builder) AppendString(str string) {
copy(b.buf[b.start:], b.buf[oldStart:])
copy(b.buf[len(b.buf)-len(str):], str)
}
+
+// String returns the accumulated string. No other methods should be called
+// after String.
+func (b *Builder) String() string {
+ return gohacks.StringFromImmutableBytes(b.buf[b.start:])
+}
diff --git a/pkg/fspath/fspath.go b/pkg/fspath/fspath.go
index 9fb3fee24..4c983d5fd 100644
--- a/pkg/fspath/fspath.go
+++ b/pkg/fspath/fspath.go
@@ -67,7 +67,8 @@ func Parse(pathname string) Path {
// Path contains the information contained in a pathname string.
//
-// Path is copyable by value.
+// Path is copyable by value. The zero value for Path is equivalent to
+// fspath.Parse(""), i.e. the empty path.
type Path struct {
// Begin is an iterator to the first path component in the relative part of
// the path.
diff --git a/pkg/gohacks/BUILD b/pkg/gohacks/BUILD
new file mode 100644
index 000000000..798a65eca
--- /dev/null
+++ b/pkg/gohacks/BUILD
@@ -0,0 +1,11 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "gohacks",
+ srcs = [
+ "gohacks_unsafe.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/gohacks/gohacks_unsafe.go b/pkg/gohacks/gohacks_unsafe.go
new file mode 100644
index 000000000..aad675172
--- /dev/null
+++ b/pkg/gohacks/gohacks_unsafe.go
@@ -0,0 +1,57 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gohacks contains utilities for subverting the Go compiler.
+package gohacks
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// Noescape hides a pointer from escape analysis. Noescape is the identity
+// function but escape analysis doesn't think the output depends on the input.
+// Noescape is inlined and currently compiles down to zero instructions.
+// USE CAREFULLY!
+//
+// (Noescape is copy/pasted from Go's runtime/stubs.go:noescape().)
+//
+//go:nosplit
+func Noescape(p unsafe.Pointer) unsafe.Pointer {
+ x := uintptr(p)
+ return unsafe.Pointer(x ^ 0)
+}
+
+// ImmutableBytesFromString is equivalent to []byte(s), except that it uses the
+// same memory backing s instead of making a heap-allocated copy. This is only
+// valid if the returned slice is never mutated.
+func ImmutableBytesFromString(s string) []byte {
+ shdr := (*reflect.StringHeader)(unsafe.Pointer(&s))
+ var bs []byte
+ bshdr := (*reflect.SliceHeader)(unsafe.Pointer(&bs))
+ bshdr.Data = shdr.Data
+ bshdr.Len = shdr.Len
+ bshdr.Cap = shdr.Len
+ return bs
+}
+
+// StringFromImmutableBytes is equivalent to string(bs), except that it uses
+// the same memory backing bs instead of making a heap-allocated copy. This is
+// only valid if bs is never mutated after StringFromImmutableBytes returns.
+func StringFromImmutableBytes(bs []byte) string {
+ // This is cheaper than messing with reflect.StringHeader and
+ // reflect.SliceHeader, which as of this writing produces many dead stores
+ // of zeroes. Compare strings.Builder.String().
+ return *(*string)(unsafe.Pointer(&bs))
+}
diff --git a/pkg/log/glog.go b/pkg/log/glog.go
index cab5fae55..b4f7bb5a4 100644
--- a/pkg/log/glog.go
+++ b/pkg/log/glog.go
@@ -46,7 +46,7 @@ var pid = os.Getpid()
// line The line number
// msg The user-supplied message
//
-func (g *GoogleEmitter) Emit(level Level, timestamp time.Time, format string, args ...interface{}) {
+func (g *GoogleEmitter) Emit(depth int, level Level, timestamp time.Time, format string, args ...interface{}) {
// Log level.
prefix := byte('?')
switch level {
@@ -64,9 +64,7 @@ func (g *GoogleEmitter) Emit(level Level, timestamp time.Time, format string, ar
microsecond := int(timestamp.Nanosecond() / 1000)
// 0 = this frame.
- // 1 = Debugf, etc.
- // 2 = Caller.
- _, file, line, ok := runtime.Caller(2)
+ _, file, line, ok := runtime.Caller(depth + 1)
if ok {
// Trim any directory path from the file.
slash := strings.LastIndexByte(file, byte('/'))
diff --git a/pkg/log/json.go b/pkg/log/json.go
index a278c8fc8..0943db1cc 100644
--- a/pkg/log/json.go
+++ b/pkg/log/json.go
@@ -62,7 +62,7 @@ type JSONEmitter struct {
}
// Emit implements Emitter.Emit.
-func (e JSONEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+func (e JSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
j := jsonLog{
Msg: fmt.Sprintf(format, v...),
Level: level,
diff --git a/pkg/log/json_k8s.go b/pkg/log/json_k8s.go
index cee6eb514..6c6fc8b6f 100644
--- a/pkg/log/json_k8s.go
+++ b/pkg/log/json_k8s.go
@@ -33,7 +33,7 @@ type K8sJSONEmitter struct {
}
// Emit implements Emitter.Emit.
-func (e *K8sJSONEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+func (e *K8sJSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
j := k8sJSONLog{
Log: fmt.Sprintf(format, v...),
Level: level,
diff --git a/pkg/log/log.go b/pkg/log/log.go
index 5056f17e6..a794da1aa 100644
--- a/pkg/log/log.go
+++ b/pkg/log/log.go
@@ -79,7 +79,7 @@ func (l Level) String() string {
type Emitter interface {
// Emit emits the given log statement. This allows for control over the
// timestamp used for logging.
- Emit(level Level, timestamp time.Time, format string, v ...interface{})
+ Emit(depth int, level Level, timestamp time.Time, format string, v ...interface{})
}
// Writer writes the output to the given writer.
@@ -142,7 +142,7 @@ func (l *Writer) Write(data []byte) (int, error) {
}
// Emit emits the message.
-func (l *Writer) Emit(level Level, timestamp time.Time, format string, args ...interface{}) {
+func (l *Writer) Emit(_ int, _ Level, _ time.Time, format string, args ...interface{}) {
fmt.Fprintf(l, format, args...)
}
@@ -150,9 +150,9 @@ func (l *Writer) Emit(level Level, timestamp time.Time, format string, args ...i
type MultiEmitter []Emitter
// Emit emits to all emitters.
-func (m *MultiEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+func (m *MultiEmitter) Emit(depth int, level Level, timestamp time.Time, format string, v ...interface{}) {
for _, e := range *m {
- e.Emit(level, timestamp, format, v...)
+ e.Emit(1+depth, level, timestamp, format, v...)
}
}
@@ -167,7 +167,7 @@ type TestEmitter struct {
}
// Emit emits to the TestLogger.
-func (t *TestEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+func (t *TestEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
t.Logf(format, v...)
}
@@ -198,22 +198,37 @@ type BasicLogger struct {
// Debugf implements logger.Debugf.
func (l *BasicLogger) Debugf(format string, v ...interface{}) {
- if l.IsLogging(Debug) {
- l.Emit(Debug, time.Now(), format, v...)
- }
+ l.DebugfAtDepth(1, format, v...)
}
// Infof implements logger.Infof.
func (l *BasicLogger) Infof(format string, v ...interface{}) {
- if l.IsLogging(Info) {
- l.Emit(Info, time.Now(), format, v...)
- }
+ l.InfofAtDepth(1, format, v...)
}
// Warningf implements logger.Warningf.
func (l *BasicLogger) Warningf(format string, v ...interface{}) {
+ l.WarningfAtDepth(1, format, v...)
+}
+
+// DebugfAtDepth logs at a specific depth.
+func (l *BasicLogger) DebugfAtDepth(depth int, format string, v ...interface{}) {
+ if l.IsLogging(Debug) {
+ l.Emit(1+depth, Debug, time.Now(), format, v...)
+ }
+}
+
+// InfofAtDepth logs at a specific depth.
+func (l *BasicLogger) InfofAtDepth(depth int, format string, v ...interface{}) {
+ if l.IsLogging(Info) {
+ l.Emit(1+depth, Info, time.Now(), format, v...)
+ }
+}
+
+// WarningfAtDepth logs at a specific depth.
+func (l *BasicLogger) WarningfAtDepth(depth int, format string, v ...interface{}) {
if l.IsLogging(Warning) {
- l.Emit(Warning, time.Now(), format, v...)
+ l.Emit(1+depth, Warning, time.Now(), format, v...)
}
}
@@ -257,17 +272,32 @@ func SetLevel(newLevel Level) {
// Debugf logs to the global logger.
func Debugf(format string, v ...interface{}) {
- Log().Debugf(format, v...)
+ Log().DebugfAtDepth(1, format, v...)
}
// Infof logs to the global logger.
func Infof(format string, v ...interface{}) {
- Log().Infof(format, v...)
+ Log().InfofAtDepth(1, format, v...)
}
// Warningf logs to the global logger.
func Warningf(format string, v ...interface{}) {
- Log().Warningf(format, v...)
+ Log().WarningfAtDepth(1, format, v...)
+}
+
+// DebugfAtDepth logs to the global logger.
+func DebugfAtDepth(depth int, format string, v ...interface{}) {
+ Log().DebugfAtDepth(1+depth, format, v...)
+}
+
+// InfofAtDepth logs to the global logger.
+func InfofAtDepth(depth int, format string, v ...interface{}) {
+ Log().InfofAtDepth(1+depth, format, v...)
+}
+
+// WarningfAtDepth logs to the global logger.
+func WarningfAtDepth(depth int, format string, v ...interface{}) {
+ Log().WarningfAtDepth(1+depth, format, v...)
}
// defaultStackSize is the default buffer size to allocate for stack traces.
diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go
index 93d4f2b8c..006fcd9ab 100644
--- a/pkg/metric/metric.go
+++ b/pkg/metric/metric.go
@@ -46,7 +46,6 @@ var (
//
// TODO(b/67298402): Support non-cumulative metrics.
// TODO(b/67298427): Support metric fields.
-//
type Uint64Metric struct {
// value is the actual value of the metric. It must be accessed
// atomically.
diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD
index 4ccc1de86..8904afad9 100644
--- a/pkg/p9/BUILD
+++ b/pkg/p9/BUILD
@@ -16,7 +16,6 @@ go_library(
"messages.go",
"p9.go",
"path_tree.go",
- "pool.go",
"server.go",
"transport.go",
"transport_flipcall.go",
@@ -27,6 +26,7 @@ go_library(
"//pkg/fdchannel",
"//pkg/flipcall",
"//pkg/log",
+ "//pkg/pool",
"//pkg/sync",
"//pkg/unet",
"@org_golang_x_sys//unix:go_default_library",
@@ -41,7 +41,6 @@ go_test(
"client_test.go",
"messages_test.go",
"p9_test.go",
- "pool_test.go",
"transport_test.go",
"version_test.go",
],
diff --git a/pkg/p9/buffer.go b/pkg/p9/buffer.go
index 249536d8a..6a4951821 100644
--- a/pkg/p9/buffer.go
+++ b/pkg/p9/buffer.go
@@ -20,16 +20,16 @@ import (
// encoder is used for messages and 9P primitives.
type encoder interface {
- // Decode decodes from the given buffer. Decode may be called more than once
+ // decode decodes from the given buffer. decode may be called more than once
// to reuse the instance. It must clear any previous state.
//
// This may not fail, exhaustion will be recorded in the buffer.
- Decode(b *buffer)
+ decode(b *buffer)
- // Encode encodes to the given buffer.
+ // encode encodes to the given buffer.
//
// This may not fail.
- Encode(b *buffer)
+ encode(b *buffer)
}
// order is the byte order used for encoding.
@@ -39,7 +39,7 @@ var order = binary.LittleEndian
//
// This is passed to the encoder methods.
type buffer struct {
- // data is the underlying data. This may grow during Encode.
+ // data is the underlying data. This may grow during encode.
data []byte
// overflow indicates whether an overflow has occurred.
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index 4045e41fa..a6f493b82 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -22,6 +22,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/flipcall"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/pool"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
@@ -74,10 +75,10 @@ type Client struct {
socket *unet.Socket
// tagPool is the collection of available tags.
- tagPool pool
+ tagPool pool.Pool
// fidPool is the collection of available fids.
- fidPool pool
+ fidPool pool.Pool
// messageSize is the maximum total size of a message.
messageSize uint32
@@ -155,8 +156,8 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
}
c := &Client{
socket: socket,
- tagPool: pool{start: 1, limit: uint64(NoTag)},
- fidPool: pool{start: 1, limit: uint64(NoFID)},
+ tagPool: pool.Pool{Start: 1, Limit: uint64(NoTag)},
+ fidPool: pool.Pool{Start: 1, Limit: uint64(NoFID)},
pending: make(map[Tag]*response),
recvr: make(chan bool, 1),
messageSize: messageSize,
diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go
index 0254e4ccc..2ee07b664 100644
--- a/pkg/p9/client_file.go
+++ b/pkg/p9/client_file.go
@@ -194,6 +194,39 @@ func (c *clientFile) SetXattr(name, value string, flags uint32) error {
return c.client.sendRecv(&Tsetxattr{FID: c.fid, Name: name, Value: value, Flags: flags}, &Rsetxattr{})
}
+// ListXattr implements File.ListXattr.
+func (c *clientFile) ListXattr(size uint64) (map[string]struct{}, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, syscall.EBADF
+ }
+ if !versionSupportsListRemoveXattr(c.client.version) {
+ return nil, syscall.EOPNOTSUPP
+ }
+
+ rlistxattr := Rlistxattr{}
+ if err := c.client.sendRecv(&Tlistxattr{FID: c.fid, Size: size}, &rlistxattr); err != nil {
+ return nil, err
+ }
+
+ xattrs := make(map[string]struct{}, len(rlistxattr.Xattrs))
+ for _, x := range rlistxattr.Xattrs {
+ xattrs[x] = struct{}{}
+ }
+ return xattrs, nil
+}
+
+// RemoveXattr implements File.RemoveXattr.
+func (c *clientFile) RemoveXattr(name string) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+ if !versionSupportsListRemoveXattr(c.client.version) {
+ return syscall.EOPNOTSUPP
+ }
+
+ return c.client.sendRecv(&Tremovexattr{FID: c.fid, Name: name}, &Rremovexattr{})
+}
+
// Allocate implements File.Allocate.
func (c *clientFile) Allocate(mode AllocateMode, offset, length uint64) error {
if atomic.LoadUint32(&c.closed) != 0 {
diff --git a/pkg/p9/file.go b/pkg/p9/file.go
index 4607cfcdf..d4ffbc8e3 100644
--- a/pkg/p9/file.go
+++ b/pkg/p9/file.go
@@ -105,6 +105,22 @@ type File interface {
// TODO(b/127675828): Determine concurrency guarantees once implemented.
SetXattr(name, value string, flags uint32) error
+ // ListXattr lists the names of the extended attributes on this node.
+ //
+ // Size indicates the size of the buffer that has been allocated to hold the
+ // attribute list. If the list would be larger than size, implementations may
+ // return ERANGE to indicate that the buffer is too small, but they are also
+ // free to ignore the hint entirely (i.e. the value returned may be larger
+ // than size). All size checking is done independently at the syscall layer.
+ //
+ // TODO(b/148303075): Determine concurrency guarantees once implemented.
+ ListXattr(size uint64) (map[string]struct{}, error)
+
+ // RemoveXattr removes extended attributes on this node.
+ //
+ // TODO(b/148303075): Determine concurrency guarantees once implemented.
+ RemoveXattr(name string) error
+
// Allocate allows the caller to directly manipulate the allocated disk space
// for the file. See fallocate(2) for more details.
Allocate(mode AllocateMode, offset, length uint64) error
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
index 7d6653a07..2ac45eb80 100644
--- a/pkg/p9/handlers.go
+++ b/pkg/p9/handlers.go
@@ -942,6 +942,39 @@ func (t *Tsetxattr) handle(cs *connState) message {
}
// handle implements handler.handle.
+func (t *Tlistxattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ xattrs, err := ref.file.ListXattr(t.Size)
+ if err != nil {
+ return newErr(err)
+ }
+ xattrList := make([]string, 0, len(xattrs))
+ for x := range xattrs {
+ xattrList = append(xattrList, x)
+ }
+ return &Rlistxattr{Xattrs: xattrList}
+}
+
+// handle implements handler.handle.
+func (t *Tremovexattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.file.RemoveXattr(t.Name); err != nil {
+ return newErr(err)
+ }
+ return &Rremovexattr{}
+}
+
+// handle implements handler.handle.
func (t *Treaddir) handle(cs *connState) message {
ref, ok := cs.LookupFID(t.Directory)
if !ok {
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go
index ceb723d86..3863ad1f5 100644
--- a/pkg/p9/messages.go
+++ b/pkg/p9/messages.go
@@ -51,7 +51,7 @@ type payloader interface {
// SetPayload returns the decoded message.
//
// This is going to be total message size - FixedSize. But this should
- // be validated during Decode, which will be called after SetPayload.
+ // be validated during decode, which will be called after SetPayload.
SetPayload([]byte)
}
@@ -90,14 +90,14 @@ type Tversion struct {
Version string
}
-// Decode implements encoder.Decode.
-func (t *Tversion) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tversion) decode(b *buffer) {
t.MSize = b.Read32()
t.Version = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (t *Tversion) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tversion) encode(b *buffer) {
b.Write32(t.MSize)
b.WriteString(t.Version)
}
@@ -121,14 +121,14 @@ type Rversion struct {
Version string
}
-// Decode implements encoder.Decode.
-func (r *Rversion) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rversion) decode(b *buffer) {
r.MSize = b.Read32()
r.Version = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (r *Rversion) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rversion) encode(b *buffer) {
b.Write32(r.MSize)
b.WriteString(r.Version)
}
@@ -149,13 +149,13 @@ type Tflush struct {
OldTag Tag
}
-// Decode implements encoder.Decode.
-func (t *Tflush) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tflush) decode(b *buffer) {
t.OldTag = b.ReadTag()
}
-// Encode implements encoder.Encode.
-func (t *Tflush) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tflush) encode(b *buffer) {
b.WriteTag(t.OldTag)
}
@@ -173,12 +173,12 @@ func (t *Tflush) String() string {
type Rflush struct {
}
-// Decode implements encoder.Decode.
-func (*Rflush) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (*Rflush) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rflush) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (*Rflush) encode(*buffer) {
}
// Type implements message.Type.
@@ -188,7 +188,7 @@ func (*Rflush) Type() MsgType {
// String implements fmt.Stringer.
func (r *Rflush) String() string {
- return fmt.Sprintf("RFlush{}")
+ return "RFlush{}"
}
// Twalk is a walk request.
@@ -203,8 +203,8 @@ type Twalk struct {
Names []string
}
-// Decode implements encoder.Decode.
-func (t *Twalk) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Twalk) decode(b *buffer) {
t.FID = b.ReadFID()
t.NewFID = b.ReadFID()
n := b.Read16()
@@ -214,8 +214,8 @@ func (t *Twalk) Decode(b *buffer) {
}
}
-// Encode implements encoder.Encode.
-func (t *Twalk) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Twalk) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteFID(t.NewFID)
b.Write16(uint16(len(t.Names)))
@@ -240,22 +240,22 @@ type Rwalk struct {
QIDs []QID
}
-// Decode implements encoder.Decode.
-func (r *Rwalk) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rwalk) decode(b *buffer) {
n := b.Read16()
r.QIDs = r.QIDs[:0]
for i := 0; i < int(n); i++ {
var q QID
- q.Decode(b)
+ q.decode(b)
r.QIDs = append(r.QIDs, q)
}
}
-// Encode implements encoder.Encode.
-func (r *Rwalk) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rwalk) encode(b *buffer) {
b.Write16(uint16(len(r.QIDs)))
for _, q := range r.QIDs {
- q.Encode(b)
+ q.encode(b)
}
}
@@ -275,13 +275,13 @@ type Tclunk struct {
FID FID
}
-// Decode implements encoder.Decode.
-func (t *Tclunk) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tclunk) decode(b *buffer) {
t.FID = b.ReadFID()
}
-// Encode implements encoder.Encode.
-func (t *Tclunk) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tclunk) encode(b *buffer) {
b.WriteFID(t.FID)
}
@@ -299,12 +299,12 @@ func (t *Tclunk) String() string {
type Rclunk struct {
}
-// Decode implements encoder.Decode.
-func (*Rclunk) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (*Rclunk) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rclunk) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (*Rclunk) encode(*buffer) {
}
// Type implements message.Type.
@@ -314,7 +314,7 @@ func (*Rclunk) Type() MsgType {
// String implements fmt.Stringer.
func (r *Rclunk) String() string {
- return fmt.Sprintf("Rclunk{}")
+ return "Rclunk{}"
}
// Tremove is a remove request.
@@ -325,13 +325,13 @@ type Tremove struct {
FID FID
}
-// Decode implements encoder.Decode.
-func (t *Tremove) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tremove) decode(b *buffer) {
t.FID = b.ReadFID()
}
-// Encode implements encoder.Encode.
-func (t *Tremove) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tremove) encode(b *buffer) {
b.WriteFID(t.FID)
}
@@ -349,12 +349,12 @@ func (t *Tremove) String() string {
type Rremove struct {
}
-// Decode implements encoder.Decode.
-func (*Rremove) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (*Rremove) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rremove) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (*Rremove) encode(*buffer) {
}
// Type implements message.Type.
@@ -364,7 +364,7 @@ func (*Rremove) Type() MsgType {
// String implements fmt.Stringer.
func (r *Rremove) String() string {
- return fmt.Sprintf("Rremove{}")
+ return "Rremove{}"
}
// Rlerror is an error response.
@@ -374,13 +374,13 @@ type Rlerror struct {
Error uint32
}
-// Decode implements encoder.Decode.
-func (r *Rlerror) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rlerror) decode(b *buffer) {
r.Error = b.Read32()
}
-// Encode implements encoder.Encode.
-func (r *Rlerror) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rlerror) encode(b *buffer) {
b.Write32(r.Error)
}
@@ -409,16 +409,16 @@ type Tauth struct {
UID UID
}
-// Decode implements encoder.Decode.
-func (t *Tauth) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tauth) decode(b *buffer) {
t.AuthenticationFID = b.ReadFID()
t.UserName = b.ReadString()
t.AttachName = b.ReadString()
t.UID = b.ReadUID()
}
-// Encode implements encoder.Encode.
-func (t *Tauth) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tauth) encode(b *buffer) {
b.WriteFID(t.AuthenticationFID)
b.WriteString(t.UserName)
b.WriteString(t.AttachName)
@@ -437,7 +437,7 @@ func (t *Tauth) String() string {
// Rauth is an authentication response.
//
-// Encode, Decode and Length are inherited directly from QID.
+// encode and decode are inherited directly from QID.
type Rauth struct {
QID
}
@@ -463,16 +463,16 @@ type Tattach struct {
Auth Tauth
}
-// Decode implements encoder.Decode.
-func (t *Tattach) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tattach) decode(b *buffer) {
t.FID = b.ReadFID()
- t.Auth.Decode(b)
+ t.Auth.decode(b)
}
-// Encode implements encoder.Encode.
-func (t *Tattach) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tattach) encode(b *buffer) {
b.WriteFID(t.FID)
- t.Auth.Encode(b)
+ t.Auth.encode(b)
}
// Type implements message.Type.
@@ -509,14 +509,14 @@ type Tlopen struct {
Flags OpenFlags
}
-// Decode implements encoder.Decode.
-func (t *Tlopen) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tlopen) decode(b *buffer) {
t.FID = b.ReadFID()
t.Flags = b.ReadOpenFlags()
}
-// Encode implements encoder.Encode.
-func (t *Tlopen) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tlopen) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteOpenFlags(t.Flags)
}
@@ -542,15 +542,15 @@ type Rlopen struct {
filePayload
}
-// Decode implements encoder.Decode.
-func (r *Rlopen) Decode(b *buffer) {
- r.QID.Decode(b)
+// decode implements encoder.decode.
+func (r *Rlopen) decode(b *buffer) {
+ r.QID.decode(b)
r.IoUnit = b.Read32()
}
-// Encode implements encoder.Encode.
-func (r *Rlopen) Encode(b *buffer) {
- r.QID.Encode(b)
+// encode implements encoder.encode.
+func (r *Rlopen) encode(b *buffer) {
+ r.QID.encode(b)
b.Write32(r.IoUnit)
}
@@ -587,8 +587,8 @@ type Tlcreate struct {
GID GID
}
-// Decode implements encoder.Decode.
-func (t *Tlcreate) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tlcreate) decode(b *buffer) {
t.FID = b.ReadFID()
t.Name = b.ReadString()
t.OpenFlags = b.ReadOpenFlags()
@@ -596,8 +596,8 @@ func (t *Tlcreate) Decode(b *buffer) {
t.GID = b.ReadGID()
}
-// Encode implements encoder.Encode.
-func (t *Tlcreate) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tlcreate) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteString(t.Name)
b.WriteOpenFlags(t.OpenFlags)
@@ -617,7 +617,7 @@ func (t *Tlcreate) String() string {
// Rlcreate is a create response.
//
-// The Encode, Decode, etc. methods are inherited from Rlopen.
+// The encode, decode, etc. methods are inherited from Rlopen.
type Rlcreate struct {
Rlopen
}
@@ -647,16 +647,16 @@ type Tsymlink struct {
GID GID
}
-// Decode implements encoder.Decode.
-func (t *Tsymlink) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tsymlink) decode(b *buffer) {
t.Directory = b.ReadFID()
t.Name = b.ReadString()
t.Target = b.ReadString()
t.GID = b.ReadGID()
}
-// Encode implements encoder.Encode.
-func (t *Tsymlink) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tsymlink) encode(b *buffer) {
b.WriteFID(t.Directory)
b.WriteString(t.Name)
b.WriteString(t.Target)
@@ -679,14 +679,14 @@ type Rsymlink struct {
QID QID
}
-// Decode implements encoder.Decode.
-func (r *Rsymlink) Decode(b *buffer) {
- r.QID.Decode(b)
+// decode implements encoder.decode.
+func (r *Rsymlink) decode(b *buffer) {
+ r.QID.decode(b)
}
-// Encode implements encoder.Encode.
-func (r *Rsymlink) Encode(b *buffer) {
- r.QID.Encode(b)
+// encode implements encoder.encode.
+func (r *Rsymlink) encode(b *buffer) {
+ r.QID.encode(b)
}
// Type implements message.Type.
@@ -711,15 +711,15 @@ type Tlink struct {
Name string
}
-// Decode implements encoder.Decode.
-func (t *Tlink) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tlink) decode(b *buffer) {
t.Directory = b.ReadFID()
t.Target = b.ReadFID()
t.Name = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (t *Tlink) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tlink) encode(b *buffer) {
b.WriteFID(t.Directory)
b.WriteFID(t.Target)
b.WriteString(t.Name)
@@ -744,17 +744,17 @@ func (*Rlink) Type() MsgType {
return MsgRlink
}
-// Decode implements encoder.Decode.
-func (*Rlink) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (*Rlink) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rlink) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (*Rlink) encode(*buffer) {
}
// String implements fmt.Stringer.
func (r *Rlink) String() string {
- return fmt.Sprintf("Rlink{}")
+ return "Rlink{}"
}
// Trenameat is a rename request.
@@ -772,16 +772,16 @@ type Trenameat struct {
NewName string
}
-// Decode implements encoder.Decode.
-func (t *Trenameat) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Trenameat) decode(b *buffer) {
t.OldDirectory = b.ReadFID()
t.OldName = b.ReadString()
t.NewDirectory = b.ReadFID()
t.NewName = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (t *Trenameat) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Trenameat) encode(b *buffer) {
b.WriteFID(t.OldDirectory)
b.WriteString(t.OldName)
b.WriteFID(t.NewDirectory)
@@ -802,12 +802,12 @@ func (t *Trenameat) String() string {
type Rrenameat struct {
}
-// Decode implements encoder.Decode.
-func (*Rrenameat) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (*Rrenameat) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rrenameat) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (*Rrenameat) encode(*buffer) {
}
// Type implements message.Type.
@@ -817,7 +817,7 @@ func (*Rrenameat) Type() MsgType {
// String implements fmt.Stringer.
func (r *Rrenameat) String() string {
- return fmt.Sprintf("Rrenameat{}")
+ return "Rrenameat{}"
}
// Tunlinkat is an unlink request.
@@ -832,15 +832,15 @@ type Tunlinkat struct {
Flags uint32
}
-// Decode implements encoder.Decode.
-func (t *Tunlinkat) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tunlinkat) decode(b *buffer) {
t.Directory = b.ReadFID()
t.Name = b.ReadString()
t.Flags = b.Read32()
}
-// Encode implements encoder.Encode.
-func (t *Tunlinkat) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tunlinkat) encode(b *buffer) {
b.WriteFID(t.Directory)
b.WriteString(t.Name)
b.Write32(t.Flags)
@@ -860,12 +860,12 @@ func (t *Tunlinkat) String() string {
type Runlinkat struct {
}
-// Decode implements encoder.Decode.
-func (*Runlinkat) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (*Runlinkat) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Runlinkat) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (*Runlinkat) encode(*buffer) {
}
// Type implements message.Type.
@@ -875,7 +875,7 @@ func (*Runlinkat) Type() MsgType {
// String implements fmt.Stringer.
func (r *Runlinkat) String() string {
- return fmt.Sprintf("Runlinkat{}")
+ return "Runlinkat{}"
}
// Trename is a rename request.
@@ -893,15 +893,15 @@ type Trename struct {
Name string
}
-// Decode implements encoder.Decode.
-func (t *Trename) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Trename) decode(b *buffer) {
t.FID = b.ReadFID()
t.Directory = b.ReadFID()
t.Name = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (t *Trename) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Trename) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteFID(t.Directory)
b.WriteString(t.Name)
@@ -921,12 +921,12 @@ func (t *Trename) String() string {
type Rrename struct {
}
-// Decode implements encoder.Decode.
-func (*Rrename) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (*Rrename) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rrename) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (*Rrename) encode(*buffer) {
}
// Type implements message.Type.
@@ -936,7 +936,7 @@ func (*Rrename) Type() MsgType {
// String implements fmt.Stringer.
func (r *Rrename) String() string {
- return fmt.Sprintf("Rrename{}")
+ return "Rrename{}"
}
// Treadlink is a readlink request.
@@ -945,13 +945,13 @@ type Treadlink struct {
FID FID
}
-// Decode implements encoder.Decode.
-func (t *Treadlink) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Treadlink) decode(b *buffer) {
t.FID = b.ReadFID()
}
-// Encode implements encoder.Encode.
-func (t *Treadlink) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Treadlink) encode(b *buffer) {
b.WriteFID(t.FID)
}
@@ -971,13 +971,13 @@ type Rreadlink struct {
Target string
}
-// Decode implements encoder.Decode.
-func (r *Rreadlink) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rreadlink) decode(b *buffer) {
r.Target = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (r *Rreadlink) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rreadlink) encode(b *buffer) {
b.WriteString(r.Target)
}
@@ -1003,15 +1003,15 @@ type Tread struct {
Count uint32
}
-// Decode implements encoder.Decode.
-func (t *Tread) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tread) decode(b *buffer) {
t.FID = b.ReadFID()
t.Offset = b.Read64()
t.Count = b.Read32()
}
-// Encode implements encoder.Encode.
-func (t *Tread) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tread) encode(b *buffer) {
b.WriteFID(t.FID)
b.Write64(t.Offset)
b.Write32(t.Count)
@@ -1033,20 +1033,20 @@ type Rread struct {
Data []byte
}
-// Decode implements encoder.Decode.
+// decode implements encoder.decode.
//
// Data is automatically decoded via Payload.
-func (r *Rread) Decode(b *buffer) {
+func (r *Rread) decode(b *buffer) {
count := b.Read32()
if count != uint32(len(r.Data)) {
b.markOverrun()
}
}
-// Encode implements encoder.Encode.
+// encode implements encoder.encode.
//
// Data is automatically encoded via Payload.
-func (r *Rread) Encode(b *buffer) {
+func (r *Rread) encode(b *buffer) {
b.Write32(uint32(len(r.Data)))
}
@@ -1087,8 +1087,8 @@ type Twrite struct {
Data []byte
}
-// Decode implements encoder.Decode.
-func (t *Twrite) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Twrite) decode(b *buffer) {
t.FID = b.ReadFID()
t.Offset = b.Read64()
count := b.Read32()
@@ -1097,10 +1097,10 @@ func (t *Twrite) Decode(b *buffer) {
}
}
-// Encode implements encoder.Encode.
+// encode implements encoder.encode.
//
// This uses the buffer payload to avoid a copy.
-func (t *Twrite) Encode(b *buffer) {
+func (t *Twrite) encode(b *buffer) {
b.WriteFID(t.FID)
b.Write64(t.Offset)
b.Write32(uint32(len(t.Data)))
@@ -1137,13 +1137,13 @@ type Rwrite struct {
Count uint32
}
-// Decode implements encoder.Decode.
-func (r *Rwrite) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rwrite) decode(b *buffer) {
r.Count = b.Read32()
}
-// Encode implements encoder.Encode.
-func (r *Rwrite) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rwrite) encode(b *buffer) {
b.Write32(r.Count)
}
@@ -1178,8 +1178,8 @@ type Tmknod struct {
GID GID
}
-// Decode implements encoder.Decode.
-func (t *Tmknod) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tmknod) decode(b *buffer) {
t.Directory = b.ReadFID()
t.Name = b.ReadString()
t.Mode = b.ReadFileMode()
@@ -1188,8 +1188,8 @@ func (t *Tmknod) Decode(b *buffer) {
t.GID = b.ReadGID()
}
-// Encode implements encoder.Encode.
-func (t *Tmknod) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tmknod) encode(b *buffer) {
b.WriteFID(t.Directory)
b.WriteString(t.Name)
b.WriteFileMode(t.Mode)
@@ -1214,14 +1214,14 @@ type Rmknod struct {
QID QID
}
-// Decode implements encoder.Decode.
-func (r *Rmknod) Decode(b *buffer) {
- r.QID.Decode(b)
+// decode implements encoder.decode.
+func (r *Rmknod) decode(b *buffer) {
+ r.QID.decode(b)
}
-// Encode implements encoder.Encode.
-func (r *Rmknod) Encode(b *buffer) {
- r.QID.Encode(b)
+// encode implements encoder.encode.
+func (r *Rmknod) encode(b *buffer) {
+ r.QID.encode(b)
}
// Type implements message.Type.
@@ -1249,16 +1249,16 @@ type Tmkdir struct {
GID GID
}
-// Decode implements encoder.Decode.
-func (t *Tmkdir) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tmkdir) decode(b *buffer) {
t.Directory = b.ReadFID()
t.Name = b.ReadString()
t.Permissions = b.ReadPermissions()
t.GID = b.ReadGID()
}
-// Encode implements encoder.Encode.
-func (t *Tmkdir) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tmkdir) encode(b *buffer) {
b.WriteFID(t.Directory)
b.WriteString(t.Name)
b.WritePermissions(t.Permissions)
@@ -1281,14 +1281,14 @@ type Rmkdir struct {
QID QID
}
-// Decode implements encoder.Decode.
-func (r *Rmkdir) Decode(b *buffer) {
- r.QID.Decode(b)
+// decode implements encoder.decode.
+func (r *Rmkdir) decode(b *buffer) {
+ r.QID.decode(b)
}
-// Encode implements encoder.Encode.
-func (r *Rmkdir) Encode(b *buffer) {
- r.QID.Encode(b)
+// encode implements encoder.encode.
+func (r *Rmkdir) encode(b *buffer) {
+ r.QID.encode(b)
}
// Type implements message.Type.
@@ -1310,16 +1310,16 @@ type Tgetattr struct {
AttrMask AttrMask
}
-// Decode implements encoder.Decode.
-func (t *Tgetattr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tgetattr) decode(b *buffer) {
t.FID = b.ReadFID()
- t.AttrMask.Decode(b)
+ t.AttrMask.decode(b)
}
-// Encode implements encoder.Encode.
-func (t *Tgetattr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tgetattr) encode(b *buffer) {
b.WriteFID(t.FID)
- t.AttrMask.Encode(b)
+ t.AttrMask.encode(b)
}
// Type implements message.Type.
@@ -1344,18 +1344,18 @@ type Rgetattr struct {
Attr Attr
}
-// Decode implements encoder.Decode.
-func (r *Rgetattr) Decode(b *buffer) {
- r.Valid.Decode(b)
- r.QID.Decode(b)
- r.Attr.Decode(b)
+// decode implements encoder.decode.
+func (r *Rgetattr) decode(b *buffer) {
+ r.Valid.decode(b)
+ r.QID.decode(b)
+ r.Attr.decode(b)
}
-// Encode implements encoder.Encode.
-func (r *Rgetattr) Encode(b *buffer) {
- r.Valid.Encode(b)
- r.QID.Encode(b)
- r.Attr.Encode(b)
+// encode implements encoder.encode.
+func (r *Rgetattr) encode(b *buffer) {
+ r.Valid.encode(b)
+ r.QID.encode(b)
+ r.Attr.encode(b)
}
// Type implements message.Type.
@@ -1380,18 +1380,18 @@ type Tsetattr struct {
SetAttr SetAttr
}
-// Decode implements encoder.Decode.
-func (t *Tsetattr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tsetattr) decode(b *buffer) {
t.FID = b.ReadFID()
- t.Valid.Decode(b)
- t.SetAttr.Decode(b)
+ t.Valid.decode(b)
+ t.SetAttr.decode(b)
}
-// Encode implements encoder.Encode.
-func (t *Tsetattr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tsetattr) encode(b *buffer) {
b.WriteFID(t.FID)
- t.Valid.Encode(b)
- t.SetAttr.Encode(b)
+ t.Valid.encode(b)
+ t.SetAttr.encode(b)
}
// Type implements message.Type.
@@ -1408,12 +1408,12 @@ func (t *Tsetattr) String() string {
type Rsetattr struct {
}
-// Decode implements encoder.Decode.
-func (*Rsetattr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (*Rsetattr) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rsetattr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (*Rsetattr) encode(*buffer) {
}
// Type implements message.Type.
@@ -1423,7 +1423,7 @@ func (*Rsetattr) Type() MsgType {
// String implements fmt.Stringer.
func (r *Rsetattr) String() string {
- return fmt.Sprintf("Rsetattr{}")
+ return "Rsetattr{}"
}
// Tallocate is an allocate request. This is an extension to 9P protocol, not
@@ -1435,18 +1435,18 @@ type Tallocate struct {
Length uint64
}
-// Decode implements encoder.Decode.
-func (t *Tallocate) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tallocate) decode(b *buffer) {
t.FID = b.ReadFID()
- t.Mode.Decode(b)
+ t.Mode.decode(b)
t.Offset = b.Read64()
t.Length = b.Read64()
}
-// Encode implements encoder.Encode.
-func (t *Tallocate) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tallocate) encode(b *buffer) {
b.WriteFID(t.FID)
- t.Mode.Encode(b)
+ t.Mode.encode(b)
b.Write64(t.Offset)
b.Write64(t.Length)
}
@@ -1465,12 +1465,12 @@ func (t *Tallocate) String() string {
type Rallocate struct {
}
-// Decode implements encoder.Decode.
-func (*Rallocate) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (*Rallocate) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rallocate) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (*Rallocate) encode(*buffer) {
}
// Type implements message.Type.
@@ -1480,7 +1480,71 @@ func (*Rallocate) Type() MsgType {
// String implements fmt.Stringer.
func (r *Rallocate) String() string {
- return fmt.Sprintf("Rallocate{}")
+ return "Rallocate{}"
+}
+
+// Tlistxattr is a listxattr request.
+type Tlistxattr struct {
+ // FID refers to the file on which to list xattrs.
+ FID FID
+
+ // Size is the buffer size for the xattr list.
+ Size uint64
+}
+
+// decode implements encoder.decode.
+func (t *Tlistxattr) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Size = b.Read64()
+}
+
+// encode implements encoder.encode.
+func (t *Tlistxattr) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.Write64(t.Size)
+}
+
+// Type implements message.Type.
+func (*Tlistxattr) Type() MsgType {
+ return MsgTlistxattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tlistxattr) String() string {
+ return fmt.Sprintf("Tlistxattr{FID: %d, Size: %d}", t.FID, t.Size)
+}
+
+// Rlistxattr is a listxattr response.
+type Rlistxattr struct {
+ // Xattrs is a list of extended attribute names.
+ Xattrs []string
+}
+
+// decode implements encoder.decode.
+func (r *Rlistxattr) decode(b *buffer) {
+ n := b.Read16()
+ r.Xattrs = r.Xattrs[:0]
+ for i := 0; i < int(n); i++ {
+ r.Xattrs = append(r.Xattrs, b.ReadString())
+ }
+}
+
+// encode implements encoder.encode.
+func (r *Rlistxattr) encode(b *buffer) {
+ b.Write16(uint16(len(r.Xattrs)))
+ for _, x := range r.Xattrs {
+ b.WriteString(x)
+ }
+}
+
+// Type implements message.Type.
+func (*Rlistxattr) Type() MsgType {
+ return MsgRlistxattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rlistxattr) String() string {
+ return fmt.Sprintf("Rlistxattr{Xattrs: %v}", r.Xattrs)
}
// Txattrwalk walks extended attributes.
@@ -1495,15 +1559,15 @@ type Txattrwalk struct {
Name string
}
-// Decode implements encoder.Decode.
-func (t *Txattrwalk) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Txattrwalk) decode(b *buffer) {
t.FID = b.ReadFID()
t.NewFID = b.ReadFID()
t.Name = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (t *Txattrwalk) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Txattrwalk) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteFID(t.NewFID)
b.WriteString(t.Name)
@@ -1525,13 +1589,13 @@ type Rxattrwalk struct {
Size uint64
}
-// Decode implements encoder.Decode.
-func (r *Rxattrwalk) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rxattrwalk) decode(b *buffer) {
r.Size = b.Read64()
}
-// Encode implements encoder.Encode.
-func (r *Rxattrwalk) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rxattrwalk) encode(b *buffer) {
b.Write64(r.Size)
}
@@ -1563,16 +1627,16 @@ type Txattrcreate struct {
Flags uint32
}
-// Decode implements encoder.Decode.
-func (t *Txattrcreate) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Txattrcreate) decode(b *buffer) {
t.FID = b.ReadFID()
t.Name = b.ReadString()
t.AttrSize = b.Read64()
t.Flags = b.Read32()
}
-// Encode implements encoder.Encode.
-func (t *Txattrcreate) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Txattrcreate) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteString(t.Name)
b.Write64(t.AttrSize)
@@ -1593,12 +1657,12 @@ func (t *Txattrcreate) String() string {
type Rxattrcreate struct {
}
-// Decode implements encoder.Decode.
-func (r *Rxattrcreate) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rxattrcreate) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (r *Rxattrcreate) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rxattrcreate) encode(*buffer) {
}
// Type implements message.Type.
@@ -1608,7 +1672,7 @@ func (*Rxattrcreate) Type() MsgType {
// String implements fmt.Stringer.
func (r *Rxattrcreate) String() string {
- return fmt.Sprintf("Rxattrcreate{}")
+ return "Rxattrcreate{}"
}
// Tgetxattr is a getxattr request.
@@ -1623,15 +1687,15 @@ type Tgetxattr struct {
Size uint64
}
-// Decode implements encoder.Decode.
-func (t *Tgetxattr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tgetxattr) decode(b *buffer) {
t.FID = b.ReadFID()
t.Name = b.ReadString()
t.Size = b.Read64()
}
-// Encode implements encoder.Encode.
-func (t *Tgetxattr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tgetxattr) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteString(t.Name)
b.Write64(t.Size)
@@ -1653,13 +1717,13 @@ type Rgetxattr struct {
Value string
}
-// Decode implements encoder.Decode.
-func (r *Rgetxattr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rgetxattr) decode(b *buffer) {
r.Value = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (r *Rgetxattr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rgetxattr) encode(b *buffer) {
b.WriteString(r.Value)
}
@@ -1688,16 +1752,16 @@ type Tsetxattr struct {
Flags uint32
}
-// Decode implements encoder.Decode.
-func (t *Tsetxattr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tsetxattr) decode(b *buffer) {
t.FID = b.ReadFID()
t.Name = b.ReadString()
t.Value = b.ReadString()
t.Flags = b.Read32()
}
-// Encode implements encoder.Encode.
-func (t *Tsetxattr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tsetxattr) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteString(t.Name)
b.WriteString(t.Value)
@@ -1718,12 +1782,12 @@ func (t *Tsetxattr) String() string {
type Rsetxattr struct {
}
-// Decode implements encoder.Decode.
-func (r *Rsetxattr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rsetxattr) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (r *Rsetxattr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rsetxattr) encode(*buffer) {
}
// Type implements message.Type.
@@ -1733,7 +1797,60 @@ func (*Rsetxattr) Type() MsgType {
// String implements fmt.Stringer.
func (r *Rsetxattr) String() string {
- return fmt.Sprintf("Rsetxattr{}")
+ return "Rsetxattr{}"
+}
+
+// Tremovexattr is a removexattr request.
+type Tremovexattr struct {
+ // FID refers to the file on which to set xattrs.
+ FID FID
+
+ // Name is the attribute name.
+ Name string
+}
+
+// decode implements encoder.decode.
+func (t *Tremovexattr) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Name = b.ReadString()
+}
+
+// encode implements encoder.encode.
+func (t *Tremovexattr) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteString(t.Name)
+}
+
+// Type implements message.Type.
+func (*Tremovexattr) Type() MsgType {
+ return MsgTremovexattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tremovexattr) String() string {
+ return fmt.Sprintf("Tremovexattr{FID: %d, Name: %s}", t.FID, t.Name)
+}
+
+// Rremovexattr is a removexattr response.
+type Rremovexattr struct {
+}
+
+// decode implements encoder.decode.
+func (r *Rremovexattr) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (r *Rremovexattr) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rremovexattr) Type() MsgType {
+ return MsgRremovexattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rremovexattr) String() string {
+ return "Rremovexattr{}"
}
// Treaddir is a readdir request.
@@ -1748,15 +1865,15 @@ type Treaddir struct {
Count uint32
}
-// Decode implements encoder.Decode.
-func (t *Treaddir) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Treaddir) decode(b *buffer) {
t.Directory = b.ReadFID()
t.Offset = b.Read64()
t.Count = b.Read32()
}
-// Encode implements encoder.Encode.
-func (t *Treaddir) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Treaddir) encode(b *buffer) {
b.WriteFID(t.Directory)
b.Write64(t.Offset)
b.Write32(t.Count)
@@ -1790,14 +1907,14 @@ type Rreaddir struct {
payload []byte
}
-// Decode implements encoder.Decode.
-func (r *Rreaddir) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rreaddir) decode(b *buffer) {
r.Count = b.Read32()
entriesBuf := buffer{data: r.payload}
r.Entries = r.Entries[:0]
for {
var d Dirent
- d.Decode(&entriesBuf)
+ d.decode(&entriesBuf)
if entriesBuf.isOverrun() {
// Couldn't decode a complete entry.
break
@@ -1806,11 +1923,11 @@ func (r *Rreaddir) Decode(b *buffer) {
}
}
-// Encode implements encoder.Encode.
-func (r *Rreaddir) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rreaddir) encode(b *buffer) {
entriesBuf := buffer{}
for _, d := range r.Entries {
- d.Encode(&entriesBuf)
+ d.encode(&entriesBuf)
if len(entriesBuf.data) >= int(r.Count) {
break
}
@@ -1855,13 +1972,13 @@ type Tfsync struct {
FID FID
}
-// Decode implements encoder.Decode.
-func (t *Tfsync) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tfsync) decode(b *buffer) {
t.FID = b.ReadFID()
}
-// Encode implements encoder.Encode.
-func (t *Tfsync) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tfsync) encode(b *buffer) {
b.WriteFID(t.FID)
}
@@ -1879,12 +1996,12 @@ func (t *Tfsync) String() string {
type Rfsync struct {
}
-// Decode implements encoder.Decode.
-func (*Rfsync) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (*Rfsync) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rfsync) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (*Rfsync) encode(*buffer) {
}
// Type implements message.Type.
@@ -1894,7 +2011,7 @@ func (*Rfsync) Type() MsgType {
// String implements fmt.Stringer.
func (r *Rfsync) String() string {
- return fmt.Sprintf("Rfsync{}")
+ return "Rfsync{}"
}
// Tstatfs is a stat request.
@@ -1903,13 +2020,13 @@ type Tstatfs struct {
FID FID
}
-// Decode implements encoder.Decode.
-func (t *Tstatfs) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tstatfs) decode(b *buffer) {
t.FID = b.ReadFID()
}
-// Encode implements encoder.Encode.
-func (t *Tstatfs) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tstatfs) encode(b *buffer) {
b.WriteFID(t.FID)
}
@@ -1929,14 +2046,14 @@ type Rstatfs struct {
FSStat FSStat
}
-// Decode implements encoder.Decode.
-func (r *Rstatfs) Decode(b *buffer) {
- r.FSStat.Decode(b)
+// decode implements encoder.decode.
+func (r *Rstatfs) decode(b *buffer) {
+ r.FSStat.decode(b)
}
-// Encode implements encoder.Encode.
-func (r *Rstatfs) Encode(b *buffer) {
- r.FSStat.Encode(b)
+// encode implements encoder.encode.
+func (r *Rstatfs) encode(b *buffer) {
+ r.FSStat.encode(b)
}
// Type implements message.Type.
@@ -1955,13 +2072,13 @@ type Tflushf struct {
FID FID
}
-// Decode implements encoder.Decode.
-func (t *Tflushf) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tflushf) decode(b *buffer) {
t.FID = b.ReadFID()
}
-// Encode implements encoder.Encode.
-func (t *Tflushf) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tflushf) encode(b *buffer) {
b.WriteFID(t.FID)
}
@@ -1979,12 +2096,12 @@ func (t *Tflushf) String() string {
type Rflushf struct {
}
-// Decode implements encoder.Decode.
-func (*Rflushf) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (*Rflushf) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rflushf) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (*Rflushf) encode(*buffer) {
}
// Type implements message.Type.
@@ -1994,7 +2111,7 @@ func (*Rflushf) Type() MsgType {
// String implements fmt.Stringer.
func (*Rflushf) String() string {
- return fmt.Sprintf("Rflushf{}")
+ return "Rflushf{}"
}
// Twalkgetattr is a walk request.
@@ -2009,8 +2126,8 @@ type Twalkgetattr struct {
Names []string
}
-// Decode implements encoder.Decode.
-func (t *Twalkgetattr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Twalkgetattr) decode(b *buffer) {
t.FID = b.ReadFID()
t.NewFID = b.ReadFID()
n := b.Read16()
@@ -2020,8 +2137,8 @@ func (t *Twalkgetattr) Decode(b *buffer) {
}
}
-// Encode implements encoder.Encode.
-func (t *Twalkgetattr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Twalkgetattr) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteFID(t.NewFID)
b.Write16(uint16(len(t.Names)))
@@ -2052,26 +2169,26 @@ type Rwalkgetattr struct {
QIDs []QID
}
-// Decode implements encoder.Decode.
-func (r *Rwalkgetattr) Decode(b *buffer) {
- r.Valid.Decode(b)
- r.Attr.Decode(b)
+// decode implements encoder.decode.
+func (r *Rwalkgetattr) decode(b *buffer) {
+ r.Valid.decode(b)
+ r.Attr.decode(b)
n := b.Read16()
r.QIDs = r.QIDs[:0]
for i := 0; i < int(n); i++ {
var q QID
- q.Decode(b)
+ q.decode(b)
r.QIDs = append(r.QIDs, q)
}
}
-// Encode implements encoder.Encode.
-func (r *Rwalkgetattr) Encode(b *buffer) {
- r.Valid.Encode(b)
- r.Attr.Encode(b)
+// encode implements encoder.encode.
+func (r *Rwalkgetattr) encode(b *buffer) {
+ r.Valid.encode(b)
+ r.Attr.encode(b)
b.Write16(uint16(len(r.QIDs)))
for _, q := range r.QIDs {
- q.Encode(b)
+ q.encode(b)
}
}
@@ -2093,15 +2210,15 @@ type Tucreate struct {
UID UID
}
-// Decode implements encoder.Decode.
-func (t *Tucreate) Decode(b *buffer) {
- t.Tlcreate.Decode(b)
+// decode implements encoder.decode.
+func (t *Tucreate) decode(b *buffer) {
+ t.Tlcreate.decode(b)
t.UID = b.ReadUID()
}
-// Encode implements encoder.Encode.
-func (t *Tucreate) Encode(b *buffer) {
- t.Tlcreate.Encode(b)
+// encode implements encoder.encode.
+func (t *Tucreate) encode(b *buffer) {
+ t.Tlcreate.encode(b)
b.WriteUID(t.UID)
}
@@ -2138,15 +2255,15 @@ type Tumkdir struct {
UID UID
}
-// Decode implements encoder.Decode.
-func (t *Tumkdir) Decode(b *buffer) {
- t.Tmkdir.Decode(b)
+// decode implements encoder.decode.
+func (t *Tumkdir) decode(b *buffer) {
+ t.Tmkdir.decode(b)
t.UID = b.ReadUID()
}
-// Encode implements encoder.Encode.
-func (t *Tumkdir) Encode(b *buffer) {
- t.Tmkdir.Encode(b)
+// encode implements encoder.encode.
+func (t *Tumkdir) encode(b *buffer) {
+ t.Tmkdir.encode(b)
b.WriteUID(t.UID)
}
@@ -2183,15 +2300,15 @@ type Tumknod struct {
UID UID
}
-// Decode implements encoder.Decode.
-func (t *Tumknod) Decode(b *buffer) {
- t.Tmknod.Decode(b)
+// decode implements encoder.decode.
+func (t *Tumknod) decode(b *buffer) {
+ t.Tmknod.decode(b)
t.UID = b.ReadUID()
}
-// Encode implements encoder.Encode.
-func (t *Tumknod) Encode(b *buffer) {
- t.Tmknod.Encode(b)
+// encode implements encoder.encode.
+func (t *Tumknod) encode(b *buffer) {
+ t.Tmknod.encode(b)
b.WriteUID(t.UID)
}
@@ -2228,15 +2345,15 @@ type Tusymlink struct {
UID UID
}
-// Decode implements encoder.Decode.
-func (t *Tusymlink) Decode(b *buffer) {
- t.Tsymlink.Decode(b)
+// decode implements encoder.decode.
+func (t *Tusymlink) decode(b *buffer) {
+ t.Tsymlink.decode(b)
t.UID = b.ReadUID()
}
-// Encode implements encoder.Encode.
-func (t *Tusymlink) Encode(b *buffer) {
- t.Tsymlink.Encode(b)
+// encode implements encoder.encode.
+func (t *Tusymlink) encode(b *buffer) {
+ t.Tsymlink.encode(b)
b.WriteUID(t.UID)
}
@@ -2274,14 +2391,14 @@ type Tlconnect struct {
Flags ConnectFlags
}
-// Decode implements encoder.Decode.
-func (t *Tlconnect) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tlconnect) decode(b *buffer) {
t.FID = b.ReadFID()
t.Flags = b.ReadConnectFlags()
}
-// Encode implements encoder.Encode.
-func (t *Tlconnect) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tlconnect) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteConnectFlags(t.Flags)
}
@@ -2301,11 +2418,11 @@ type Rlconnect struct {
filePayload
}
-// Decode implements encoder.Decode.
-func (r *Rlconnect) Decode(*buffer) {}
+// decode implements encoder.decode.
+func (r *Rlconnect) decode(*buffer) {}
-// Encode implements encoder.Encode.
-func (r *Rlconnect) Encode(*buffer) {}
+// encode implements encoder.encode.
+func (r *Rlconnect) encode(*buffer) {}
// Type implements message.Type.
func (*Rlconnect) Type() MsgType {
@@ -2328,14 +2445,14 @@ type Tchannel struct {
Control uint32
}
-// Decode implements encoder.Decode.
-func (t *Tchannel) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tchannel) decode(b *buffer) {
t.ID = b.Read32()
t.Control = b.Read32()
}
-// Encode implements encoder.Encode.
-func (t *Tchannel) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tchannel) encode(b *buffer) {
b.Write32(t.ID)
b.Write32(t.Control)
}
@@ -2357,14 +2474,14 @@ type Rchannel struct {
filePayload
}
-// Decode implements encoder.Decode.
-func (r *Rchannel) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rchannel) decode(b *buffer) {
r.Offset = b.Read64()
r.Length = b.Read64()
}
-// Encode implements encoder.Encode.
-func (r *Rchannel) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rchannel) encode(b *buffer) {
b.Write64(r.Offset)
b.Write64(r.Length)
}
@@ -2460,7 +2577,7 @@ func calculateSize(m message) uint32 {
return p.FixedSize()
}
var dataBuf buffer
- m.Encode(&dataBuf)
+ m.encode(&dataBuf)
return uint32(len(dataBuf.data))
}
@@ -2484,6 +2601,8 @@ func init() {
msgRegistry.register(MsgRgetattr, func() message { return &Rgetattr{} })
msgRegistry.register(MsgTsetattr, func() message { return &Tsetattr{} })
msgRegistry.register(MsgRsetattr, func() message { return &Rsetattr{} })
+ msgRegistry.register(MsgTlistxattr, func() message { return &Tlistxattr{} })
+ msgRegistry.register(MsgRlistxattr, func() message { return &Rlistxattr{} })
msgRegistry.register(MsgTxattrwalk, func() message { return &Txattrwalk{} })
msgRegistry.register(MsgRxattrwalk, func() message { return &Rxattrwalk{} })
msgRegistry.register(MsgTxattrcreate, func() message { return &Txattrcreate{} })
@@ -2492,6 +2611,8 @@ func init() {
msgRegistry.register(MsgRgetxattr, func() message { return &Rgetxattr{} })
msgRegistry.register(MsgTsetxattr, func() message { return &Tsetxattr{} })
msgRegistry.register(MsgRsetxattr, func() message { return &Rsetxattr{} })
+ msgRegistry.register(MsgTremovexattr, func() message { return &Tremovexattr{} })
+ msgRegistry.register(MsgRremovexattr, func() message { return &Rremovexattr{} })
msgRegistry.register(MsgTreaddir, func() message { return &Treaddir{} })
msgRegistry.register(MsgRreaddir, func() message { return &Rreaddir{} })
msgRegistry.register(MsgTfsync, func() message { return &Tfsync{} })
diff --git a/pkg/p9/messages_test.go b/pkg/p9/messages_test.go
index 825c939da..c20324404 100644
--- a/pkg/p9/messages_test.go
+++ b/pkg/p9/messages_test.go
@@ -382,7 +382,7 @@ func TestEncodeDecode(t *testing.T) {
// Encode the original.
data := make([]byte, initialBufferLength)
buf := buffer{data: data[:0]}
- enc.Encode(&buf)
+ enc.encode(&buf)
// Create a new object, same as the first.
enc2 := reflect.New(reflect.ValueOf(enc).Elem().Type()).Interface().(encoder)
@@ -399,7 +399,7 @@ func TestEncodeDecode(t *testing.T) {
}
// Mark sure it was okay.
- enc2.Decode(&buf2)
+ enc2.decode(&buf2)
if buf2.isOverrun() {
t.Errorf("object %#v->%#v got overrun on decode", enc, enc2)
continue
diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go
index 5ab00d625..28d851ff5 100644
--- a/pkg/p9/p9.go
+++ b/pkg/p9/p9.go
@@ -335,6 +335,8 @@ const (
MsgRgetattr = 25
MsgTsetattr = 26
MsgRsetattr = 27
+ MsgTlistxattr = 28
+ MsgRlistxattr = 29
MsgTxattrwalk = 30
MsgRxattrwalk = 31
MsgTxattrcreate = 32
@@ -343,6 +345,8 @@ const (
MsgRgetxattr = 35
MsgTsetxattr = 36
MsgRsetxattr = 37
+ MsgTremovexattr = 38
+ MsgRremovexattr = 39
MsgTreaddir = 40
MsgRreaddir = 41
MsgTfsync = 50
@@ -446,15 +450,15 @@ func (q QID) String() string {
return fmt.Sprintf("QID{Type: %d, Version: %d, Path: %d}", q.Type, q.Version, q.Path)
}
-// Decode implements encoder.Decode.
-func (q *QID) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (q *QID) decode(b *buffer) {
q.Type = b.ReadQIDType()
q.Version = b.Read32()
q.Path = b.Read64()
}
-// Encode implements encoder.Encode.
-func (q *QID) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (q *QID) encode(b *buffer) {
b.WriteQIDType(q.Type)
b.Write32(q.Version)
b.Write64(q.Path)
@@ -511,8 +515,8 @@ type FSStat struct {
NameLength uint32
}
-// Decode implements encoder.Decode.
-func (f *FSStat) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (f *FSStat) decode(b *buffer) {
f.Type = b.Read32()
f.BlockSize = b.Read32()
f.Blocks = b.Read64()
@@ -524,8 +528,8 @@ func (f *FSStat) Decode(b *buffer) {
f.NameLength = b.Read32()
}
-// Encode implements encoder.Encode.
-func (f *FSStat) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (f *FSStat) encode(b *buffer) {
b.Write32(f.Type)
b.Write32(f.BlockSize)
b.Write64(f.Blocks)
@@ -675,8 +679,8 @@ func (a AttrMask) String() string {
return fmt.Sprintf("AttrMask{with: %s}", strings.Join(masks, " "))
}
-// Decode implements encoder.Decode.
-func (a *AttrMask) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (a *AttrMask) decode(b *buffer) {
mask := b.Read64()
a.Mode = mask&0x00000001 != 0
a.NLink = mask&0x00000002 != 0
@@ -694,8 +698,8 @@ func (a *AttrMask) Decode(b *buffer) {
a.DataVersion = mask&0x00002000 != 0
}
-// Encode implements encoder.Encode.
-func (a *AttrMask) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (a *AttrMask) encode(b *buffer) {
var mask uint64
if a.Mode {
mask |= 0x00000001
@@ -770,8 +774,8 @@ func (a Attr) String() string {
a.Mode, a.UID, a.GID, a.NLink, a.RDev, a.Size, a.BlockSize, a.Blocks, a.ATimeSeconds, a.ATimeNanoSeconds, a.MTimeSeconds, a.MTimeNanoSeconds, a.CTimeSeconds, a.CTimeNanoSeconds, a.BTimeSeconds, a.BTimeNanoSeconds, a.Gen, a.DataVersion)
}
-// Encode implements encoder.Encode.
-func (a *Attr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (a *Attr) encode(b *buffer) {
b.WriteFileMode(a.Mode)
b.WriteUID(a.UID)
b.WriteGID(a.GID)
@@ -792,8 +796,8 @@ func (a *Attr) Encode(b *buffer) {
b.Write64(a.DataVersion)
}
-// Decode implements encoder.Decode.
-func (a *Attr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (a *Attr) decode(b *buffer) {
a.Mode = b.ReadFileMode()
a.UID = b.ReadUID()
a.GID = b.ReadGID()
@@ -922,8 +926,8 @@ func (s SetAttrMask) Empty() bool {
return !s.Permissions && !s.UID && !s.GID && !s.Size && !s.ATime && !s.MTime && !s.CTime && !s.ATimeNotSystemTime && !s.MTimeNotSystemTime
}
-// Decode implements encoder.Decode.
-func (s *SetAttrMask) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (s *SetAttrMask) decode(b *buffer) {
mask := b.Read32()
s.Permissions = mask&0x00000001 != 0
s.UID = mask&0x00000002 != 0
@@ -968,8 +972,8 @@ func (s SetAttrMask) bitmask() uint32 {
return mask
}
-// Encode implements encoder.Encode.
-func (s *SetAttrMask) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (s *SetAttrMask) encode(b *buffer) {
b.Write32(s.bitmask())
}
@@ -990,8 +994,8 @@ func (s SetAttr) String() string {
return fmt.Sprintf("SetAttr{Permissions: 0o%o, UID: %d, GID: %d, Size: %d, ATime: {Sec: %d, NanoSec: %d}, MTime: {Sec: %d, NanoSec: %d}}", s.Permissions, s.UID, s.GID, s.Size, s.ATimeSeconds, s.ATimeNanoSeconds, s.MTimeSeconds, s.MTimeNanoSeconds)
}
-// Decode implements encoder.Decode.
-func (s *SetAttr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (s *SetAttr) decode(b *buffer) {
s.Permissions = b.ReadPermissions()
s.UID = b.ReadUID()
s.GID = b.ReadGID()
@@ -1002,8 +1006,8 @@ func (s *SetAttr) Decode(b *buffer) {
s.MTimeNanoSeconds = b.Read64()
}
-// Encode implements encoder.Encode.
-func (s *SetAttr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (s *SetAttr) encode(b *buffer) {
b.WritePermissions(s.Permissions)
b.WriteUID(s.UID)
b.WriteGID(s.GID)
@@ -1060,17 +1064,17 @@ func (d Dirent) String() string {
return fmt.Sprintf("Dirent{QID: %d, Offset: %d, Type: 0x%X, Name: %s}", d.QID, d.Offset, d.Type, d.Name)
}
-// Decode implements encoder.Decode.
-func (d *Dirent) Decode(b *buffer) {
- d.QID.Decode(b)
+// decode implements encoder.decode.
+func (d *Dirent) decode(b *buffer) {
+ d.QID.decode(b)
d.Offset = b.Read64()
d.Type = b.ReadQIDType()
d.Name = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (d *Dirent) Encode(b *buffer) {
- d.QID.Encode(b)
+// encode implements encoder.encode.
+func (d *Dirent) encode(b *buffer) {
+ d.QID.encode(b)
b.Write64(d.Offset)
b.WriteQIDType(d.Type)
b.WriteString(d.Name)
@@ -1114,8 +1118,8 @@ func (a *AllocateMode) ToLinux() uint32 {
return rv
}
-// Decode implements encoder.Decode.
-func (a *AllocateMode) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (a *AllocateMode) decode(b *buffer) {
mask := b.Read32()
a.KeepSize = mask&0x01 != 0
a.PunchHole = mask&0x02 != 0
@@ -1126,8 +1130,8 @@ func (a *AllocateMode) Decode(b *buffer) {
a.Unshare = mask&0x40 != 0
}
-// Encode implements encoder.Encode.
-func (a *AllocateMode) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (a *AllocateMode) encode(b *buffer) {
mask := uint32(0)
if a.KeepSize {
mask |= 0x01
diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go
index 9c11e28ce..7cec0e86d 100644
--- a/pkg/p9/transport.go
+++ b/pkg/p9/transport.go
@@ -80,7 +80,7 @@ func send(s *unet.Socket, tag Tag, m message) error {
}
// Encode the message. The buffer will grow automatically.
- m.Encode(&dataBuf)
+ m.encode(&dataBuf)
// Get our vectors to send.
var hdr [headerLength]byte
@@ -316,7 +316,7 @@ func recv(s *unet.Socket, msize uint32, lookup lookupTagAndType) (Tag, message,
}
// Decode the message data.
- m.Decode(&dataBuf)
+ m.decode(&dataBuf)
if dataBuf.isOverrun() {
// No need to drain the socket.
return NoTag, nil, ErrNoValidMessage
diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go
index 233f825e3..a0d274f3b 100644
--- a/pkg/p9/transport_flipcall.go
+++ b/pkg/p9/transport_flipcall.go
@@ -151,7 +151,7 @@ func (ch *channel) send(m message) (uint32, error) {
} else {
ch.buf.Write8(0) // No incoming FD.
}
- m.Encode(&ch.buf)
+ m.encode(&ch.buf)
ssz := uint32(len(ch.buf.data)) // Updated below.
// Is there a payload?
@@ -205,7 +205,7 @@ func (ch *channel) recv(r message, rsz uint32) (message, error) {
ch.buf.data = ch.buf.data[:fs]
}
- r.Decode(&ch.buf)
+ r.decode(&ch.buf)
if ch.buf.isOverrun() {
// Nothing valid was available.
log.Debugf("recv [got %d bytes, needed more]", rsz)
diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go
index 2f50ff3ea..3668fcad7 100644
--- a/pkg/p9/transport_test.go
+++ b/pkg/p9/transport_test.go
@@ -56,8 +56,8 @@ func TestSendRecv(t *testing.T) {
// badDecode overruns on decode.
type badDecode struct{}
-func (*badDecode) Decode(b *buffer) { b.markOverrun() }
-func (*badDecode) Encode(b *buffer) {}
+func (*badDecode) decode(b *buffer) { b.markOverrun() }
+func (*badDecode) encode(b *buffer) {}
func (*badDecode) Type() MsgType { return MsgTypeBadDecode }
func (*badDecode) String() string { return "badDecode{}" }
@@ -81,8 +81,8 @@ func TestRecvOverrun(t *testing.T) {
// unregistered is not registered on decode.
type unregistered struct{}
-func (*unregistered) Decode(b *buffer) {}
-func (*unregistered) Encode(b *buffer) {}
+func (*unregistered) decode(b *buffer) {}
+func (*unregistered) encode(b *buffer) {}
func (*unregistered) Type() MsgType { return MsgTypeUnregistered }
func (*unregistered) String() string { return "unregistered{}" }
diff --git a/pkg/p9/version.go b/pkg/p9/version.go
index 34a15eb55..09cde9f5a 100644
--- a/pkg/p9/version.go
+++ b/pkg/p9/version.go
@@ -26,7 +26,7 @@ const (
//
// Clients are expected to start requesting this version number and
// to continuously decrement it until a Tversion request succeeds.
- highestSupportedVersion uint32 = 10
+ highestSupportedVersion uint32 = 11
// lowestSupportedVersion is the lowest supported version X in a
// version string of the format 9P2000.L.Google.X.
@@ -167,3 +167,9 @@ func VersionSupportsOpenTruncateFlag(v uint32) bool {
func versionSupportsGetSetXattr(v uint32) bool {
return v >= 10
}
+
+// versionSupportsListRemoveXattr returns true if version v supports
+// the Tlistxattr and Tremovexattr messages.
+func versionSupportsListRemoveXattr(v uint32) bool {
+ return v >= 11
+}
diff --git a/pkg/pool/BUILD b/pkg/pool/BUILD
new file mode 100644
index 000000000..7b1c6b75b
--- /dev/null
+++ b/pkg/pool/BUILD
@@ -0,0 +1,25 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"],
+)
+
+go_library(
+ name = "pool",
+ srcs = [
+ "pool.go",
+ ],
+ deps = [
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "pool_test",
+ size = "small",
+ srcs = [
+ "pool_test.go",
+ ],
+ library = ":pool",
+)
diff --git a/pkg/p9/pool.go b/pkg/pool/pool.go
index 2b14a5ce3..a1b2e0cfe 100644
--- a/pkg/p9/pool.go
+++ b/pkg/pool/pool.go
@@ -12,33 +12,31 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package p9
+package pool
import (
"gvisor.dev/gvisor/pkg/sync"
)
-// pool is a simple allocator.
-//
-// It is used for both tags and FIDs.
-type pool struct {
+// Pool is a simple allocator.
+type Pool struct {
mu sync.Mutex
// cache is the set of returned values.
cache []uint64
- // start is the starting value (if needed).
- start uint64
+ // Start is the starting value (if needed).
+ Start uint64
// max is the current maximum issued.
max uint64
- // limit is the upper limit.
- limit uint64
+ // Limit is the upper limit.
+ Limit uint64
}
// Get gets a value from the pool.
-func (p *pool) Get() (uint64, bool) {
+func (p *Pool) Get() (uint64, bool) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -50,18 +48,18 @@ func (p *pool) Get() (uint64, bool) {
}
// Over the limit?
- if p.start == p.limit {
+ if p.Start == p.Limit {
return 0, false
}
// Generate a new value.
- v := p.start
- p.start++
+ v := p.Start
+ p.Start++
return v, true
}
// Put returns a value to the pool.
-func (p *pool) Put(v uint64) {
+func (p *Pool) Put(v uint64) {
p.mu.Lock()
p.cache = append(p.cache, v)
p.mu.Unlock()
diff --git a/pkg/p9/pool_test.go b/pkg/pool/pool_test.go
index e4746b8da..d928439c1 100644
--- a/pkg/p9/pool_test.go
+++ b/pkg/pool/pool_test.go
@@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package p9
+package pool
import (
"testing"
)
func TestPoolUnique(t *testing.T) {
- p := pool{start: 1, limit: 3}
+ p := Pool{Start: 1, Limit: 3}
got := make(map[uint64]bool)
for {
@@ -39,7 +39,7 @@ func TestPoolUnique(t *testing.T) {
}
func TestExausted(t *testing.T) {
- p := pool{start: 1, limit: 500}
+ p := Pool{Start: 1, Limit: 500}
for i := 0; i < 499; i++ {
_, ok := p.Get()
if !ok {
@@ -54,7 +54,7 @@ func TestExausted(t *testing.T) {
}
func TestPoolRecycle(t *testing.T) {
- p := pool{start: 1, limit: 500}
+ p := Pool{Start: 1, Limit: 500}
n1, _ := p.Get()
p.Put(n1)
n2, _ := p.Get()
diff --git a/pkg/safecopy/safecopy_test.go b/pkg/safecopy/safecopy_test.go
index 5818f7f9b..7f7f69d61 100644
--- a/pkg/safecopy/safecopy_test.go
+++ b/pkg/safecopy/safecopy_test.go
@@ -138,10 +138,14 @@ func TestSwapUint32Success(t *testing.T) {
func TestSwapUint32AlignmentError(t *testing.T) {
// Test that SwapUint32 returns an AlignmentError when passed an unaligned
// address.
- data := new(struct{ val uint64 })
- addr := uintptr(unsafe.Pointer(&data.val)) + 1
- want := AlignmentError{Addr: addr, Alignment: 4}
- if _, err := SwapUint32(unsafe.Pointer(addr), 1); err != want {
+ data := make([]byte, 8) // 2 * sizeof(uint32).
+ alignedIndex := uintptr(0)
+ if offset := uintptr(unsafe.Pointer(&data[0])) % 4; offset != 0 {
+ alignedIndex = 4 - offset
+ }
+ ptr := unsafe.Pointer(&data[alignedIndex+1])
+ want := AlignmentError{Addr: uintptr(ptr), Alignment: 4}
+ if _, err := SwapUint32(ptr, 1); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
}
}
@@ -171,10 +175,14 @@ func TestSwapUint64Success(t *testing.T) {
func TestSwapUint64AlignmentError(t *testing.T) {
// Test that SwapUint64 returns an AlignmentError when passed an unaligned
// address.
- data := new(struct{ val1, val2 uint64 })
- addr := uintptr(unsafe.Pointer(&data.val1)) + 1
- want := AlignmentError{Addr: addr, Alignment: 8}
- if _, err := SwapUint64(unsafe.Pointer(addr), 1); err != want {
+ data := make([]byte, 16) // 2 * sizeof(uint64).
+ alignedIndex := uintptr(0)
+ if offset := uintptr(unsafe.Pointer(&data[0])) % 8; offset != 0 {
+ alignedIndex = 8 - offset
+ }
+ ptr := unsafe.Pointer(&data[alignedIndex+1])
+ want := AlignmentError{Addr: uintptr(ptr), Alignment: 8}
+ if _, err := SwapUint64(ptr, 1); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
}
}
@@ -201,10 +209,14 @@ func TestCompareAndSwapUint32Success(t *testing.T) {
func TestCompareAndSwapUint32AlignmentError(t *testing.T) {
// Test that CompareAndSwapUint32 returns an AlignmentError when passed an
// unaligned address.
- data := new(struct{ val uint64 })
- addr := uintptr(unsafe.Pointer(&data.val)) + 1
- want := AlignmentError{Addr: addr, Alignment: 4}
- if _, err := CompareAndSwapUint32(unsafe.Pointer(addr), 0, 1); err != want {
+ data := make([]byte, 8) // 2 * sizeof(uint32).
+ alignedIndex := uintptr(0)
+ if offset := uintptr(unsafe.Pointer(&data[0])) % 4; offset != 0 {
+ alignedIndex = 4 - offset
+ }
+ ptr := unsafe.Pointer(&data[alignedIndex+1])
+ want := AlignmentError{Addr: uintptr(ptr), Alignment: 4}
+ if _, err := CompareAndSwapUint32(ptr, 0, 1); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
}
}
@@ -252,8 +264,8 @@ func TestCopyInSegvError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
dst := randBuf(pageSize)
n, err := CopyIn(dst, src)
if n != bytesBeforeFault {
@@ -276,8 +288,8 @@ func TestCopyInBusError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) {
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
dst := randBuf(pageSize)
n, err := CopyIn(dst, src)
if n != bytesBeforeFault {
@@ -300,8 +312,8 @@ func TestCopyOutSegvError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
src := randBuf(pageSize)
n, err := CopyOut(dst, src)
if n != bytesBeforeFault {
@@ -324,8 +336,8 @@ func TestCopyOutBusError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
src := randBuf(pageSize)
n, err := CopyOut(dst, src)
if n != bytesBeforeFault {
@@ -348,8 +360,8 @@ func TestCopySourceSegvError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
dst := randBuf(pageSize)
n, err := Copy(unsafe.Pointer(&dst[0]), src, pageSize)
if n != uintptr(bytesBeforeFault) {
@@ -372,8 +384,8 @@ func TestCopySourceBusError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) {
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
dst := randBuf(pageSize)
n, err := Copy(unsafe.Pointer(&dst[0]), src, pageSize)
if n != uintptr(bytesBeforeFault) {
@@ -396,8 +408,8 @@ func TestCopyDestinationSegvError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
src := randBuf(pageSize)
n, err := Copy(dst, unsafe.Pointer(&src[0]), pageSize)
if n != uintptr(bytesBeforeFault) {
@@ -420,8 +432,8 @@ func TestCopyDestinationBusError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) {
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
src := randBuf(pageSize)
n, err := Copy(dst, unsafe.Pointer(&src[0]), pageSize)
if n != uintptr(bytesBeforeFault) {
@@ -444,8 +456,8 @@ func TestZeroOutSegvError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting write %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
n, err := ZeroOut(dst, pageSize)
if n != uintptr(bytesBeforeFault) {
t.Errorf("Unexpected write length: got %v, want %v", n, bytesBeforeFault)
@@ -467,8 +479,8 @@ func TestZeroOutBusError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting write %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) {
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
n, err := ZeroOut(dst, pageSize)
if n != uintptr(bytesBeforeFault) {
t.Errorf("Unexpected write length: got %v, want %v", n, bytesBeforeFault)
@@ -488,7 +500,7 @@ func TestSwapUint32SegvError(t *testing.T) {
// Test that SwapUint32 returns a SegvError when reaching a page that
// signals SIGSEGV.
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
_, err := SwapUint32(unsafe.Pointer(secondPage), 1)
if want := (SegvError{secondPage}); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
@@ -500,7 +512,7 @@ func TestSwapUint32BusError(t *testing.T) {
// Test that SwapUint32 returns a BusError when reaching a page that
// signals SIGBUS.
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
_, err := SwapUint32(unsafe.Pointer(secondPage), 1)
if want := (BusError{secondPage}); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
@@ -512,7 +524,7 @@ func TestSwapUint64SegvError(t *testing.T) {
// Test that SwapUint64 returns a SegvError when reaching a page that
// signals SIGSEGV.
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
_, err := SwapUint64(unsafe.Pointer(secondPage), 1)
if want := (SegvError{secondPage}); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
@@ -524,7 +536,7 @@ func TestSwapUint64BusError(t *testing.T) {
// Test that SwapUint64 returns a BusError when reaching a page that
// signals SIGBUS.
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
_, err := SwapUint64(unsafe.Pointer(secondPage), 1)
if want := (BusError{secondPage}); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
@@ -536,7 +548,7 @@ func TestCompareAndSwapUint32SegvError(t *testing.T) {
// Test that CompareAndSwapUint32 returns a SegvError when reaching a page
// that signals SIGSEGV.
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
_, err := CompareAndSwapUint32(unsafe.Pointer(secondPage), 0, 1)
if want := (SegvError{secondPage}); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
@@ -548,7 +560,7 @@ func TestCompareAndSwapUint32BusError(t *testing.T) {
// Test that CompareAndSwapUint32 returns a BusError when reaching a page
// that signals SIGBUS.
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
_, err := CompareAndSwapUint32(unsafe.Pointer(secondPage), 0, 1)
if want := (BusError{secondPage}); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
diff --git a/pkg/safecopy/safecopy_unsafe.go b/pkg/safecopy/safecopy_unsafe.go
index eef028e68..41dd567f3 100644
--- a/pkg/safecopy/safecopy_unsafe.go
+++ b/pkg/safecopy/safecopy_unsafe.go
@@ -16,6 +16,7 @@ package safecopy
import (
"fmt"
+ "runtime"
"syscall"
"unsafe"
)
@@ -35,7 +36,7 @@ const maxRegisterSize = 16
// successfully copied.
//
//go:noescape
-func memcpy(dst, src unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32)
+func memcpy(dst, src uintptr, n uintptr) (fault uintptr, sig int32)
// memclr sets the n bytes following ptr to zeroes. If a SIGSEGV or SIGBUS
// signal is received during the write, it returns the address that caused the
@@ -47,7 +48,7 @@ func memcpy(dst, src unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32
// successfully written.
//
//go:noescape
-func memclr(ptr unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32)
+func memclr(ptr uintptr, n uintptr) (fault uintptr, sig int32)
// swapUint32 atomically stores new into *ptr and returns (the previous *ptr
// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the
@@ -90,29 +91,35 @@ func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32)
// CopyIn copies len(dst) bytes from src to dst. It returns the number of bytes
// copied and an error if SIGSEGV or SIGBUS is received while reading from src.
func CopyIn(dst []byte, src unsafe.Pointer) (int, error) {
+ n, err := copyIn(dst, uintptr(src))
+ runtime.KeepAlive(src)
+ return n, err
+}
+
+// copyIn is the underlying definition for CopyIn.
+func copyIn(dst []byte, src uintptr) (int, error) {
toCopy := uintptr(len(dst))
if len(dst) == 0 {
return 0, nil
}
- fault, sig := memcpy(unsafe.Pointer(&dst[0]), src, toCopy)
+ fault, sig := memcpy(uintptr(unsafe.Pointer(&dst[0])), src, toCopy)
if sig == 0 {
return len(dst), nil
}
- faultN, srcN := uintptr(fault), uintptr(src)
- if faultN < srcN || faultN >= srcN+toCopy {
- panic(fmt.Sprintf("CopyIn raised signal %d at %#x, which is outside source [%#x, %#x)", sig, faultN, srcN, srcN+toCopy))
+ if fault < src || fault >= src+toCopy {
+ panic(fmt.Sprintf("CopyIn raised signal %d at %#x, which is outside source [%#x, %#x)", sig, fault, src, src+toCopy))
}
// memcpy might have ended the copy up to maxRegisterSize bytes before
// fault, if an instruction caused a memory access that straddled two
// pages, and the second one faulted. Try to copy up to the fault.
var done int
- if faultN-srcN > maxRegisterSize {
- done = int(faultN - srcN - maxRegisterSize)
+ if fault-src > maxRegisterSize {
+ done = int(fault - src - maxRegisterSize)
}
- n, err := CopyIn(dst[done:int(faultN-srcN)], unsafe.Pointer(srcN+uintptr(done)))
+ n, err := copyIn(dst[done:int(fault-src)], src+uintptr(done))
done += n
if err != nil {
return done, err
@@ -124,29 +131,35 @@ func CopyIn(dst []byte, src unsafe.Pointer) (int, error) {
// bytes done and an error if SIGSEGV or SIGBUS is received while writing to
// dst.
func CopyOut(dst unsafe.Pointer, src []byte) (int, error) {
+ n, err := copyOut(uintptr(dst), src)
+ runtime.KeepAlive(dst)
+ return n, err
+}
+
+// copyOut is the underlying definition for CopyOut.
+func copyOut(dst uintptr, src []byte) (int, error) {
toCopy := uintptr(len(src))
if toCopy == 0 {
return 0, nil
}
- fault, sig := memcpy(dst, unsafe.Pointer(&src[0]), toCopy)
+ fault, sig := memcpy(dst, uintptr(unsafe.Pointer(&src[0])), toCopy)
if sig == 0 {
return len(src), nil
}
- faultN, dstN := uintptr(fault), uintptr(dst)
- if faultN < dstN || faultN >= dstN+toCopy {
- panic(fmt.Sprintf("CopyOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, faultN, dstN, dstN+toCopy))
+ if fault < dst || fault >= dst+toCopy {
+ panic(fmt.Sprintf("CopyOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, fault, dst, dst+toCopy))
}
// memcpy might have ended the copy up to maxRegisterSize bytes before
// fault, if an instruction caused a memory access that straddled two
// pages, and the second one faulted. Try to copy up to the fault.
var done int
- if faultN-dstN > maxRegisterSize {
- done = int(faultN - dstN - maxRegisterSize)
+ if fault-dst > maxRegisterSize {
+ done = int(fault - dst - maxRegisterSize)
}
- n, err := CopyOut(unsafe.Pointer(dstN+uintptr(done)), src[done:int(faultN-dstN)])
+ n, err := copyOut(dst+uintptr(done), src[done:int(fault-dst)])
done += n
if err != nil {
return done, err
@@ -161,6 +174,14 @@ func CopyOut(dst unsafe.Pointer, src []byte) (int, error) {
// Data is copied in order; if [src, src+toCopy) and [dst, dst+toCopy) overlap,
// the resulting contents of dst are unspecified.
func Copy(dst, src unsafe.Pointer, toCopy uintptr) (uintptr, error) {
+ n, err := copyN(uintptr(dst), uintptr(src), toCopy)
+ runtime.KeepAlive(dst)
+ runtime.KeepAlive(src)
+ return n, err
+}
+
+// copyN is the underlying definition for Copy.
+func copyN(dst, src uintptr, toCopy uintptr) (uintptr, error) {
if toCopy == 0 {
return 0, nil
}
@@ -171,17 +192,16 @@ func Copy(dst, src unsafe.Pointer, toCopy uintptr) (uintptr, error) {
}
// Did the fault occur while reading from src or writing to dst?
- faultN, srcN, dstN := uintptr(fault), uintptr(src), uintptr(dst)
faultAfterSrc := ^uintptr(0)
- if faultN >= srcN {
- faultAfterSrc = faultN - srcN
+ if fault >= src {
+ faultAfterSrc = fault - src
}
faultAfterDst := ^uintptr(0)
- if faultN >= dstN {
- faultAfterDst = faultN - dstN
+ if fault >= dst {
+ faultAfterDst = fault - dst
}
if faultAfterSrc >= toCopy && faultAfterDst >= toCopy {
- panic(fmt.Sprintf("Copy raised signal %d at %#x, which is outside source [%#x, %#x) and destination [%#x, %#x)", sig, faultN, srcN, srcN+toCopy, dstN, dstN+toCopy))
+ panic(fmt.Sprintf("Copy raised signal %d at %#x, which is outside source [%#x, %#x) and destination [%#x, %#x)", sig, fault, src, src+toCopy, dst, dst+toCopy))
}
faultedAfter := faultAfterSrc
if faultedAfter > faultAfterDst {
@@ -195,7 +215,7 @@ func Copy(dst, src unsafe.Pointer, toCopy uintptr) (uintptr, error) {
if faultedAfter > maxRegisterSize {
done = faultedAfter - maxRegisterSize
}
- n, err := Copy(unsafe.Pointer(dstN+done), unsafe.Pointer(srcN+done), faultedAfter-done)
+ n, err := copyN(dst+done, src+done, faultedAfter-done)
done += n
if err != nil {
return done, err
@@ -206,6 +226,13 @@ func Copy(dst, src unsafe.Pointer, toCopy uintptr) (uintptr, error) {
// ZeroOut writes toZero zero bytes to dst. It returns the number of bytes
// written and an error if SIGSEGV or SIGBUS is received while writing to dst.
func ZeroOut(dst unsafe.Pointer, toZero uintptr) (uintptr, error) {
+ n, err := zeroOut(uintptr(dst), toZero)
+ runtime.KeepAlive(dst)
+ return n, err
+}
+
+// zeroOut is the underlying definition for ZeroOut.
+func zeroOut(dst uintptr, toZero uintptr) (uintptr, error) {
if toZero == 0 {
return 0, nil
}
@@ -215,19 +242,18 @@ func ZeroOut(dst unsafe.Pointer, toZero uintptr) (uintptr, error) {
return toZero, nil
}
- faultN, dstN := uintptr(fault), uintptr(dst)
- if faultN < dstN || faultN >= dstN+toZero {
- panic(fmt.Sprintf("ZeroOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, faultN, dstN, dstN+toZero))
+ if fault < dst || fault >= dst+toZero {
+ panic(fmt.Sprintf("ZeroOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, fault, dst, dst+toZero))
}
// memclr might have ended the write up to maxRegisterSize bytes before
// fault, if an instruction caused a memory access that straddled two
// pages, and the second one faulted. Try to write up to the fault.
var done uintptr
- if faultN-dstN > maxRegisterSize {
- done = faultN - dstN - maxRegisterSize
+ if fault-dst > maxRegisterSize {
+ done = fault - dst - maxRegisterSize
}
- n, err := ZeroOut(unsafe.Pointer(dstN+done), faultN-dstN-done)
+ n, err := zeroOut(dst+done, fault-dst-done)
done += n
if err != nil {
return done, err
@@ -243,7 +269,7 @@ func SwapUint32(ptr unsafe.Pointer, new uint32) (uint32, error) {
return 0, AlignmentError{addr, 4}
}
old, sig := swapUint32(ptr, new)
- return old, errorFromFaultSignal(ptr, sig)
+ return old, errorFromFaultSignal(uintptr(ptr), sig)
}
// SwapUint64 is equivalent to sync/atomic.SwapUint64, except that it returns
@@ -254,7 +280,7 @@ func SwapUint64(ptr unsafe.Pointer, new uint64) (uint64, error) {
return 0, AlignmentError{addr, 8}
}
old, sig := swapUint64(ptr, new)
- return old, errorFromFaultSignal(ptr, sig)
+ return old, errorFromFaultSignal(uintptr(ptr), sig)
}
// CompareAndSwapUint32 is equivalent to atomicbitops.CompareAndSwapUint32,
@@ -265,7 +291,7 @@ func CompareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (uint32, error) {
return 0, AlignmentError{addr, 4}
}
prev, sig := compareAndSwapUint32(ptr, old, new)
- return prev, errorFromFaultSignal(ptr, sig)
+ return prev, errorFromFaultSignal(uintptr(ptr), sig)
}
// LoadUint32 is like sync/atomic.LoadUint32, but operates with user memory. It
@@ -277,17 +303,17 @@ func LoadUint32(ptr unsafe.Pointer) (uint32, error) {
return 0, AlignmentError{addr, 4}
}
val, sig := loadUint32(ptr)
- return val, errorFromFaultSignal(ptr, sig)
+ return val, errorFromFaultSignal(uintptr(ptr), sig)
}
-func errorFromFaultSignal(addr unsafe.Pointer, sig int32) error {
+func errorFromFaultSignal(addr uintptr, sig int32) error {
switch sig {
case 0:
return nil
case int32(syscall.SIGSEGV):
- return SegvError{uintptr(addr)}
+ return SegvError{addr}
case int32(syscall.SIGBUS):
- return BusError{uintptr(addr)}
+ return BusError{addr}
default:
panic(fmt.Sprintf("safecopy got unexpected signal %d at address %#x", sig, addr))
}
diff --git a/pkg/safemem/seq_unsafe.go b/pkg/safemem/seq_unsafe.go
index 354a95dde..dcdfc9600 100644
--- a/pkg/safemem/seq_unsafe.go
+++ b/pkg/safemem/seq_unsafe.go
@@ -18,6 +18,7 @@ import (
"bytes"
"fmt"
"reflect"
+ "syscall"
"unsafe"
)
@@ -297,3 +298,19 @@ func ZeroSeq(dsts BlockSeq) (uint64, error) {
}
return done, nil
}
+
+// IovecsFromBlockSeq returns a []syscall.Iovec representing seq.
+func IovecsFromBlockSeq(bs BlockSeq) []syscall.Iovec {
+ iovs := make([]syscall.Iovec, 0, bs.NumBlocks())
+ for ; !bs.IsEmpty(); bs = bs.Tail() {
+ b := bs.Head()
+ iovs = append(iovs, syscall.Iovec{
+ Base: &b.ToSlice()[0],
+ Len: uint64(b.Len()),
+ })
+ // We don't need to care about b.NeedSafecopy(), because the host
+ // kernel will handle such address ranges just fine (by returning
+ // EFAULT).
+ }
+ return iovs
+}
diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD
index 742c8b79b..c5fca2ba3 100644
--- a/pkg/seccomp/BUILD
+++ b/pkg/seccomp/BUILD
@@ -26,7 +26,7 @@ go_library(
"seccomp_rules.go",
"seccomp_unsafe.go",
],
- visibility = ["//visibility:public"],
+ visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
"//pkg/bpf",
diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go
index fc36efa23..55fd6967e 100644
--- a/pkg/seccomp/seccomp.go
+++ b/pkg/seccomp/seccomp.go
@@ -219,24 +219,36 @@ func addSyscallArgsCheck(p *bpf.ProgramBuilder, rules []Rule, action linux.BPFAc
switch a := arg.(type) {
case AllowAny:
case AllowValue:
+ dataOffsetLow := seccompDataOffsetArgLow(i)
+ dataOffsetHigh := seccompDataOffsetArgHigh(i)
+ if i == RuleIP {
+ dataOffsetLow = seccompDataOffsetIPLow
+ dataOffsetHigh = seccompDataOffsetIPHigh
+ }
high, low := uint32(a>>32), uint32(a)
// assert arg_low == low
- p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArgLow(i))
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow)
p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
// assert arg_high == high
- p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArgHigh(i))
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh)
p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
labelled = true
case GreaterThan:
+ dataOffsetLow := seccompDataOffsetArgLow(i)
+ dataOffsetHigh := seccompDataOffsetArgHigh(i)
+ if i == RuleIP {
+ dataOffsetLow = seccompDataOffsetIPLow
+ dataOffsetHigh = seccompDataOffsetIPHigh
+ }
labelGood := fmt.Sprintf("gt%v", i)
high, low := uint32(a>>32), uint32(a)
// assert arg_high < high
- p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArgHigh(i))
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh)
p.AddJumpFalseLabel(bpf.Jmp|bpf.Jge|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
// arg_high > high
p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
// arg_low < low
- p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArgLow(i))
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow)
p.AddJumpFalseLabel(bpf.Jmp|bpf.Jgt|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
labelled = true
diff --git a/pkg/seccomp/seccomp_rules.go b/pkg/seccomp/seccomp_rules.go
index 84c841d7f..06308cd29 100644
--- a/pkg/seccomp/seccomp_rules.go
+++ b/pkg/seccomp/seccomp_rules.go
@@ -62,7 +62,11 @@ func (a AllowValue) String() (s string) {
// rule := Rule {
// AllowValue(linux.ARCH_GET_FS | linux.ARCH_SET_FS), // arg0
// }
-type Rule [6]interface{}
+type Rule [7]interface{} // 6 arguments + RIP
+
+// RuleIP indicates what rules in the Rule array have to be applied to
+// instruction pointer.
+const RuleIP = 6
func (r Rule) String() (s string) {
if len(r) == 0 {
diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go
index abbee7051..da5a5e4b2 100644
--- a/pkg/seccomp/seccomp_test.go
+++ b/pkg/seccomp/seccomp_test.go
@@ -388,6 +388,33 @@ func TestBasic(t *testing.T) {
},
},
},
+ {
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ RuleIP: AllowValue(0x7aabbccdd),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ specs: []spec{
+ {
+ desc: "IP: Syscall instruction pointer allowed",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{}, instructionPointer: 0x7aabbccdd},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "IP: Syscall instruction pointer disallowed",
+ data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{}, instructionPointer: 0x711223344},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
} {
instrs, err := BuildProgram(test.ruleSets, test.defaultAction)
if err != nil {
diff --git a/pkg/sentry/BUILD b/pkg/sentry/BUILD
index e8b794179..e759dc36f 100644
--- a/pkg/sentry/BUILD
+++ b/pkg/sentry/BUILD
@@ -1,13 +1,11 @@
-# This BUILD file defines a package_group that allows for interdependencies for
-# sentry-internal packages.
-
package(licenses = ["notice"])
+# The "internal" package_group should be used as much as possible by packages
+# that should remain Sentry-internal (i.e. not be exposed directly to command
+# line tooling or APIs).
package_group(
name = "internal",
packages = [
- "//cloud/gvisor/gopkg/sentry/...",
- "//cloud/gvisor/sentry/...",
"//pkg/sentry/...",
"//runsc/...",
# Code generated by go_marshal relies on go_marshal libraries.
diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD
index 34c0a867d..e27f21e5e 100644
--- a/pkg/sentry/arch/BUILD
+++ b/pkg/sentry/arch/BUILD
@@ -14,6 +14,7 @@ go_library(
"arch_state_aarch64.go",
"arch_state_x86.go",
"arch_x86.go",
+ "arch_x86_impl.go",
"auxv.go",
"signal.go",
"signal_act.go",
diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go
index 53039465d..5053393c1 100644
--- a/pkg/sentry/arch/arch_aarch64.go
+++ b/pkg/sentry/arch/arch_aarch64.go
@@ -40,6 +40,9 @@ const (
fpsimdContextSize = 0x210
)
+// ARMTrapFlag is the mask for the trap flag.
+const ARMTrapFlag = uint64(1) << 21
+
// aarch64FPState is aarch64 floating point state.
type aarch64FPState []byte
diff --git a/pkg/sentry/arch/arch_amd64.s b/pkg/sentry/arch/arch_amd64.s
index bd61402cf..6c10336e7 100644
--- a/pkg/sentry/arch/arch_amd64.s
+++ b/pkg/sentry/arch/arch_amd64.s
@@ -26,10 +26,11 @@
//
// func initX86FPState(data *FloatingPointData, useXsave bool)
//
-// We need to clear out and initialize an empty fp state area since the sentry
-// may have left sensitive information in the floating point registers.
+// We need to clear out and initialize an empty fp state area since the sentry,
+// or any previous loader, may have left sensitive information in the floating
+// point registers.
//
-// Preconditions: data is zeroed
+// Preconditions: data is zeroed.
TEXT ·initX86FPState(SB), $24-16
// Save MXCSR (callee-save)
STMXCSR mxcsr-8(SP)
diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go
index d388ee9cf..e35c9214a 100644
--- a/pkg/sentry/arch/arch_state_x86.go
+++ b/pkg/sentry/arch/arch_state_x86.go
@@ -43,8 +43,8 @@ func (e ErrFloatingPoint) Error() string {
// and SSE state, so this is the equivalent XSTATE_BV value.
const fxsaveBV uint64 = cpuid.XSAVEFeatureX87 | cpuid.XSAVEFeatureSSE
-// afterLoad is invoked by stateify.
-func (s *State) afterLoad() {
+// afterLoadFPState is invoked by afterLoad.
+func (s *State) afterLoadFPState() {
old := s.x86FPState
// Recreate the slice. This is done to ensure that it is aligned
diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go
index a18093155..88b40a9d1 100644
--- a/pkg/sentry/arch/arch_x86.go
+++ b/pkg/sentry/arch/arch_x86.go
@@ -114,6 +114,10 @@ func newX86FPStateSlice() []byte {
size, align := cpuid.HostFeatureSet().ExtendedStateSize()
capacity := size
// Always use at least 4096 bytes.
+ //
+ // For the KVM platform, this state is a fixed 4096 bytes, so make sure
+ // that the underlying array is at _least_ that size otherwise we will
+ // corrupt random memory. This is not a pleasant thing to debug.
if capacity < 4096 {
capacity = 4096
}
@@ -151,21 +155,6 @@ func NewFloatingPointData() *FloatingPointData {
return (*FloatingPointData)(&(newX86FPState()[0]))
}
-// State contains the common architecture bits for X86 (the build tag of this
-// file ensures it's only built on x86).
-//
-// +stateify savable
-type State struct {
- // The system registers.
- Regs syscall.PtraceRegs `state:".(syscallPtraceRegs)"`
-
- // Our floating point state.
- x86FPState `state:"wait"`
-
- // FeatureSet is a pointer to the currently active feature set.
- FeatureSet *cpuid.FeatureSet
-}
-
// Proto returns a protobuf representation of the system registers in State.
func (s State) Proto() *rpb.Registers {
regs := &rpb.AMD64Registers{
diff --git a/pkg/sentry/arch/arch_x86_impl.go b/pkg/sentry/arch/arch_x86_impl.go
new file mode 100644
index 000000000..04ac283c6
--- /dev/null
+++ b/pkg/sentry/arch/arch_x86_impl.go
@@ -0,0 +1,43 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64 i386
+
+package arch
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/cpuid"
+)
+
+// State contains the common architecture bits for X86 (the build tag of this
+// file ensures it's only built on x86).
+//
+// +stateify savable
+type State struct {
+ // The system registers.
+ Regs syscall.PtraceRegs `state:".(syscallPtraceRegs)"`
+
+ // Our floating point state.
+ x86FPState `state:"wait"`
+
+ // FeatureSet is a pointer to the currently active feature set.
+ FeatureSet *cpuid.FeatureSet
+}
+
+// afterLoad is invoked by stateify.
+func (s *State) afterLoad() {
+ s.afterLoadFPState()
+}
diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go
index 81b92bb43..6fb756f0e 100644
--- a/pkg/sentry/arch/signal_amd64.go
+++ b/pkg/sentry/arch/signal_amd64.go
@@ -55,7 +55,7 @@ type SignalContext64 struct {
Trapno uint64
Oldmask linux.SignalSet
Cr2 uint64
- // Pointer to a struct _fpstate.
+ // Pointer to a struct _fpstate. See b/33003106#comment8.
Fpstate uint64
Reserved [8]uint64
}
diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go
index 28615f97f..0c1db4b13 100644
--- a/pkg/sentry/arch/signal_arm64.go
+++ b/pkg/sentry/arch/signal_arm64.go
@@ -40,6 +40,8 @@ type aarch64Ctx struct {
Size uint32
}
+// FpsimdContext is equivalent to struct fpsimd_context on arm64
+// (arch/arm64/include/uapi/asm/sigcontext.h).
type FpsimdContext struct {
Head aarch64Ctx
Fpsr uint32
diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD
index e69496477..d16d78aa5 100644
--- a/pkg/sentry/control/BUILD
+++ b/pkg/sentry/control/BUILD
@@ -16,10 +16,13 @@ go_library(
],
deps = [
"//pkg/abi/linux",
+ "//pkg/context",
"//pkg/fd",
+ "//pkg/fspath",
"//pkg/log",
"//pkg/sentry/fs",
"//pkg/sentry/fs/host",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
@@ -27,8 +30,10 @@ go_library(
"//pkg/sentry/state",
"//pkg/sentry/strace",
"//pkg/sentry/usage",
+ "//pkg/sentry/vfs",
"//pkg/sentry/watchdog",
"//pkg/sync",
+ "//pkg/syserror",
"//pkg/tcpip/link/sniffer",
"//pkg/urpc",
],
diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go
index ced51c66c..5457ba5e7 100644
--- a/pkg/sentry/control/proc.go
+++ b/pkg/sentry/control/proc.go
@@ -18,19 +18,26 @@ import (
"bytes"
"encoding/json"
"fmt"
+ "path"
"sort"
"strings"
"text/tabwriter"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/host"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/urpc"
)
@@ -60,6 +67,12 @@ type ExecArgs struct {
// process's MountNamespace.
MountNamespace *fs.MountNamespace
+ // MountNamespaceVFS2 is the mount namespace to execute the new process in.
+ // A reference on MountNamespace must be held for the lifetime of the
+ // ExecArgs. If MountNamespace is nil, it will default to the init
+ // process's MountNamespace.
+ MountNamespaceVFS2 *vfs.MountNamespace
+
// WorkingDirectory defines the working directory for the new process.
WorkingDirectory string `json:"wd"`
@@ -150,6 +163,7 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
Envv: args.Envv,
WorkingDirectory: args.WorkingDirectory,
MountNamespace: args.MountNamespace,
+ MountNamespaceVFS2: args.MountNamespaceVFS2,
Credentials: creds,
FDTable: fdTable,
Umask: 0022,
@@ -166,24 +180,53 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
// be donated to the new process in CreateProcess.
initArgs.MountNamespace.IncRef()
}
+ if initArgs.MountNamespaceVFS2 != nil {
+ // initArgs must hold a reference on MountNamespaceVFS2, which will
+ // be donated to the new process in CreateProcess.
+ initArgs.MountNamespaceVFS2.IncRef()
+ }
ctx := initArgs.NewContext(proc.Kernel)
if initArgs.Filename == "" {
- // Get the full path to the filename from the PATH env variable.
- paths := fs.GetPath(initArgs.Envv)
- mns := initArgs.MountNamespace
- if mns == nil {
- mns = proc.Kernel.GlobalInit().Leader().MountNamespace()
- }
- f, err := mns.ResolveExecutablePath(ctx, initArgs.WorkingDirectory, initArgs.Argv[0], paths)
- if err != nil {
- return nil, 0, nil, fmt.Errorf("error finding executable %q in PATH %v: %v", initArgs.Argv[0], paths, err)
+ if kernel.VFS2Enabled {
+ // Get the full path to the filename from the PATH env variable.
+ if initArgs.MountNamespaceVFS2 == nil {
+ // Set initArgs so that 'ctx' returns the namespace.
+ //
+ // MountNamespaceVFS2 adds a reference to the namespace, which is
+ // transferred to the new process.
+ initArgs.MountNamespaceVFS2 = proc.Kernel.GlobalInit().Leader().MountNamespaceVFS2()
+ }
+
+ paths := fs.GetPath(initArgs.Envv)
+ vfsObj := proc.Kernel.VFS()
+ file, err := ResolveExecutablePath(ctx, vfsObj, initArgs.WorkingDirectory, initArgs.Argv[0], paths)
+ if err != nil {
+ return nil, 0, nil, fmt.Errorf("error finding executable %q in PATH %v: %v", initArgs.Argv[0], paths, err)
+ }
+ initArgs.File = fsbridge.NewVFSFile(file)
+ } else {
+ // Get the full path to the filename from the PATH env variable.
+ paths := fs.GetPath(initArgs.Envv)
+ if initArgs.MountNamespace == nil {
+ // Set initArgs so that 'ctx' returns the namespace.
+ initArgs.MountNamespace = proc.Kernel.GlobalInit().Leader().MountNamespace()
+
+ // initArgs must hold a reference on MountNamespace, which will
+ // be donated to the new process in CreateProcess.
+ initArgs.MountNamespaceVFS2.IncRef()
+ }
+ f, err := initArgs.MountNamespace.ResolveExecutablePath(ctx, initArgs.WorkingDirectory, initArgs.Argv[0], paths)
+ if err != nil {
+ return nil, 0, nil, fmt.Errorf("error finding executable %q in PATH %v: %v", initArgs.Argv[0], paths, err)
+ }
+ initArgs.Filename = f
}
- initArgs.Filename = f
}
mounter := fs.FileOwnerFromContext(ctx)
+ // TODO(gvisor.dev/issue/1623): Use host FD when supported in VFS2.
var ttyFile *fs.File
for appFD, hostFile := range args.FilePayload.Files {
var appFile *fs.File
@@ -411,3 +454,67 @@ func ttyName(tty *kernel.TTY) string {
}
return fmt.Sprintf("pts/%d", tty.Index)
}
+
+// ResolveExecutablePath resolves the given executable name given a set of
+// paths that might contain it.
+func ResolveExecutablePath(ctx context.Context, vfsObj *vfs.VirtualFilesystem, wd, name string, paths []string) (*vfs.FileDescription, error) {
+ root := vfs.RootFromContext(ctx)
+ defer root.DecRef()
+ creds := auth.CredentialsFromContext(ctx)
+
+ // Absolute paths can be used directly.
+ if path.IsAbs(name) {
+ return openExecutable(ctx, vfsObj, creds, root, name)
+ }
+
+ // Paths with '/' in them should be joined to the working directory, or
+ // to the root if working directory is not set.
+ if strings.IndexByte(name, '/') > 0 {
+ if len(wd) == 0 {
+ wd = "/"
+ }
+ if !path.IsAbs(wd) {
+ return nil, fmt.Errorf("working directory %q must be absolute", wd)
+ }
+ return openExecutable(ctx, vfsObj, creds, root, path.Join(wd, name))
+ }
+
+ // Otherwise, we must lookup the name in the paths, starting from the
+ // calling context's root directory.
+ for _, p := range paths {
+ if !path.IsAbs(p) {
+ // Relative paths aren't safe, no one should be using them.
+ log.Warningf("Skipping relative path %q in $PATH", p)
+ continue
+ }
+
+ binPath := path.Join(p, name)
+ f, err := openExecutable(ctx, vfsObj, creds, root, binPath)
+ if err != nil {
+ return nil, err
+ }
+ if f == nil {
+ continue // Not found/no access.
+ }
+ return f, nil
+ }
+ return nil, syserror.ENOENT
+}
+
+func openExecutable(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, root vfs.VirtualDentry, path string) (*vfs.FileDescription, error) {
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root, // binPath is absolute, Start can be anything.
+ Path: fspath.Parse(path),
+ FollowFinalSymlink: true,
+ }
+ opts := &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ FileExec: true,
+ }
+ f, err := vfsObj.OpenAt(ctx, creds, &pop, opts)
+ if err == syserror.ENOENT || err == syserror.EACCES {
+ return nil, nil
+ }
+ return f, err
+}
diff --git a/pkg/sentry/devices/memdev/BUILD b/pkg/sentry/devices/memdev/BUILD
new file mode 100644
index 000000000..abe58f818
--- /dev/null
+++ b/pkg/sentry/devices/memdev/BUILD
@@ -0,0 +1,28 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "memdev",
+ srcs = [
+ "full.go",
+ "memdev.go",
+ "null.go",
+ "random.go",
+ "zero.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/rand",
+ "//pkg/safemem",
+ "//pkg/sentry/fsimpl/devtmpfs",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/mm",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/devices/memdev/full.go b/pkg/sentry/devices/memdev/full.go
new file mode 100644
index 000000000..c7e197691
--- /dev/null
+++ b/pkg/sentry/devices/memdev/full.go
@@ -0,0 +1,75 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package memdev
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const fullDevMinor = 7
+
+// fullDevice implements vfs.Device for /dev/full.
+type fullDevice struct{}
+
+// Open implements vfs.Device.Open.
+func (fullDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &fullFD{}
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// fullFD implements vfs.FileDescriptionImpl for /dev/full.
+type fullFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *fullFD) Release() {
+ // noop
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *fullFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return dst.ZeroOut(ctx, dst.NumBytes())
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *fullFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ return dst.ZeroOut(ctx, dst.NumBytes())
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *fullFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ENOSPC
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *fullFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ENOSPC
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *fullFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ return 0, nil
+}
diff --git a/pkg/sentry/devices/memdev/memdev.go b/pkg/sentry/devices/memdev/memdev.go
new file mode 100644
index 000000000..5759900c4
--- /dev/null
+++ b/pkg/sentry/devices/memdev/memdev.go
@@ -0,0 +1,59 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package memdev implements "mem" character devices, as implemented in Linux
+// by drivers/char/mem.c and drivers/char/random.c.
+package memdev
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// Register registers all devices implemented by this package in vfsObj.
+func Register(vfsObj *vfs.VirtualFilesystem) error {
+ for minor, dev := range map[uint32]vfs.Device{
+ nullDevMinor: nullDevice{},
+ zeroDevMinor: zeroDevice{},
+ fullDevMinor: fullDevice{},
+ randomDevMinor: randomDevice{},
+ urandomDevMinor: randomDevice{},
+ } {
+ if err := vfsObj.RegisterDevice(vfs.CharDevice, linux.MEM_MAJOR, minor, dev, &vfs.RegisterDeviceOptions{
+ GroupName: "mem",
+ }); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// CreateDevtmpfsFiles creates device special files in dev representing all
+// devices implemented by this package.
+func CreateDevtmpfsFiles(ctx context.Context, dev *devtmpfs.Accessor) error {
+ for minor, name := range map[uint32]string{
+ nullDevMinor: "null",
+ zeroDevMinor: "zero",
+ fullDevMinor: "full",
+ randomDevMinor: "random",
+ urandomDevMinor: "urandom",
+ } {
+ if err := dev.CreateDeviceFile(ctx, name, vfs.CharDevice, linux.MEM_MAJOR, minor, 0666 /* mode */); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/devices/memdev/null.go b/pkg/sentry/devices/memdev/null.go
new file mode 100644
index 000000000..33d060d02
--- /dev/null
+++ b/pkg/sentry/devices/memdev/null.go
@@ -0,0 +1,76 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package memdev
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const nullDevMinor = 3
+
+// nullDevice implements vfs.Device for /dev/null.
+type nullDevice struct{}
+
+// Open implements vfs.Device.Open.
+func (nullDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &nullFD{}
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// nullFD implements vfs.FileDescriptionImpl for /dev/null.
+type nullFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *nullFD) Release() {
+ // noop
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *nullFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return 0, io.EOF
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *nullFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ return 0, io.EOF
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *nullFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return src.NumBytes(), nil
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *nullFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return src.NumBytes(), nil
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *nullFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ return 0, nil
+}
diff --git a/pkg/sentry/devices/memdev/random.go b/pkg/sentry/devices/memdev/random.go
new file mode 100644
index 000000000..acfa23149
--- /dev/null
+++ b/pkg/sentry/devices/memdev/random.go
@@ -0,0 +1,92 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package memdev
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ randomDevMinor = 8
+ urandomDevMinor = 9
+)
+
+// randomDevice implements vfs.Device for /dev/random and /dev/urandom.
+type randomDevice struct{}
+
+// Open implements vfs.Device.Open.
+func (randomDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &randomFD{}
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// randomFD implements vfs.FileDescriptionImpl for /dev/random.
+type randomFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+
+ // off is the "file offset". off is accessed using atomic memory
+ // operations.
+ off int64
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *randomFD) Release() {
+ // noop
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *randomFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return dst.CopyOutFrom(ctx, safemem.FromIOReader{rand.Reader})
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *randomFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ n, err := dst.CopyOutFrom(ctx, safemem.FromIOReader{rand.Reader})
+ atomic.AddInt64(&fd.off, n)
+ return n, err
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *randomFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ // In Linux, this mixes the written bytes into the entropy pool; we just
+ // throw them away.
+ return src.NumBytes(), nil
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *randomFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ atomic.AddInt64(&fd.off, src.NumBytes())
+ return src.NumBytes(), nil
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *randomFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ // Linux: drivers/char/random.c:random_fops.llseek == urandom_fops.llseek
+ // == noop_llseek
+ return atomic.LoadInt64(&fd.off), nil
+}
diff --git a/pkg/sentry/devices/memdev/zero.go b/pkg/sentry/devices/memdev/zero.go
new file mode 100644
index 000000000..3b1372b9e
--- /dev/null
+++ b/pkg/sentry/devices/memdev/zero.go
@@ -0,0 +1,88 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package memdev
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const zeroDevMinor = 5
+
+// zeroDevice implements vfs.Device for /dev/zero.
+type zeroDevice struct{}
+
+// Open implements vfs.Device.Open.
+func (zeroDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &zeroFD{}
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// zeroFD implements vfs.FileDescriptionImpl for /dev/zero.
+type zeroFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *zeroFD) Release() {
+ // noop
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *zeroFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return dst.ZeroOut(ctx, dst.NumBytes())
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *zeroFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ return dst.ZeroOut(ctx, dst.NumBytes())
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *zeroFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return src.NumBytes(), nil
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *zeroFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return src.NumBytes(), nil
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *zeroFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ return 0, nil
+}
+
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *zeroFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ m, err := mm.NewSharedAnonMappable(opts.Length, pgalloc.MemoryFileProviderFromContext(ctx))
+ if err != nil {
+ return err
+ }
+ opts.MappingIdentity = m
+ opts.Mappable = m
+ return nil
+}
diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go
index f6c79e51b..b060a12ff 100644
--- a/pkg/sentry/fs/copy_up.go
+++ b/pkg/sentry/fs/copy_up.go
@@ -401,7 +401,7 @@ func copyAttributesLocked(ctx context.Context, upper *Inode, lower *Inode) error
if err != nil {
return err
}
- lowerXattr, err := lower.ListXattr(ctx)
+ lowerXattr, err := lower.ListXattr(ctx, linux.XATTR_SIZE_MAX)
if err != nil && err != syserror.EOPNOTSUPP {
return err
}
diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD
index 4c4b7d5cc..9b6bb26d0 100644
--- a/pkg/sentry/fs/dev/BUILD
+++ b/pkg/sentry/fs/dev/BUILD
@@ -9,6 +9,7 @@ go_library(
"device.go",
"fs.go",
"full.go",
+ "net_tun.go",
"null.go",
"random.go",
"tty.go",
@@ -19,15 +20,19 @@ go_library(
"//pkg/context",
"//pkg/rand",
"//pkg/safemem",
+ "//pkg/sentry/arch",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/ramfs",
"//pkg/sentry/fs/tmpfs",
+ "//pkg/sentry/kernel",
"//pkg/sentry/memmap",
"//pkg/sentry/mm",
"//pkg/sentry/pgalloc",
+ "//pkg/sentry/socket/netstack",
"//pkg/syserror",
+ "//pkg/tcpip/link/tun",
"//pkg/usermem",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go
index 35bd23991..7e66c29b0 100644
--- a/pkg/sentry/fs/dev/dev.go
+++ b/pkg/sentry/fs/dev/dev.go
@@ -66,8 +66,8 @@ func newMemDevice(ctx context.Context, iops fs.InodeOperations, msrc *fs.MountSo
})
}
-func newDirectory(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
- iops := ramfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0555))
+func newDirectory(ctx context.Context, contents map[string]*fs.Inode, msrc *fs.MountSource) *fs.Inode {
+ iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
return fs.NewInode(ctx, iops, msrc, fs.StableAttr{
DeviceID: devDevice.DeviceID(),
InodeID: devDevice.NextIno(),
@@ -111,7 +111,7 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
// A devpts is typically mounted at /dev/pts to provide
// pseudoterminal support. Place an empty directory there for
// the devpts to be mounted over.
- "pts": newDirectory(ctx, msrc),
+ "pts": newDirectory(ctx, nil, msrc),
// Similarly, applications expect a ptmx device at /dev/ptmx
// connected to the terminals provided by /dev/pts/. Rather
// than creating a device directly (which requires a hairy
@@ -124,6 +124,10 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
"ptmx": newSymlink(ctx, "pts/ptmx", msrc),
"tty": newCharacterDevice(ctx, newTTYDevice(ctx, fs.RootOwner, 0666), msrc, ttyDevMajor, ttyDevMinor),
+
+ "net": newDirectory(ctx, map[string]*fs.Inode{
+ "tun": newCharacterDevice(ctx, newNetTunDevice(ctx, fs.RootOwner, 0666), msrc, netTunDevMajor, netTunDevMinor),
+ }, msrc),
}
iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go
new file mode 100644
index 000000000..755644488
--- /dev/null
+++ b/pkg/sentry/fs/dev/net_tun.go
@@ -0,0 +1,170 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dev
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip/link/tun"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ netTunDevMajor = 10
+ netTunDevMinor = 200
+)
+
+// +stateify savable
+type netTunInodeOperations struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopAllocate `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopTruncate `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+}
+
+var _ fs.InodeOperations = (*netTunInodeOperations)(nil)
+
+func newNetTunDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMode) *netTunInodeOperations {
+ return &netTunInodeOperations{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fs.FilePermsFromMode(mode), linux.TMPFS_MAGIC),
+ }
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (iops *netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, d, flags, &netTunFileOperations{}), nil
+}
+
+// +stateify savable
+type netTunFileOperations struct {
+ fsutil.FileNoSeek `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ device tun.Device
+}
+
+var _ fs.FileOperations = (*netTunFileOperations)(nil)
+
+// Release implements fs.FileOperations.Release.
+func (fops *netTunFileOperations) Release() {
+ fops.device.Release()
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ request := args[1].Uint()
+ data := args[2].Pointer()
+
+ switch request {
+ case linux.TUNSETIFF:
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ panic("Ioctl should be called from a task context")
+ }
+ if !t.HasCapability(linux.CAP_NET_ADMIN) {
+ return 0, syserror.EPERM
+ }
+ stack, ok := t.NetworkContext().(*netstack.Stack)
+ if !ok {
+ return 0, syserror.EINVAL
+ }
+
+ var req linux.IFReq
+ if _, err := usermem.CopyObjectIn(ctx, io, data, &req, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ flags := usermem.ByteOrder.Uint16(req.Data[:])
+ return 0, fops.device.SetIff(stack.Stack, req.Name(), flags)
+
+ case linux.TUNGETIFF:
+ var req linux.IFReq
+
+ copy(req.IFName[:], fops.device.Name())
+
+ // Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) when
+ // there is no sk_filter. See __tun_chr_ioctl() in net/drivers/tun.c.
+ flags := fops.device.Flags() | linux.IFF_NOFILTER
+ usermem.ByteOrder.PutUint16(req.Data[:], flags)
+
+ _, err := usermem.CopyObjectOut(ctx, io, data, &req, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+// Write implements fs.FileOperations.Write.
+func (fops *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ data := make([]byte, src.NumBytes())
+ if _, err := src.CopyIn(ctx, data); err != nil {
+ return 0, err
+ }
+ return fops.device.Write(data)
+}
+
+// Read implements fs.FileOperations.Read.
+func (fops *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ data, err := fops.device.Read()
+ if err != nil {
+ return 0, err
+ }
+ n, err := dst.CopyOut(ctx, data)
+ if n > 0 && n < len(data) {
+ // Not an error for partial copying. Packet truncated.
+ err = nil
+ }
+ return int64(n), err
+}
+
+// Readiness implements watier.Waitable.Readiness.
+func (fops *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return fops.device.Readiness(mask)
+}
+
+// EventRegister implements watier.Waitable.EventRegister.
+func (fops *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ fops.device.EventRegister(e, mask)
+}
+
+// EventUnregister implements watier.Waitable.EventUnregister.
+func (fops *netTunFileOperations) EventUnregister(e *waiter.Entry) {
+ fops.device.EventUnregister(e)
+}
diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go
index acab0411a..e0b32e1c1 100644
--- a/pkg/sentry/fs/dirent.go
+++ b/pkg/sentry/fs/dirent.go
@@ -1438,8 +1438,8 @@ func lockForRename(oldParent *Dirent, oldName string, newParent *Dirent, newName
}, nil
}
-func checkSticky(ctx context.Context, dir *Dirent, victim *Dirent) error {
- uattr, err := dir.Inode.UnstableAttr(ctx)
+func (d *Dirent) checkSticky(ctx context.Context, victim *Dirent) error {
+ uattr, err := d.Inode.UnstableAttr(ctx)
if err != nil {
return syserror.EPERM
}
@@ -1465,30 +1465,33 @@ func checkSticky(ctx context.Context, dir *Dirent, victim *Dirent) error {
return syserror.EPERM
}
-// MayDelete determines whether `name`, a child of `dir`, can be deleted or
+// MayDelete determines whether `name`, a child of `d`, can be deleted or
// renamed by `ctx`.
//
// Compare Linux kernel fs/namei.c:may_delete.
-func MayDelete(ctx context.Context, root, dir *Dirent, name string) error {
- if err := dir.Inode.CheckPermission(ctx, PermMask{Write: true, Execute: true}); err != nil {
+func (d *Dirent) MayDelete(ctx context.Context, root *Dirent, name string) error {
+ if err := d.Inode.CheckPermission(ctx, PermMask{Write: true, Execute: true}); err != nil {
return err
}
- victim, err := dir.Walk(ctx, root, name)
+ unlock := d.lockDirectory()
+ defer unlock()
+
+ victim, err := d.walk(ctx, root, name, true /* may unlock */)
if err != nil {
return err
}
defer victim.DecRef()
- return mayDelete(ctx, dir, victim)
+ return d.mayDelete(ctx, victim)
}
// mayDelete determines whether `victim`, a child of `dir`, can be deleted or
// renamed by `ctx`.
//
// Preconditions: `dir` is writable and executable by `ctx`.
-func mayDelete(ctx context.Context, dir, victim *Dirent) error {
- if err := checkSticky(ctx, dir, victim); err != nil {
+func (d *Dirent) mayDelete(ctx context.Context, victim *Dirent) error {
+ if err := d.checkSticky(ctx, victim); err != nil {
return err
}
@@ -1542,7 +1545,7 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string
defer renamed.DecRef()
// Check that the renamed dirent is deletable.
- if err := mayDelete(ctx, oldParent, renamed); err != nil {
+ if err := oldParent.mayDelete(ctx, renamed); err != nil {
return err
}
@@ -1580,7 +1583,7 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string
// across the Rename, so must call DecRef manually (no defer).
// Check that we can delete replaced.
- if err := mayDelete(ctx, newParent, replaced); err != nil {
+ if err := newParent.mayDelete(ctx, replaced); err != nil {
replaced.DecRef()
return err
}
diff --git a/pkg/sentry/fs/file_overlay_test.go b/pkg/sentry/fs/file_overlay_test.go
index 02538bb4f..a76d87e3a 100644
--- a/pkg/sentry/fs/file_overlay_test.go
+++ b/pkg/sentry/fs/file_overlay_test.go
@@ -177,6 +177,7 @@ func TestReaddirRevalidation(t *testing.T) {
// TestReaddirOverlayFrozen tests that calling Readdir on an overlay file with
// a frozen dirent tree does not make Readdir calls to the underlying files.
+// This is a regression test for b/114808269.
func TestReaddirOverlayFrozen(t *testing.T) {
ctx := contexttest.Context(t)
diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD
index 4ab2a384f..789369220 100644
--- a/pkg/sentry/fs/fsutil/BUILD
+++ b/pkg/sentry/fs/fsutil/BUILD
@@ -28,13 +28,13 @@ go_template_instance(
"platform": "gvisor.dev/gvisor/pkg/sentry/platform",
},
package = "fsutil",
- prefix = "frameRef",
+ prefix = "FrameRef",
template = "//pkg/segment:generic_set",
types = {
"Key": "uint64",
"Range": "platform.FileRange",
"Value": "uint64",
- "Functions": "frameRefSetFunctions",
+ "Functions": "FrameRefSetFunctions",
},
)
diff --git a/pkg/sentry/fs/fsutil/frame_ref_set.go b/pkg/sentry/fs/fsutil/frame_ref_set.go
index dd63db32b..6564fd0c6 100644
--- a/pkg/sentry/fs/fsutil/frame_ref_set.go
+++ b/pkg/sentry/fs/fsutil/frame_ref_set.go
@@ -20,24 +20,25 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform"
)
-type frameRefSetFunctions struct{}
+// FrameRefSetFunctions implements segment.Functions for FrameRefSet.
+type FrameRefSetFunctions struct{}
// MinKey implements segment.Functions.MinKey.
-func (frameRefSetFunctions) MinKey() uint64 {
+func (FrameRefSetFunctions) MinKey() uint64 {
return 0
}
// MaxKey implements segment.Functions.MaxKey.
-func (frameRefSetFunctions) MaxKey() uint64 {
+func (FrameRefSetFunctions) MaxKey() uint64 {
return math.MaxUint64
}
// ClearValue implements segment.Functions.ClearValue.
-func (frameRefSetFunctions) ClearValue(val *uint64) {
+func (FrameRefSetFunctions) ClearValue(val *uint64) {
}
// Merge implements segment.Functions.Merge.
-func (frameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform.FileRange, val2 uint64) (uint64, bool) {
+func (FrameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform.FileRange, val2 uint64) (uint64, bool) {
if val1 != val2 {
return 0, false
}
@@ -45,6 +46,6 @@ func (frameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform.
}
// Split implements segment.Functions.Split.
-func (frameRefSetFunctions) Split(_ platform.FileRange, val uint64, _ uint64) (uint64, uint64) {
+func (FrameRefSetFunctions) Split(_ platform.FileRange, val uint64, _ uint64) (uint64, uint64) {
return val, val
}
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go
index 67278aa86..e82afd112 100644
--- a/pkg/sentry/fs/fsutil/host_file_mapper.go
+++ b/pkg/sentry/fs/fsutil/host_file_mapper.go
@@ -65,13 +65,18 @@ type mapping struct {
writable bool
}
-// NewHostFileMapper returns a HostFileMapper with no references or cached
-// mappings.
+// Init must be called on zero-value HostFileMappers before first use.
+func (f *HostFileMapper) Init() {
+ f.refs = make(map[uint64]int32)
+ f.mappings = make(map[uint64]mapping)
+}
+
+// NewHostFileMapper returns an initialized HostFileMapper allocated on the
+// heap with no references or cached mappings.
func NewHostFileMapper() *HostFileMapper {
- return &HostFileMapper{
- refs: make(map[uint64]int32),
- mappings: make(map[uint64]mapping),
- }
+ f := &HostFileMapper{}
+ f.Init()
+ return f
}
// IncRefOn increments the reference count on all offsets in mr.
diff --git a/pkg/sentry/fs/fsutil/inode.go b/pkg/sentry/fs/fsutil/inode.go
index 252830572..daecc4ffe 100644
--- a/pkg/sentry/fs/fsutil/inode.go
+++ b/pkg/sentry/fs/fsutil/inode.go
@@ -247,7 +247,7 @@ func (i *InodeSimpleExtendedAttributes) SetXattr(_ context.Context, _ *fs.Inode,
}
// ListXattr implements fs.InodeOperations.ListXattr.
-func (i *InodeSimpleExtendedAttributes) ListXattr(context.Context, *fs.Inode) (map[string]struct{}, error) {
+func (i *InodeSimpleExtendedAttributes) ListXattr(context.Context, *fs.Inode, uint64) (map[string]struct{}, error) {
i.mu.RLock()
names := make(map[string]struct{}, len(i.xattrs))
for name := range i.xattrs {
@@ -257,6 +257,17 @@ func (i *InodeSimpleExtendedAttributes) ListXattr(context.Context, *fs.Inode) (m
return names, nil
}
+// RemoveXattr implements fs.InodeOperations.RemoveXattr.
+func (i *InodeSimpleExtendedAttributes) RemoveXattr(_ context.Context, _ *fs.Inode, name string) error {
+ i.mu.RLock()
+ defer i.mu.RUnlock()
+ if _, ok := i.xattrs[name]; ok {
+ delete(i.xattrs, name)
+ return nil
+ }
+ return syserror.ENOATTR
+}
+
// staticFile is a file with static contents. It is returned by
// InodeStaticFileGetter.GetFile.
//
@@ -460,10 +471,15 @@ func (InodeNoExtendedAttributes) SetXattr(context.Context, *fs.Inode, string, st
}
// ListXattr implements fs.InodeOperations.ListXattr.
-func (InodeNoExtendedAttributes) ListXattr(context.Context, *fs.Inode) (map[string]struct{}, error) {
+func (InodeNoExtendedAttributes) ListXattr(context.Context, *fs.Inode, uint64) (map[string]struct{}, error) {
return nil, syserror.EOPNOTSUPP
}
+// RemoveXattr implements fs.InodeOperations.RemoveXattr.
+func (InodeNoExtendedAttributes) RemoveXattr(context.Context, *fs.Inode, string) error {
+ return syserror.EOPNOTSUPP
+}
+
// InodeNoopRelease implements fs.InodeOperations.Release as a noop.
type InodeNoopRelease struct{}
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
index 573b8586e..800c8b4e1 100644
--- a/pkg/sentry/fs/fsutil/inode_cached.go
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -111,7 +111,7 @@ type CachingInodeOperations struct {
// refs tracks active references to data in the cache.
//
// refs is protected by dataMu.
- refs frameRefSet
+ refs FrameRefSet
}
// CachingInodeOperationsOptions configures a CachingInodeOperations.
diff --git a/pkg/sentry/fs/g3doc/inotify.md b/pkg/sentry/fs/g3doc/inotify.md
index 71a577d9d..85063d4e6 100644
--- a/pkg/sentry/fs/g3doc/inotify.md
+++ b/pkg/sentry/fs/g3doc/inotify.md
@@ -112,11 +112,11 @@ attempts to queue a new event, it is already holding `fs.Watches.mu`. If we used
`Inotify.mu` to also protect the event queue, this would violate the above lock
ordering.
-[dirent]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/fs/dirent.go
-[event]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/fs/inotify_event.go
-[fd_table]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/kernel/fd_table.go
-[inode]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/fs/inode.go
-[inode_watches]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/fs/inode_inotify.go
-[inotify]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/fs/inotify.go
-[syscall_dir]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/syscalls/linux/
-[watch]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/fs/inotify_watch.go
+[dirent]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/dirent.go
+[event]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inotify_event.go
+[fd_table]: https://github.com/google/gvisor/blob/master/pkg/sentry/kernel/fd_table.go
+[inode]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inode.go
+[inode_watches]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inode_inotify.go
+[inotify]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inotify.go
+[syscall_dir]: https://github.com/google/gvisor/blob/master/pkg/sentry/syscalls/linux/
+[watch]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inotify_watch.go
diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD
index 971d3718e..fea135eea 100644
--- a/pkg/sentry/fs/gofer/BUILD
+++ b/pkg/sentry/fs/gofer/BUILD
@@ -9,6 +9,7 @@ go_library(
"cache_policy.go",
"context_file.go",
"device.go",
+ "fifo.go",
"file.go",
"file_state.go",
"fs.go",
@@ -38,6 +39,7 @@ go_library(
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/host",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/pipe",
"//pkg/sentry/kernel/time",
"//pkg/sentry/memmap",
"//pkg/sentry/socket/unix/transport",
diff --git a/pkg/sentry/fs/gofer/attr.go b/pkg/sentry/fs/gofer/attr.go
index 71cccdc34..6db4b762d 100644
--- a/pkg/sentry/fs/gofer/attr.go
+++ b/pkg/sentry/fs/gofer/attr.go
@@ -88,8 +88,9 @@ func bsize(pattr p9.Attr) int64 {
if pattr.BlockSize > 0 {
return int64(pattr.BlockSize)
}
- // Some files may have no clue of their block size. Better not to report
- // something misleading or buggy and have a safe default.
+ // Some files, particularly those that are not on a local file system,
+ // may have no clue of their block size. Better not to report something
+ // misleading or buggy and have a safe default.
return usermem.PageSize
}
@@ -149,6 +150,7 @@ func links(valid p9.AttrMask, pattr p9.Attr) uint64 {
}
// This node is likely backed by a file system that doesn't support links.
+ //
// We could readdir() and count children directories to provide an accurate
// link count. However this may be expensive since the gofer may be backed by remote
// storage. Instead, simply return 2 links for directories and 1 for everything else
diff --git a/pkg/sentry/fs/gofer/cache_policy.go b/pkg/sentry/fs/gofer/cache_policy.go
index ebea03c42..07a564e92 100644
--- a/pkg/sentry/fs/gofer/cache_policy.go
+++ b/pkg/sentry/fs/gofer/cache_policy.go
@@ -127,6 +127,9 @@ func (cp cachePolicy) revalidate(ctx context.Context, name string, parent, child
childIops, ok := child.InodeOperations.(*inodeOperations)
if !ok {
+ if _, ok := child.InodeOperations.(*fifo); ok {
+ return false
+ }
panic(fmt.Sprintf("revalidating inode operations of unknown type %T", child.InodeOperations))
}
parentIops, ok := parent.InodeOperations.(*inodeOperations)
diff --git a/pkg/sentry/fs/gofer/context_file.go b/pkg/sentry/fs/gofer/context_file.go
index 3da818aed..125907d70 100644
--- a/pkg/sentry/fs/gofer/context_file.go
+++ b/pkg/sentry/fs/gofer/context_file.go
@@ -73,6 +73,20 @@ func (c *contextFile) setXattr(ctx context.Context, name, value string, flags ui
return err
}
+func (c *contextFile) listXattr(ctx context.Context, size uint64) (map[string]struct{}, error) {
+ ctx.UninterruptibleSleepStart(false)
+ xattrs, err := c.file.ListXattr(size)
+ ctx.UninterruptibleSleepFinish(false)
+ return xattrs, err
+}
+
+func (c *contextFile) removeXattr(ctx context.Context, name string) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := c.file.RemoveXattr(name)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
func (c *contextFile) allocate(ctx context.Context, mode p9.AllocateMode, offset, length uint64) error {
ctx.UninterruptibleSleepStart(false)
err := c.file.Allocate(mode, offset, length)
diff --git a/pkg/sentry/fs/gofer/fifo.go b/pkg/sentry/fs/gofer/fifo.go
new file mode 100644
index 000000000..456557058
--- /dev/null
+++ b/pkg/sentry/fs/gofer/fifo.go
@@ -0,0 +1,40 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+)
+
+// +stateify savable
+type fifo struct {
+ fs.InodeOperations
+ fileIops *inodeOperations
+}
+
+var _ fs.InodeOperations = (*fifo)(nil)
+
+// Rename implements fs.InodeOperations. It forwards the call to the underlying
+// file inode to handle the file rename. Note that file key remains the same
+// after the rename to keep the endpoint mapping.
+func (i *fifo) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
+ return i.fileIops.Rename(ctx, inode, oldParent, oldName, newParent, newName, replacement)
+}
+
+// StatFS implements fs.InodeOperations.
+func (i *fifo) StatFS(ctx context.Context) (fs.Info, error) {
+ return i.fileIops.StatFS(ctx)
+}
diff --git a/pkg/sentry/fs/gofer/gofer_test.go b/pkg/sentry/fs/gofer/gofer_test.go
index 0c2f89ae8..2df2fe889 100644
--- a/pkg/sentry/fs/gofer/gofer_test.go
+++ b/pkg/sentry/fs/gofer/gofer_test.go
@@ -61,7 +61,7 @@ func rootTest(t *testing.T, name string, cp cachePolicy, fn func(context.Context
ctx := contexttest.Context(t)
sattr, rootInodeOperations := newInodeOperations(ctx, s, contextFile{
file: rootFile,
- }, root.QID, p9.AttrMaskAll(), root.Attr, false /* socket */)
+ }, root.QID, p9.AttrMaskAll(), root.Attr)
m := fs.NewMountSource(ctx, s, &filesystem{}, fs.MountSourceFlags{})
rootInode := fs.NewInode(ctx, rootInodeOperations, m, sattr)
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index ac28174d2..1c934981b 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -604,18 +604,23 @@ func (i *inodeOperations) Truncate(ctx context.Context, inode *fs.Inode, length
}
// GetXattr implements fs.InodeOperations.GetXattr.
-func (i *inodeOperations) GetXattr(ctx context.Context, inode *fs.Inode, name string, size uint64) (string, error) {
+func (i *inodeOperations) GetXattr(ctx context.Context, _ *fs.Inode, name string, size uint64) (string, error) {
return i.fileState.file.getXattr(ctx, name, size)
}
// SetXattr implements fs.InodeOperations.SetXattr.
-func (i *inodeOperations) SetXattr(ctx context.Context, inode *fs.Inode, name string, value string, flags uint32) error {
+func (i *inodeOperations) SetXattr(ctx context.Context, _ *fs.Inode, name string, value string, flags uint32) error {
return i.fileState.file.setXattr(ctx, name, value, flags)
}
// ListXattr implements fs.InodeOperations.ListXattr.
-func (i *inodeOperations) ListXattr(context.Context, *fs.Inode) (map[string]struct{}, error) {
- return nil, syscall.EOPNOTSUPP
+func (i *inodeOperations) ListXattr(ctx context.Context, _ *fs.Inode, size uint64) (map[string]struct{}, error) {
+ return i.fileState.file.listXattr(ctx, size)
+}
+
+// RemoveXattr implements fs.InodeOperations.RemoveXattr.
+func (i *inodeOperations) RemoveXattr(ctx context.Context, _ *fs.Inode, name string) error {
+ return i.fileState.file.removeXattr(ctx, name)
}
// Allocate implements fs.InodeOperations.Allocate.
diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go
index 0c1be05ef..a35c3a23d 100644
--- a/pkg/sentry/fs/gofer/path.go
+++ b/pkg/sentry/fs/gofer/path.go
@@ -23,14 +23,24 @@ import (
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// maxFilenameLen is the maximum length of a filename. This is dictated by 9P's
// encoding of strings, which uses 2 bytes for the length prefix.
const maxFilenameLen = (1 << 16) - 1
+func changeType(mode p9.FileMode, newType p9.FileMode) p9.FileMode {
+ if newType&^p9.FileModeMask != 0 {
+ panic(fmt.Sprintf("newType contained more bits than just file mode: %x", newType))
+ }
+ clear := mode &^ p9.FileModeMask
+ return clear | newType
+}
+
// Lookup loads an Inode at name into a Dirent based on the session's cache
// policy.
func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string) (*fs.Dirent, error) {
@@ -69,8 +79,25 @@ func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string
return nil, err
}
+ if i.session().overrides != nil {
+ // Check if file belongs to a internal named pipe. Note that it doesn't need
+ // to check for sockets because it's done in newInodeOperations below.
+ deviceKey := device.MultiDeviceKey{
+ Device: p9attr.RDev,
+ SecondaryDevice: i.session().connID,
+ Inode: qids[0].Path,
+ }
+ unlock := i.session().overrides.lock()
+ if pipeInode := i.session().overrides.getPipe(deviceKey); pipeInode != nil {
+ unlock()
+ pipeInode.IncRef()
+ return fs.NewDirent(ctx, pipeInode, name), nil
+ }
+ unlock()
+ }
+
// Construct the Inode operations.
- sattr, node := newInodeOperations(ctx, i.fileState.s, newFile, qids[0], mask, p9attr, false)
+ sattr, node := newInodeOperations(ctx, i.fileState.s, newFile, qids[0], mask, p9attr)
// Construct a positive Dirent.
return fs.NewDirent(ctx, fs.NewInode(ctx, node, dir.MountSource, sattr), name), nil
@@ -138,7 +165,7 @@ func (i *inodeOperations) Create(ctx context.Context, dir *fs.Inode, name string
qid := qids[0]
// Construct the InodeOperations.
- sattr, iops := newInodeOperations(ctx, i.fileState.s, unopened, qid, mask, p9attr, false)
+ sattr, iops := newInodeOperations(ctx, i.fileState.s, unopened, qid, mask, p9attr)
// Construct the positive Dirent.
d := fs.NewDirent(ctx, fs.NewInode(ctx, iops, dir.MountSource, sattr), name)
@@ -223,82 +250,115 @@ func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string,
return nil, syserror.ENAMETOOLONG
}
- if i.session().endpoints == nil {
+ if i.session().overrides == nil {
return nil, syscall.EOPNOTSUPP
}
- // Create replaces the directory fid with the newly created/opened
- // file, so clone this directory so it doesn't change out from under
- // this node.
- _, newFile, err := i.fileState.file.walk(ctx, nil)
+ // Stabilize the override map while creation is in progress.
+ unlock := i.session().overrides.lock()
+ defer unlock()
+
+ sattr, iops, err := i.createEndpointFile(ctx, dir, name, perm, p9.ModeSocket)
if err != nil {
return nil, err
}
- // We're not going to use newFile after return.
- defer newFile.close(ctx)
- // Stabilize the endpoint map while creation is in progress.
- unlock := i.session().endpoints.lock()
- defer unlock()
+ // Construct the positive Dirent.
+ childDir := fs.NewDirent(ctx, fs.NewInode(ctx, iops, dir.MountSource, sattr), name)
+ i.session().overrides.addBoundEndpoint(iops.fileState.key, childDir, ep)
+ return childDir, nil
+}
- // Create a regular file in the gofer and then mark it as a socket by
- // adding this inode key in the 'endpoints' map.
- owner := fs.FileOwnerFromContext(ctx)
- hostFile, err := newFile.create(ctx, name, p9.ReadWrite, p9.FileMode(perm.LinuxMode()), p9.UID(owner.UID), p9.GID(owner.GID))
- if err != nil {
- return nil, err
+// CreateFifo implements fs.InodeOperations.CreateFifo.
+func (i *inodeOperations) CreateFifo(ctx context.Context, dir *fs.Inode, name string, perm fs.FilePermissions) error {
+ if len(name) > maxFilenameLen {
+ return syserror.ENAMETOOLONG
}
- // We're not going to use this file.
- hostFile.Close()
- i.touchModificationAndStatusChangeTime(ctx, dir)
+ owner := fs.FileOwnerFromContext(ctx)
+ mode := p9.FileMode(perm.LinuxMode()) | p9.ModeNamedPipe
- // Get the attributes of the file to create inode key.
- qid, mask, attr, err := getattr(ctx, newFile)
- if err != nil {
- return nil, err
+ // N.B. FIFOs use major/minor numbers 0.
+ if _, err := i.fileState.file.mknod(ctx, name, mode, 0, 0, p9.UID(owner.UID), p9.GID(owner.GID)); err != nil {
+ if i.session().overrides == nil || err != syscall.EPERM {
+ return err
+ }
+ // If gofer doesn't support mknod, check if we can create an internal fifo.
+ return i.createInternalFifo(ctx, dir, name, owner, perm)
}
- key := device.MultiDeviceKey{
- Device: attr.RDev,
- SecondaryDevice: i.session().connID,
- Inode: qid.Path,
+ i.touchModificationAndStatusChangeTime(ctx, dir)
+ return nil
+}
+
+func (i *inodeOperations) createInternalFifo(ctx context.Context, dir *fs.Inode, name string, owner fs.FileOwner, perm fs.FilePermissions) error {
+ if i.session().overrides == nil {
+ return syserror.EPERM
}
- // Create child dirent.
+ // Stabilize the override map while creation is in progress.
+ unlock := i.session().overrides.lock()
+ defer unlock()
- // Get an unopened p9.File for the file we created so that it can be
- // cloned and re-opened multiple times after creation.
- _, unopened, err := i.fileState.file.walk(ctx, []string{name})
+ sattr, fileOps, err := i.createEndpointFile(ctx, dir, name, perm, p9.ModeNamedPipe)
if err != nil {
- return nil, err
+ return err
}
- // Construct the InodeOperations.
- sattr, iops := newInodeOperations(ctx, i.fileState.s, unopened, qid, mask, attr, true)
+ // First create a pipe.
+ p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)
+
+ // Wrap the fileOps with our Fifo.
+ iops := &fifo{
+ InodeOperations: pipe.NewInodeOperations(ctx, perm, p),
+ fileIops: fileOps,
+ }
+ inode := fs.NewInode(ctx, iops, dir.MountSource, sattr)
// Construct the positive Dirent.
childDir := fs.NewDirent(ctx, fs.NewInode(ctx, iops, dir.MountSource, sattr), name)
- i.session().endpoints.add(key, childDir, ep)
- return childDir, nil
+ i.session().overrides.addPipe(fileOps.fileState.key, childDir, inode)
+ return nil
}
-// CreateFifo implements fs.InodeOperations.CreateFifo.
-func (i *inodeOperations) CreateFifo(ctx context.Context, dir *fs.Inode, name string, perm fs.FilePermissions) error {
- if len(name) > maxFilenameLen {
- return syserror.ENAMETOOLONG
+// Caller must hold Session.endpoint lock.
+func (i *inodeOperations) createEndpointFile(ctx context.Context, dir *fs.Inode, name string, perm fs.FilePermissions, fileType p9.FileMode) (fs.StableAttr, *inodeOperations, error) {
+ _, dirClone, err := i.fileState.file.walk(ctx, nil)
+ if err != nil {
+ return fs.StableAttr{}, nil, err
}
+ // We're not going to use dirClone after return.
+ defer dirClone.close(ctx)
+ // Create a regular file in the gofer and then mark it as a socket by
+ // adding this inode key in the 'overrides' map.
owner := fs.FileOwnerFromContext(ctx)
- mode := p9.FileMode(perm.LinuxMode()) | p9.ModeNamedPipe
-
- // N.B. FIFOs use major/minor numbers 0.
- if _, err := i.fileState.file.mknod(ctx, name, mode, 0, 0, p9.UID(owner.UID), p9.GID(owner.GID)); err != nil {
- return err
+ hostFile, err := dirClone.create(ctx, name, p9.ReadWrite, p9.FileMode(perm.LinuxMode()), p9.UID(owner.UID), p9.GID(owner.GID))
+ if err != nil {
+ return fs.StableAttr{}, nil, err
}
+ // We're not going to use this file.
+ hostFile.Close()
i.touchModificationAndStatusChangeTime(ctx, dir)
- return nil
+
+ // Get the attributes of the file to create inode key.
+ qid, mask, attr, err := getattr(ctx, dirClone)
+ if err != nil {
+ return fs.StableAttr{}, nil, err
+ }
+
+ // Get an unopened p9.File for the file we created so that it can be
+ // cloned and re-opened multiple times after creation.
+ _, unopened, err := i.fileState.file.walk(ctx, []string{name})
+ if err != nil {
+ return fs.StableAttr{}, nil, err
+ }
+
+ // Construct new inode with file type overridden.
+ attr.Mode = changeType(attr.Mode, fileType)
+ sattr, iops := newInodeOperations(ctx, i.fileState.s, unopened, qid, mask, attr)
+ return sattr, iops, nil
}
// Remove implements InodeOperations.Remove.
@@ -307,20 +367,23 @@ func (i *inodeOperations) Remove(ctx context.Context, dir *fs.Inode, name string
return syserror.ENAMETOOLONG
}
- var key device.MultiDeviceKey
- removeSocket := false
- if i.session().endpoints != nil {
- // Find out if file being deleted is a socket that needs to be
+ var key *device.MultiDeviceKey
+ if i.session().overrides != nil {
+ // Find out if file being deleted is a socket or pipe that needs to be
// removed from endpoint map.
if d, err := i.Lookup(ctx, dir, name); err == nil {
defer d.DecRef()
- if fs.IsSocket(d.Inode.StableAttr) {
- child := d.Inode.InodeOperations.(*inodeOperations)
- key = child.fileState.key
- removeSocket = true
- // Stabilize the endpoint map while deletion is in progress.
- unlock := i.session().endpoints.lock()
+ if fs.IsSocket(d.Inode.StableAttr) || fs.IsPipe(d.Inode.StableAttr) {
+ switch iops := d.Inode.InodeOperations.(type) {
+ case *inodeOperations:
+ key = &iops.fileState.key
+ case *fifo:
+ key = &iops.fileIops.fileState.key
+ }
+
+ // Stabilize the override map while deletion is in progress.
+ unlock := i.session().overrides.lock()
defer unlock()
}
}
@@ -329,8 +392,8 @@ func (i *inodeOperations) Remove(ctx context.Context, dir *fs.Inode, name string
if err := i.fileState.file.unlinkAt(ctx, name, 0); err != nil {
return err
}
- if removeSocket {
- i.session().endpoints.remove(key)
+ if key != nil {
+ i.session().overrides.remove(*key)
}
i.touchModificationAndStatusChangeTime(ctx, dir)
diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go
index 498c4645a..f6b3ef178 100644
--- a/pkg/sentry/fs/gofer/session.go
+++ b/pkg/sentry/fs/gofer/session.go
@@ -33,60 +33,107 @@ import (
var DefaultDirentCacheSize uint64 = fs.DefaultDirentCacheSize
// +stateify savable
-type endpointMaps struct {
- // mu protexts the direntMap, the keyMap, and the pathMap below.
- mu sync.RWMutex `state:"nosave"`
+type overrideInfo struct {
+ dirent *fs.Dirent
+
+ // endpoint is set when dirent points to a socket. inode must not be set.
+ endpoint transport.BoundEndpoint
+
+ // inode is set when dirent points to a pipe. endpoint must not be set.
+ inode *fs.Inode
+}
- // direntMap links sockets to their dirents.
- // It is filled concurrently with the keyMap and is stored upon save.
- // Before saving, this map is used to populate the pathMap.
- direntMap map[transport.BoundEndpoint]*fs.Dirent
+func (l *overrideInfo) inodeType() fs.InodeType {
+ switch {
+ case l.endpoint != nil:
+ return fs.Socket
+ case l.inode != nil:
+ return fs.Pipe
+ }
+ panic("endpoint or node must be set")
+}
- // keyMap links MultiDeviceKeys (containing inode IDs) to their sockets.
+// +stateify savable
+type overrideMaps struct {
+ // mu protexts the keyMap, and the pathMap below.
+ mu sync.RWMutex `state:"nosave"`
+
+ // keyMap links MultiDeviceKeys (containing inode IDs) to their sockets/pipes.
// It is not stored during save because the inode ID may change upon restore.
- keyMap map[device.MultiDeviceKey]transport.BoundEndpoint `state:"nosave"`
+ keyMap map[device.MultiDeviceKey]*overrideInfo `state:"nosave"`
- // pathMap links the sockets to their paths.
+ // pathMap links the sockets/pipes to their paths.
// It is filled before saving from the direntMap and is stored upon save.
// Upon restore, this map is used to re-populate the keyMap.
- pathMap map[transport.BoundEndpoint]string
+ pathMap map[*overrideInfo]string
+}
+
+// addBoundEndpoint adds the bound endpoint to the map.
+// A reference is taken on the dirent argument.
+//
+// Precondition: maps must have been locked with 'lock'.
+func (e *overrideMaps) addBoundEndpoint(key device.MultiDeviceKey, d *fs.Dirent, ep transport.BoundEndpoint) {
+ d.IncRef()
+ e.keyMap[key] = &overrideInfo{dirent: d, endpoint: ep}
}
-// add adds the endpoint to the maps.
+// addPipe adds the pipe inode to the map.
// A reference is taken on the dirent argument.
//
// Precondition: maps must have been locked with 'lock'.
-func (e *endpointMaps) add(key device.MultiDeviceKey, d *fs.Dirent, ep transport.BoundEndpoint) {
- e.keyMap[key] = ep
+func (e *overrideMaps) addPipe(key device.MultiDeviceKey, d *fs.Dirent, inode *fs.Inode) {
d.IncRef()
- e.direntMap[ep] = d
+ e.keyMap[key] = &overrideInfo{dirent: d, inode: inode}
}
// remove deletes the key from the maps.
//
// Precondition: maps must have been locked with 'lock'.
-func (e *endpointMaps) remove(key device.MultiDeviceKey) {
- endpoint := e.get(key)
+func (e *overrideMaps) remove(key device.MultiDeviceKey) {
+ endpoint := e.keyMap[key]
delete(e.keyMap, key)
-
- d := e.direntMap[endpoint]
- d.DecRef()
- delete(e.direntMap, endpoint)
+ endpoint.dirent.DecRef()
}
// lock blocks other addition and removal operations from happening while
// the backing file is being created or deleted. Returns a function that unlocks
// the endpoint map.
-func (e *endpointMaps) lock() func() {
+func (e *overrideMaps) lock() func() {
e.mu.Lock()
return func() { e.mu.Unlock() }
}
-// get returns the endpoint mapped to the given key.
+// getBoundEndpoint returns the bound endpoint mapped to the given key.
//
-// Precondition: maps must have been locked for reading.
-func (e *endpointMaps) get(key device.MultiDeviceKey) transport.BoundEndpoint {
- return e.keyMap[key]
+// Precondition: maps must have been locked.
+func (e *overrideMaps) getBoundEndpoint(key device.MultiDeviceKey) transport.BoundEndpoint {
+ if v := e.keyMap[key]; v != nil {
+ return v.endpoint
+ }
+ return nil
+}
+
+// getPipe returns the pipe inode mapped to the given key.
+//
+// Precondition: maps must have been locked.
+func (e *overrideMaps) getPipe(key device.MultiDeviceKey) *fs.Inode {
+ if v := e.keyMap[key]; v != nil {
+ return v.inode
+ }
+ return nil
+}
+
+// getType returns the inode type if there is a corresponding endpoint for the
+// given key. Returns false otherwise.
+func (e *overrideMaps) getType(key device.MultiDeviceKey) (fs.InodeType, bool) {
+ e.mu.Lock()
+ v := e.keyMap[key]
+ e.mu.Unlock()
+
+ if v != nil {
+ return v.inodeType(), true
+ }
+ return 0, false
}
// session holds state for each 9p session established during sys_mount.
@@ -137,16 +184,16 @@ type session struct {
// mounter is the EUID/EGID that mounted this file system.
mounter fs.FileOwner `state:"wait"`
- // endpoints is used to map inodes that represent socket files to their
- // corresponding endpoint. Socket files are created as regular files in the
- // gofer and their presence in this map indicate that they should indeed be
- // socket files. This allows unix domain sockets to be used with paths that
- // belong to a gofer.
+ // overrides is used to map inodes that represent socket/pipes files to their
+ // corresponding endpoint/iops. These files are created as regular files in
+ // the gofer and their presence in this map indicate that they should indeed
+ // be socket/pipe files. This allows unix domain sockets and named pipes to
+ // be used with paths that belong to a gofer.
//
// TODO(gvisor.dev/issue/1200): there are few possible races with someone
// stat'ing the file and another deleting it concurrently, where the file
// will not be reported as socket file.
- endpoints *endpointMaps `state:"wait"`
+ overrides *overrideMaps `state:"wait"`
}
// Destroy tears down the session.
@@ -179,15 +226,21 @@ func (s *session) SaveInodeMapping(inode *fs.Inode, path string) {
// This is very unintuitive. We *CANNOT* trust the inode's StableAttrs,
// because overlay copyUp may have changed them out from under us.
// So much for "immutable".
- sattr := inode.InodeOperations.(*inodeOperations).fileState.sattr
- s.inodeMappings[sattr.InodeID] = path
+ switch iops := inode.InodeOperations.(type) {
+ case *inodeOperations:
+ s.inodeMappings[iops.fileState.sattr.InodeID] = path
+ case *fifo:
+ s.inodeMappings[iops.fileIops.fileState.sattr.InodeID] = path
+ default:
+ panic(fmt.Sprintf("Invalid type: %T", iops))
+ }
}
-// newInodeOperations creates a new 9p fs.InodeOperations backed by a p9.File and attributes
-// (p9.QID, p9.AttrMask, p9.Attr).
+// newInodeOperations creates a new 9p fs.InodeOperations backed by a p9.File
+// and attributes (p9.QID, p9.AttrMask, p9.Attr).
//
// Endpoints lock must not be held if socket == false.
-func newInodeOperations(ctx context.Context, s *session, file contextFile, qid p9.QID, valid p9.AttrMask, attr p9.Attr, socket bool) (fs.StableAttr, *inodeOperations) {
+func newInodeOperations(ctx context.Context, s *session, file contextFile, qid p9.QID, valid p9.AttrMask, attr p9.Attr) (fs.StableAttr, *inodeOperations) {
deviceKey := device.MultiDeviceKey{
Device: attr.RDev,
SecondaryDevice: s.connID,
@@ -201,17 +254,11 @@ func newInodeOperations(ctx context.Context, s *session, file contextFile, qid p
BlockSize: bsize(attr),
}
- if s.endpoints != nil {
- if socket {
- sattr.Type = fs.Socket
- } else {
- // If unix sockets are allowed on this filesystem, check if this file is
- // supposed to be a socket file.
- unlock := s.endpoints.lock()
- if s.endpoints.get(deviceKey) != nil {
- sattr.Type = fs.Socket
- }
- unlock()
+ if s.overrides != nil && sattr.Type == fs.RegularFile {
+ // If overrides are allowed on this filesystem, check if this file is
+ // supposed to be of a different type, e.g. socket.
+ if t, ok := s.overrides.getType(deviceKey); ok {
+ sattr.Type = t
}
}
@@ -267,7 +314,7 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF
s.EnableLeakCheck("gofer.session")
if o.privateunixsocket {
- s.endpoints = newEndpointMaps()
+ s.overrides = newOverrideMaps()
}
// Construct the MountSource with the session and superBlockFlags.
@@ -305,26 +352,24 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF
return nil, err
}
- sattr, iops := newInodeOperations(ctx, &s, s.attach, qid, valid, attr, false)
+ sattr, iops := newInodeOperations(ctx, &s, s.attach, qid, valid, attr)
return fs.NewInode(ctx, iops, m, sattr), nil
}
-// newEndpointMaps creates a new endpointMaps.
-func newEndpointMaps() *endpointMaps {
- return &endpointMaps{
- direntMap: make(map[transport.BoundEndpoint]*fs.Dirent),
- keyMap: make(map[device.MultiDeviceKey]transport.BoundEndpoint),
- pathMap: make(map[transport.BoundEndpoint]string),
+// newOverrideMaps creates a new overrideMaps.
+func newOverrideMaps() *overrideMaps {
+ return &overrideMaps{
+ keyMap: make(map[device.MultiDeviceKey]*overrideInfo),
+ pathMap: make(map[*overrideInfo]string),
}
}
-// fillKeyMap populates key and dirent maps upon restore from saved
-// pathmap.
+// fillKeyMap populates key and dirent maps upon restore from saved pathmap.
func (s *session) fillKeyMap(ctx context.Context) error {
- unlock := s.endpoints.lock()
+ unlock := s.overrides.lock()
defer unlock()
- for ep, dirPath := range s.endpoints.pathMap {
+ for ep, dirPath := range s.overrides.pathMap {
_, file, err := s.attach.walk(ctx, splitAbsolutePath(dirPath))
if err != nil {
return fmt.Errorf("error filling endpointmaps, failed to walk to %q: %v", dirPath, err)
@@ -341,25 +386,25 @@ func (s *session) fillKeyMap(ctx context.Context) error {
Inode: qid.Path,
}
- s.endpoints.keyMap[key] = ep
+ s.overrides.keyMap[key] = ep
}
return nil
}
-// fillPathMap populates paths for endpoints from dirents in direntMap
+// fillPathMap populates paths for overrides from dirents in direntMap
// before save.
func (s *session) fillPathMap() error {
- unlock := s.endpoints.lock()
+ unlock := s.overrides.lock()
defer unlock()
- for ep, dir := range s.endpoints.direntMap {
- mountRoot := dir.MountRoot()
+ for _, endpoint := range s.overrides.keyMap {
+ mountRoot := endpoint.dirent.MountRoot()
defer mountRoot.DecRef()
- dirPath, _ := dir.FullName(mountRoot)
+ dirPath, _ := endpoint.dirent.FullName(mountRoot)
if dirPath == "" {
return fmt.Errorf("error getting path from dirent")
}
- s.endpoints.pathMap[ep] = dirPath
+ s.overrides.pathMap[endpoint] = dirPath
}
return nil
}
@@ -368,7 +413,7 @@ func (s *session) fillPathMap() error {
func (s *session) restoreEndpointMaps(ctx context.Context) error {
// When restoring, only need to create the keyMap because the dirent and path
// maps got stored through the save.
- s.endpoints.keyMap = make(map[device.MultiDeviceKey]transport.BoundEndpoint)
+ s.overrides.keyMap = make(map[device.MultiDeviceKey]*overrideInfo)
if err := s.fillKeyMap(ctx); err != nil {
return fmt.Errorf("failed to insert sockets into endpoint map: %v", err)
}
@@ -376,6 +421,6 @@ func (s *session) restoreEndpointMaps(ctx context.Context) error {
// Re-create pathMap because it can no longer be trusted as socket paths can
// change while process continues to run. Empty pathMap will be re-filled upon
// next save.
- s.endpoints.pathMap = make(map[transport.BoundEndpoint]string)
+ s.overrides.pathMap = make(map[*overrideInfo]string)
return nil
}
diff --git a/pkg/sentry/fs/gofer/session_state.go b/pkg/sentry/fs/gofer/session_state.go
index 0285c5361..111da59f9 100644
--- a/pkg/sentry/fs/gofer/session_state.go
+++ b/pkg/sentry/fs/gofer/session_state.go
@@ -25,9 +25,9 @@ import (
// beforeSave is invoked by stateify.
func (s *session) beforeSave() {
- if s.endpoints != nil {
+ if s.overrides != nil {
if err := s.fillPathMap(); err != nil {
- panic("failed to save paths to endpoint map before saving" + err.Error())
+ panic("failed to save paths to override map before saving" + err.Error())
}
}
}
@@ -74,10 +74,10 @@ func (s *session) afterLoad() {
panic(fmt.Sprintf("new attach name %v, want %v", opts.aname, s.aname))
}
- // Check if endpointMaps exist when uds sockets are enabled
- // (only pathmap will actualy have been saved).
- if opts.privateunixsocket != (s.endpoints != nil) {
- panic(fmt.Sprintf("new privateunixsocket option %v, want %v", opts.privateunixsocket, s.endpoints != nil))
+ // Check if overrideMaps exist when uds sockets are enabled (only pathmaps
+ // will actually have been saved).
+ if opts.privateunixsocket != (s.overrides != nil) {
+ panic(fmt.Sprintf("new privateunixsocket option %v, want %v", opts.privateunixsocket, s.overrides != nil))
}
if args.Flags != s.superBlockFlags {
panic(fmt.Sprintf("new mount flags %v, want %v", args.Flags, s.superBlockFlags))
diff --git a/pkg/sentry/fs/gofer/socket.go b/pkg/sentry/fs/gofer/socket.go
index 376cfce2c..10ba2f5f0 100644
--- a/pkg/sentry/fs/gofer/socket.go
+++ b/pkg/sentry/fs/gofer/socket.go
@@ -32,15 +32,15 @@ func (i *inodeOperations) BoundEndpoint(inode *fs.Inode, path string) transport.
return nil
}
- if i.session().endpoints != nil {
- unlock := i.session().endpoints.lock()
+ if i.session().overrides != nil {
+ unlock := i.session().overrides.lock()
defer unlock()
- ep := i.session().endpoints.get(i.fileState.key)
+ ep := i.session().overrides.getBoundEndpoint(i.fileState.key)
if ep != nil {
return ep
}
- // Not found in endpoints map, it may be a gofer backed unix socket...
+ // Not found in overrides map, it may be a gofer backed unix socket...
}
inode.IncRef()
diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go
index b66c091ab..55fb71c16 100644
--- a/pkg/sentry/fs/inode.go
+++ b/pkg/sentry/fs/inode.go
@@ -278,11 +278,19 @@ func (i *Inode) SetXattr(ctx context.Context, d *Dirent, name, value string, fla
}
// ListXattr calls i.InodeOperations.ListXattr with i as the Inode.
-func (i *Inode) ListXattr(ctx context.Context) (map[string]struct{}, error) {
+func (i *Inode) ListXattr(ctx context.Context, size uint64) (map[string]struct{}, error) {
if i.overlay != nil {
- return overlayListXattr(ctx, i.overlay)
+ return overlayListXattr(ctx, i.overlay, size)
}
- return i.InodeOperations.ListXattr(ctx, i)
+ return i.InodeOperations.ListXattr(ctx, i, size)
+}
+
+// RemoveXattr calls i.InodeOperations.RemoveXattr with i as the Inode.
+func (i *Inode) RemoveXattr(ctx context.Context, d *Dirent, name string) error {
+ if i.overlay != nil {
+ return overlayRemoveXattr(ctx, i.overlay, d, name)
+ }
+ return i.InodeOperations.RemoveXattr(ctx, i, name)
}
// CheckPermission will check if the caller may access this file in the
diff --git a/pkg/sentry/fs/inode_operations.go b/pkg/sentry/fs/inode_operations.go
index 70f2eae96..2bbfb72ef 100644
--- a/pkg/sentry/fs/inode_operations.go
+++ b/pkg/sentry/fs/inode_operations.go
@@ -190,7 +190,18 @@ type InodeOperations interface {
// ListXattr returns the set of all extended attributes names that
// have values. Inodes that do not support extended attributes return
// EOPNOTSUPP.
- ListXattr(ctx context.Context, inode *Inode) (map[string]struct{}, error)
+ //
+ // If this is called through the listxattr(2) syscall, size indicates the
+ // size of the buffer that the application has allocated to hold the
+ // attribute list. If the list would be larger than size, implementations may
+ // return ERANGE to indicate that the buffer is too small, but they are also
+ // free to ignore the hint entirely. All size checking is done independently
+ // at the syscall layer.
+ ListXattr(ctx context.Context, inode *Inode, size uint64) (map[string]struct{}, error)
+
+ // RemoveXattr removes an extended attribute specified by name. Inodes that
+ // do not support extended attributes return EOPNOTSUPP.
+ RemoveXattr(ctx context.Context, inode *Inode, name string) error
// Check determines whether an Inode can be accessed with the
// requested permission mask using the context (which gives access
diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go
index 4729b4aac..5ada33a32 100644
--- a/pkg/sentry/fs/inode_overlay.go
+++ b/pkg/sentry/fs/inode_overlay.go
@@ -564,15 +564,15 @@ func overlaySetxattr(ctx context.Context, o *overlayEntry, d *Dirent, name, valu
return o.upper.SetXattr(ctx, d, name, value, flags)
}
-func overlayListXattr(ctx context.Context, o *overlayEntry) (map[string]struct{}, error) {
+func overlayListXattr(ctx context.Context, o *overlayEntry, size uint64) (map[string]struct{}, error) {
o.copyMu.RLock()
defer o.copyMu.RUnlock()
var names map[string]struct{}
var err error
if o.upper != nil {
- names, err = o.upper.ListXattr(ctx)
+ names, err = o.upper.ListXattr(ctx, size)
} else {
- names, err = o.lower.ListXattr(ctx)
+ names, err = o.lower.ListXattr(ctx, size)
}
for name := range names {
// Same as overlayGetXattr, we shouldn't forward along
@@ -584,6 +584,18 @@ func overlayListXattr(ctx context.Context, o *overlayEntry) (map[string]struct{}
return names, err
}
+func overlayRemoveXattr(ctx context.Context, o *overlayEntry, d *Dirent, name string) error {
+ // Don't allow changes to overlay xattrs through a removexattr syscall.
+ if strings.HasPrefix(XattrOverlayPrefix, name) {
+ return syserror.EPERM
+ }
+
+ if err := copyUp(ctx, d); err != nil {
+ return err
+ }
+ return o.upper.RemoveXattr(ctx, d, name)
+}
+
func overlayCheck(ctx context.Context, o *overlayEntry, p PermMask) error {
o.copyMu.RLock()
// Hot path. Avoid defers.
diff --git a/pkg/sentry/fs/mount_test.go b/pkg/sentry/fs/mount_test.go
index e672a438c..a3d10770b 100644
--- a/pkg/sentry/fs/mount_test.go
+++ b/pkg/sentry/fs/mount_test.go
@@ -36,11 +36,12 @@ func mountPathsAre(root *Dirent, got []*Mount, want ...string) error {
gotPaths := make(map[string]struct{}, len(got))
gotStr := make([]string, len(got))
for i, g := range got {
- groot := g.Root()
- name, _ := groot.FullName(root)
- groot.DecRef()
- gotStr[i] = name
- gotPaths[name] = struct{}{}
+ if groot := g.Root(); groot != nil {
+ name, _ := groot.FullName(root)
+ groot.DecRef()
+ gotStr[i] = name
+ gotPaths[name] = struct{}{}
+ }
}
if len(got) != len(want) {
return fmt.Errorf("mount paths are different, got: %q, want: %q", gotStr, want)
diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go
index 574a2cc91..c7981f66e 100644
--- a/pkg/sentry/fs/mounts.go
+++ b/pkg/sentry/fs/mounts.go
@@ -100,10 +100,14 @@ func newUndoMount(d *Dirent) *Mount {
}
}
-// Root returns the root dirent of this mount. Callers must call DecRef on the
-// returned dirent.
+// Root returns the root dirent of this mount.
+//
+// This may return nil if the mount has already been free. Callers must handle this
+// case appropriately. If non-nil, callers must call DecRef on the returned *Dirent.
func (m *Mount) Root() *Dirent {
- m.root.IncRef()
+ if !m.root.TryIncRef() {
+ return nil
+ }
return m.root
}
diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD
index 280093c5e..77c2c5c0e 100644
--- a/pkg/sentry/fs/proc/BUILD
+++ b/pkg/sentry/fs/proc/BUILD
@@ -36,6 +36,7 @@ go_library(
"//pkg/sentry/fs/proc/device",
"//pkg/sentry/fs/proc/seqfile",
"//pkg/sentry/fs/ramfs",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/fs/proc/README.md b/pkg/sentry/fs/proc/README.md
index 5d4ec6c7b..6667a0916 100644
--- a/pkg/sentry/fs/proc/README.md
+++ b/pkg/sentry/fs/proc/README.md
@@ -11,6 +11,8 @@ inconsistency, please file a bug.
The following files are implemented:
+<!-- mdformat off(don't wrap the table) -->
+
| File /proc/ | Content |
| :------------------------ | :---------------------------------------------------- |
| [cpuinfo](#cpuinfo) | Info about the CPU |
@@ -22,6 +24,8 @@ The following files are implemented:
| [uptime](#uptime) | Wall clock since boot, combined idle time of all cpus |
| [version](#version) | Kernel version |
+<!-- mdformat on -->
+
### cpuinfo
```bash
diff --git a/pkg/sentry/fs/proc/mounts.go b/pkg/sentry/fs/proc/mounts.go
index c10888100..94deb553b 100644
--- a/pkg/sentry/fs/proc/mounts.go
+++ b/pkg/sentry/fs/proc/mounts.go
@@ -60,13 +60,15 @@ func forEachMount(t *kernel.Task, fn func(string, *fs.Mount)) {
})
for _, m := range ms {
mroot := m.Root()
+ if mroot == nil {
+ continue // No longer valid.
+ }
mountPath, desc := mroot.FullName(rootDir)
mroot.DecRef()
if !desc {
// MountSources that are not descendants of the chroot jail are ignored.
continue
}
-
fn(mountPath, m)
}
}
@@ -91,6 +93,12 @@ func (mif *mountInfoFile) ReadSeqFileData(ctx context.Context, handle seqfile.Se
var buf bytes.Buffer
forEachMount(mif.t, func(mountPath string, m *fs.Mount) {
+ mroot := m.Root()
+ if mroot == nil {
+ return // No longer valid.
+ }
+ defer mroot.DecRef()
+
// Format:
// 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue
// (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11)
@@ -107,9 +115,6 @@ func (mif *mountInfoFile) ReadSeqFileData(ctx context.Context, handle seqfile.Se
// (3) Major:Minor device ID. We don't have a superblock, so we
// just use the root inode device number.
- mroot := m.Root()
- defer mroot.DecRef()
-
sa := mroot.Inode.StableAttr
fmt.Fprintf(&buf, "%d:%d ", sa.DeviceFileMajor, sa.DeviceFileMinor)
@@ -207,6 +212,9 @@ func (mf *mountsFile) ReadSeqFileData(ctx context.Context, handle seqfile.SeqHan
//
// The "needs dump"and fsck flags are always 0, which is allowed.
root := m.Root()
+ if root == nil {
+ return // No longer valid.
+ }
defer root.DecRef()
flags := root.Inode.MountSource.Flags
diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go
index 6f2775344..95d5817ff 100644
--- a/pkg/sentry/fs/proc/net.go
+++ b/pkg/sentry/fs/proc/net.go
@@ -43,7 +43,10 @@ import (
// newNet creates a new proc net entry.
func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSource) *fs.Inode {
var contents map[string]*fs.Inode
- if s := p.k.NetworkStack(); s != nil {
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process. We should make this per-process,
+ // a.k.a. /proc/PID/net, and make /proc/net a symlink to /proc/self/net.
+ if s := p.k.RootNetworkNamespace().Stack(); s != nil {
contents = map[string]*fs.Inode{
"dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc),
"snmp": seqfile.NewSeqFileInode(ctx, &netSnmp{s: s}, msrc),
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go
index 0772d4ae4..d4c4b533d 100644
--- a/pkg/sentry/fs/proc/sys_net.go
+++ b/pkg/sentry/fs/proc/sys_net.go
@@ -357,7 +357,9 @@ func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s ine
func (p *proc) newSysNetDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
var contents map[string]*fs.Inode
- if s := p.k.NetworkStack(); s != nil {
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process.
+ if s := p.k.RootNetworkNamespace().Stack(); s != nil {
contents = map[string]*fs.Inode{
"ipv4": p.newSysNetIPv4Dir(ctx, msrc, s),
"core": p.newSysNetCore(ctx, msrc, s),
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index ca020e11e..8ab8d8a02 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -28,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/proc/device"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/mm"
@@ -249,7 +250,7 @@ func newExe(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
return newProcInode(t, exeSymlink, msrc, fs.Symlink, t)
}
-func (e *exe) executable() (d *fs.Dirent, err error) {
+func (e *exe) executable() (file fsbridge.File, err error) {
e.t.WithMuLocked(func(t *kernel.Task) {
mm := t.MemoryManager()
if mm == nil {
@@ -262,8 +263,8 @@ func (e *exe) executable() (d *fs.Dirent, err error) {
// The MemoryManager may be destroyed, in which case
// MemoryManager.destroy will simply set the executable to nil
// (with locks held).
- d = mm.Executable()
- if d == nil {
+ file = mm.Executable()
+ if file == nil {
err = syserror.ENOENT
}
})
@@ -283,15 +284,7 @@ func (e *exe) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
}
defer exec.DecRef()
- root := fs.RootFromContext(ctx)
- if root == nil {
- // This doesn't correspond to anything in Linux because the vfs is
- // global there.
- return "", syserror.EINVAL
- }
- defer root.DecRef()
- n, _ := exec.FullName(root)
- return n, nil
+ return exec.PathnameWithDeleted(ctx), nil
}
// namespaceSymlink represents a symlink in the namespacefs, such as the files
diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go
index dabc10662..25abbc151 100644
--- a/pkg/sentry/fs/tmpfs/inode_file.go
+++ b/pkg/sentry/fs/tmpfs/inode_file.go
@@ -17,6 +17,7 @@ package tmpfs
import (
"fmt"
"io"
+ "math"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -444,10 +445,15 @@ func (rw *fileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error)
defer rw.f.dataMu.Unlock()
// Compute the range to write.
- end := fs.WriteEndOffset(rw.offset, int64(srcs.NumBytes()))
- if end == rw.offset { // srcs.NumBytes() == 0?
+ if srcs.NumBytes() == 0 {
+ // Nothing to do.
return 0, nil
}
+ end := fs.WriteEndOffset(rw.offset, int64(srcs.NumBytes()))
+ if end == math.MaxInt64 {
+ // Overflow.
+ return 0, syserror.EINVAL
+ }
// Check if seals prevent either file growth or all writes.
switch {
diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go
index c00cef0a5..3c2b583ae 100644
--- a/pkg/sentry/fs/tmpfs/tmpfs.go
+++ b/pkg/sentry/fs/tmpfs/tmpfs.go
@@ -159,8 +159,13 @@ func (d *Dir) SetXattr(ctx context.Context, i *fs.Inode, name, value string, fla
}
// ListXattr implements fs.InodeOperations.ListXattr.
-func (d *Dir) ListXattr(ctx context.Context, i *fs.Inode) (map[string]struct{}, error) {
- return d.ramfsDir.ListXattr(ctx, i)
+func (d *Dir) ListXattr(ctx context.Context, i *fs.Inode, size uint64) (map[string]struct{}, error) {
+ return d.ramfsDir.ListXattr(ctx, i, size)
+}
+
+// RemoveXattr implements fs.InodeOperations.RemoveXattr.
+func (d *Dir) RemoveXattr(ctx context.Context, i *fs.Inode, name string) error {
+ return d.ramfsDir.RemoveXattr(ctx, i, name)
}
// Lookup implements fs.InodeOperations.Lookup.
diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/slave.go
index db55cdc48..6a2dbc576 100644
--- a/pkg/sentry/fs/tty/slave.go
+++ b/pkg/sentry/fs/tty/slave.go
@@ -73,7 +73,7 @@ func (si *slaveInodeOperations) Release(ctx context.Context) {
}
// Truncate implements fs.InodeOperations.Truncate.
-func (slaveInodeOperations) Truncate(context.Context, *fs.Inode, int64) error {
+func (*slaveInodeOperations) Truncate(context.Context, *fs.Inode, int64) error {
return nil
}
diff --git a/pkg/sentry/fsbridge/BUILD b/pkg/sentry/fsbridge/BUILD
new file mode 100644
index 000000000..6c798f0bd
--- /dev/null
+++ b/pkg/sentry/fsbridge/BUILD
@@ -0,0 +1,24 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "fsbridge",
+ srcs = [
+ "bridge.go",
+ "fs.go",
+ "vfs.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsbridge/bridge.go b/pkg/sentry/fsbridge/bridge.go
new file mode 100644
index 000000000..8e7590721
--- /dev/null
+++ b/pkg/sentry/fsbridge/bridge.go
@@ -0,0 +1,54 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fsbridge provides common interfaces to bridge between VFS1 and VFS2
+// files.
+package fsbridge
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// File provides a common interface to bridge between VFS1 and VFS2 files.
+type File interface {
+ // PathnameWithDeleted returns an absolute pathname to vd, consistent with
+ // Linux's d_path(). In particular, if vd.Dentry() has been disowned,
+ // PathnameWithDeleted appends " (deleted)" to the returned pathname.
+ PathnameWithDeleted(ctx context.Context) string
+
+ // ReadFull read all contents from the file.
+ ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error)
+
+ // ConfigureMMap mutates opts to implement mmap(2) for the file.
+ ConfigureMMap(context.Context, *memmap.MMapOpts) error
+
+ // Type returns the file type, e.g. linux.S_IFREG.
+ Type(context.Context) (linux.FileMode, error)
+
+ // IncRef increments reference.
+ IncRef()
+
+ // DecRef decrements reference.
+ DecRef()
+}
+
+// Lookup provides a common interface to open files.
+type Lookup interface {
+ // OpenPath opens a file.
+ OpenPath(ctx context.Context, path string, opts vfs.OpenOptions, remainingTraversals *uint, resolveFinal bool) (File, error)
+}
diff --git a/pkg/sentry/fsbridge/fs.go b/pkg/sentry/fsbridge/fs.go
new file mode 100644
index 000000000..093ce1fb3
--- /dev/null
+++ b/pkg/sentry/fsbridge/fs.go
@@ -0,0 +1,181 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsbridge
+
+import (
+ "io"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// fsFile implements File interface over fs.File.
+//
+// +stateify savable
+type fsFile struct {
+ file *fs.File
+}
+
+var _ File = (*fsFile)(nil)
+
+// NewFSFile creates a new File over fs.File.
+func NewFSFile(file *fs.File) File {
+ return &fsFile{file: file}
+}
+
+// PathnameWithDeleted implements File.
+func (f *fsFile) PathnameWithDeleted(ctx context.Context) string {
+ root := fs.RootFromContext(ctx)
+ if root == nil {
+ // This doesn't correspond to anything in Linux because the vfs is
+ // global there.
+ return ""
+ }
+ defer root.DecRef()
+
+ name, _ := f.file.Dirent.FullName(root)
+ return name
+}
+
+// ReadFull implements File.
+func (f *fsFile) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) {
+ var total int64
+ for dst.NumBytes() > 0 {
+ n, err := f.file.Preadv(ctx, dst, offset+total)
+ total += n
+ if err == io.EOF && total != 0 {
+ return total, io.ErrUnexpectedEOF
+ } else if err != nil {
+ return total, err
+ }
+ dst = dst.DropFirst64(n)
+ }
+ return total, nil
+}
+
+// ConfigureMMap implements File.
+func (f *fsFile) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ return f.file.ConfigureMMap(ctx, opts)
+}
+
+// Type implements File.
+func (f *fsFile) Type(context.Context) (linux.FileMode, error) {
+ return linux.FileMode(f.file.Dirent.Inode.StableAttr.Type.LinuxType()), nil
+}
+
+// IncRef implements File.
+func (f *fsFile) IncRef() {
+ f.file.IncRef()
+}
+
+// DecRef implements File.
+func (f *fsFile) DecRef() {
+ f.file.DecRef()
+}
+
+// fsLookup implements Lookup interface using fs.File.
+//
+// +stateify savable
+type fsLookup struct {
+ mntns *fs.MountNamespace
+
+ root *fs.Dirent
+ workingDir *fs.Dirent
+}
+
+var _ Lookup = (*fsLookup)(nil)
+
+// NewFSLookup creates a new Lookup using VFS1.
+func NewFSLookup(mntns *fs.MountNamespace, root, workingDir *fs.Dirent) Lookup {
+ return &fsLookup{
+ mntns: mntns,
+ root: root,
+ workingDir: workingDir,
+ }
+}
+
+// OpenPath implements Lookup.
+func (l *fsLookup) OpenPath(ctx context.Context, path string, opts vfs.OpenOptions, remainingTraversals *uint, resolveFinal bool) (File, error) {
+ var d *fs.Dirent
+ var err error
+ if resolveFinal {
+ d, err = l.mntns.FindInode(ctx, l.root, l.workingDir, path, remainingTraversals)
+ } else {
+ d, err = l.mntns.FindLink(ctx, l.root, l.workingDir, path, remainingTraversals)
+ }
+ if err != nil {
+ return nil, err
+ }
+ defer d.DecRef()
+
+ if !resolveFinal && fs.IsSymlink(d.Inode.StableAttr) {
+ return nil, syserror.ELOOP
+ }
+
+ fsPerm := openOptionsToPermMask(&opts)
+ if err := d.Inode.CheckPermission(ctx, fsPerm); err != nil {
+ return nil, err
+ }
+
+ // If they claim it's a directory, then make sure.
+ if strings.HasSuffix(path, "/") {
+ if d.Inode.StableAttr.Type != fs.Directory {
+ return nil, syserror.ENOTDIR
+ }
+ }
+
+ if opts.FileExec && d.Inode.StableAttr.Type != fs.RegularFile {
+ ctx.Infof("%q is not a regular file: %v", path, d.Inode.StableAttr.Type)
+ return nil, syserror.EACCES
+ }
+
+ f, err := d.Inode.GetFile(ctx, d, flagsToFileFlags(opts.Flags))
+ if err != nil {
+ return nil, err
+ }
+
+ return &fsFile{file: f}, nil
+}
+
+func openOptionsToPermMask(opts *vfs.OpenOptions) fs.PermMask {
+ mode := opts.Flags & linux.O_ACCMODE
+ return fs.PermMask{
+ Read: mode == linux.O_RDONLY || mode == linux.O_RDWR,
+ Write: mode == linux.O_WRONLY || mode == linux.O_RDWR,
+ Execute: opts.FileExec,
+ }
+}
+
+func flagsToFileFlags(flags uint32) fs.FileFlags {
+ return fs.FileFlags{
+ Direct: flags&linux.O_DIRECT != 0,
+ DSync: flags&(linux.O_DSYNC|linux.O_SYNC) != 0,
+ Sync: flags&linux.O_SYNC != 0,
+ NonBlocking: flags&linux.O_NONBLOCK != 0,
+ Read: (flags & linux.O_ACCMODE) != linux.O_WRONLY,
+ Write: (flags & linux.O_ACCMODE) != linux.O_RDONLY,
+ Append: flags&linux.O_APPEND != 0,
+ Directory: flags&linux.O_DIRECTORY != 0,
+ Async: flags&linux.O_ASYNC != 0,
+ LargeFile: flags&linux.O_LARGEFILE != 0,
+ Truncate: flags&linux.O_TRUNC != 0,
+ }
+}
diff --git a/pkg/sentry/fsbridge/vfs.go b/pkg/sentry/fsbridge/vfs.go
new file mode 100644
index 000000000..6aa17bfc1
--- /dev/null
+++ b/pkg/sentry/fsbridge/vfs.go
@@ -0,0 +1,138 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsbridge
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// fsFile implements File interface over vfs.FileDescription.
+//
+// +stateify savable
+type vfsFile struct {
+ file *vfs.FileDescription
+}
+
+var _ File = (*vfsFile)(nil)
+
+// NewVFSFile creates a new File over fs.File.
+func NewVFSFile(file *vfs.FileDescription) File {
+ return &vfsFile{file: file}
+}
+
+// PathnameWithDeleted implements File.
+func (f *vfsFile) PathnameWithDeleted(ctx context.Context) string {
+ root := vfs.RootFromContext(ctx)
+ defer root.DecRef()
+
+ vfsObj := f.file.VirtualDentry().Mount().Filesystem().VirtualFilesystem()
+ name, _ := vfsObj.PathnameWithDeleted(ctx, root, f.file.VirtualDentry())
+ return name
+}
+
+// ReadFull implements File.
+func (f *vfsFile) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) {
+ var total int64
+ for dst.NumBytes() > 0 {
+ n, err := f.file.PRead(ctx, dst, offset+total, vfs.ReadOptions{})
+ total += n
+ if err == io.EOF && total != 0 {
+ return total, io.ErrUnexpectedEOF
+ } else if err != nil {
+ return total, err
+ }
+ dst = dst.DropFirst64(n)
+ }
+ return total, nil
+}
+
+// ConfigureMMap implements File.
+func (f *vfsFile) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ return f.file.ConfigureMMap(ctx, opts)
+}
+
+// Type implements File.
+func (f *vfsFile) Type(ctx context.Context) (linux.FileMode, error) {
+ stat, err := f.file.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ return 0, err
+ }
+ return linux.FileMode(stat.Mode).FileType(), nil
+}
+
+// IncRef implements File.
+func (f *vfsFile) IncRef() {
+ f.file.IncRef()
+}
+
+// DecRef implements File.
+func (f *vfsFile) DecRef() {
+ f.file.DecRef()
+}
+
+// fsLookup implements Lookup interface using fs.File.
+//
+// +stateify savable
+type vfsLookup struct {
+ mntns *vfs.MountNamespace
+
+ root vfs.VirtualDentry
+ workingDir vfs.VirtualDentry
+}
+
+var _ Lookup = (*vfsLookup)(nil)
+
+// NewVFSLookup creates a new Lookup using VFS2.
+func NewVFSLookup(mntns *vfs.MountNamespace, root, workingDir vfs.VirtualDentry) Lookup {
+ return &vfsLookup{
+ mntns: mntns,
+ root: root,
+ workingDir: workingDir,
+ }
+}
+
+// OpenPath implements Lookup.
+//
+// remainingTraversals is not configurable in VFS2, all callers are using the
+// default anyways.
+//
+// TODO(gvisor.dev/issue/1623): Check mount has read and exec permission.
+func (l *vfsLookup) OpenPath(ctx context.Context, pathname string, opts vfs.OpenOptions, _ *uint, resolveFinal bool) (File, error) {
+ vfsObj := l.mntns.Root().Mount().Filesystem().VirtualFilesystem()
+ creds := auth.CredentialsFromContext(ctx)
+ path := fspath.Parse(pathname)
+ pop := &vfs.PathOperation{
+ Root: l.root,
+ Start: l.workingDir,
+ Path: path,
+ FollowFinalSymlink: resolveFinal,
+ }
+ if path.Absolute {
+ pop.Start = l.root
+ }
+ fd, err := vfsObj.OpenAt(ctx, creds, pop, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return &vfsFile{file: fd}, nil
+}
diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
index d36fa74fb..abd4f24e7 100644
--- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
+++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
@@ -28,6 +28,9 @@ import (
"gvisor.dev/gvisor/pkg/sync"
)
+// Name is the default filesystem name.
+const Name = "devtmpfs"
+
// FilesystemType implements vfs.FilesystemType.
type FilesystemType struct {
initOnce sync.Once
@@ -86,7 +89,7 @@ func NewAccessor(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth
// Release must be called when a is no longer in use.
func (a *Accessor) Release() {
a.root.DecRef()
- a.mntns.DecRef(a.vfsObj)
+ a.mntns.DecRef()
}
// accessorContext implements context.Context by extending an existing
@@ -107,6 +110,7 @@ func (a *Accessor) wrapContext(ctx context.Context) *accessorContext {
func (ac *accessorContext) Value(key interface{}) interface{} {
switch key {
case vfs.CtxMountNamespace:
+ ac.a.mntns.IncRef()
return ac.a.mntns
case vfs.CtxRoot:
ac.a.root.IncRef()
diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go
index 82c58c900..b6d52c015 100644
--- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go
+++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go
@@ -29,7 +29,10 @@ func TestDevtmpfs(t *testing.T) {
ctx := contexttest.Context(t)
creds := auth.CredentialsFromContext(ctx)
- vfsObj := vfs.New()
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
// Register tmpfs just so that we can have a root filesystem that isn't
// devtmpfs.
vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
@@ -45,7 +48,7 @@ func TestDevtmpfs(t *testing.T) {
if err != nil {
t.Fatalf("failed to create tmpfs root mount: %v", err)
}
- defer mntns.DecRef(vfsObj)
+ defer mntns.DecRef()
root := mntns.Root()
defer root.DecRef()
devpop := vfs.PathOperation{
diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
index d1436b943..89caee3df 100644
--- a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
+++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
@@ -15,6 +15,9 @@
// These benchmarks emulate memfs benchmarks. Ext4 images must be created
// before this benchmark is run using the `make_deep_ext4.sh` script at
// /tmp/image-{depth}.ext4 for all the depths tested below.
+//
+// The benchmark itself cannot run the script because the script requires
+// sudo privileges to create the file system images.
package benchmark_test
import (
@@ -49,7 +52,10 @@ func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesys
creds := auth.CredentialsFromContext(ctx)
// Create VFS.
- vfsObj := vfs.New()
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ return nil, nil, nil, nil, err
+ }
vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
index ebb72b75e..bd6ede995 100644
--- a/pkg/sentry/fsimpl/ext/directory.go
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -188,14 +188,14 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
childType = fs.ToInodeType(childInode.diskInode.Mode().FileType())
}
- if !cb.Handle(vfs.Dirent{
+ if err := cb.Handle(vfs.Dirent{
Name: child.diskDirent.FileName(),
Type: fs.ToDirentType(childType),
Ino: uint64(child.diskDirent.Inode()),
NextOff: fd.off + 1,
- }) {
+ }); err != nil {
dir.childList.InsertBefore(child, fd.iter)
- return nil
+ return err
}
fd.off++
}
diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go
index 05f992826..29bb73765 100644
--- a/pkg/sentry/fsimpl/ext/ext_test.go
+++ b/pkg/sentry/fsimpl/ext/ext_test.go
@@ -65,7 +65,10 @@ func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesys
creds := auth.CredentialsFromContext(ctx)
// Create VFS.
- vfsObj := vfs.New()
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
@@ -496,9 +499,9 @@ func newIterDirentCb() *iterDirentsCb {
}
// Handle implements vfs.IterDirentsCallback.Handle.
-func (cb *iterDirentsCb) Handle(dirent vfs.Dirent) bool {
+func (cb *iterDirentsCb) Handle(dirent vfs.Dirent) error {
cb.dirents = append(cb.dirents, dirent)
- return true
+ return nil
}
// TestIterDirents tests the FileDescriptionImpl.IterDirents functionality.
diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go
index 07bf58953..e05429d41 100644
--- a/pkg/sentry/fsimpl/ext/filesystem.go
+++ b/pkg/sentry/fsimpl/ext/filesystem.go
@@ -296,7 +296,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if vfs.MayWriteFileWithOpenFlags(opts.Flags) || opts.Flags&(linux.O_CREAT|linux.O_EXCL|linux.O_TMPFILE) != 0 {
return nil, syserror.EROFS
}
- return inode.open(rp, vfsd, opts.Flags)
+ return inode.open(rp, vfsd, &opts)
}
// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go
index 191b39970..6962083f5 100644
--- a/pkg/sentry/fsimpl/ext/inode.go
+++ b/pkg/sentry/fsimpl/ext/inode.go
@@ -148,8 +148,8 @@ func newInode(fs *filesystem, inodeNum uint32) (*inode, error) {
}
// open creates and returns a file description for the dentry passed in.
-func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
- ats := vfs.AccessTypesForOpenFlags(flags)
+func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
+ ats := vfs.AccessTypesForOpenFlags(opts)
if err := in.checkPermissions(rp.Credentials(), ats); err != nil {
return nil, err
}
@@ -157,7 +157,7 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v
switch in.impl.(type) {
case *regularFile:
var fd regularFileFD
- if err := fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
return &fd.vfsfd, nil
@@ -168,17 +168,17 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v
return nil, syserror.EISDIR
}
var fd directoryFD
- if err := fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil {
return nil, err
}
return &fd.vfsfd, nil
case *symlink:
- if flags&linux.O_PATH == 0 {
+ if opts.Flags&linux.O_PATH == 0 {
// Can't open symlinks without O_PATH.
return nil, syserror.ELOOP
}
var fd symlinkFD
- fd.vfsfd.Init(&fd, flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
+ fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
return &fd.vfsfd, nil
default:
panic(fmt.Sprintf("unknown inode type: %T", in.impl))
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
new file mode 100644
index 000000000..4ba76a1e8
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -0,0 +1,55 @@
+load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+licenses(["notice"])
+
+go_template_instance(
+ name = "dentry_list",
+ out = "dentry_list.go",
+ package = "gofer",
+ prefix = "dentry",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*dentry",
+ "Linker": "*dentry",
+ },
+)
+
+go_library(
+ name = "gofer",
+ srcs = [
+ "dentry_list.go",
+ "directory.go",
+ "filesystem.go",
+ "gofer.go",
+ "handle.go",
+ "handle_unsafe.go",
+ "p9file.go",
+ "pagemath.go",
+ "regular_file.go",
+ "special_file.go",
+ "symlink.go",
+ "time.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fd",
+ "//pkg/fspath",
+ "//pkg/log",
+ "//pkg/p9",
+ "//pkg/safemem",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/usage",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/unet",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go
new file mode 100644
index 000000000..5dbfc6250
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/directory.go
@@ -0,0 +1,194 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "sync"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+func (d *dentry) isDir() bool {
+ return d.fileType() == linux.S_IFDIR
+}
+
+// Preconditions: d.dirMu must be locked. d.isDir(). fs.opts.interop !=
+// InteropModeShared.
+func (d *dentry) cacheNegativeChildLocked(name string) {
+ if d.negativeChildren == nil {
+ d.negativeChildren = make(map[string]struct{})
+ }
+ d.negativeChildren[name] = struct{}{}
+}
+
+type directoryFD struct {
+ fileDescription
+ vfs.DirectoryFileDescriptionDefaultImpl
+
+ mu sync.Mutex
+ off int64
+ dirents []vfs.Dirent
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *directoryFD) Release() {
+}
+
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
+func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+
+ if fd.dirents == nil {
+ ds, err := fd.dentry().getDirents(ctx)
+ if err != nil {
+ return err
+ }
+ fd.dirents = ds
+ }
+
+ for fd.off < int64(len(fd.dirents)) {
+ if err := cb.Handle(fd.dirents[fd.off]); err != nil {
+ return err
+ }
+ fd.off++
+ }
+ return nil
+}
+
+// Preconditions: d.isDir(). There exists at least one directoryFD representing d.
+func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) {
+ // 9P2000.L's readdir does not specify behavior in the presence of
+ // concurrent mutation of an iterated directory, so implementations may
+ // duplicate or omit entries in this case, which violates POSIX semantics.
+ // Thus we read all directory entries while holding d.dirMu to exclude
+ // directory mutations. (Note that it is impossible for the client to
+ // exclude concurrent mutation from other remote filesystem users. Since
+ // there is no way to detect if the server has incorrectly omitted
+ // directory entries, we simply assume that the server is well-behaved
+ // under InteropModeShared.) This is inconsistent with Linux (which appears
+ // to assume that directory fids have the correct semantics, and translates
+ // struct file_operations::readdir calls directly to readdir RPCs), but is
+ // consistent with VFS1.
+ //
+ // NOTE(b/135560623): In particular, some gofer implementations may not
+ // retain state between calls to Readdir, so may not provide a coherent
+ // directory stream across in the presence of mutation.
+
+ d.fs.renameMu.RLock()
+ defer d.fs.renameMu.RUnlock()
+ d.dirMu.Lock()
+ defer d.dirMu.Unlock()
+ if d.dirents != nil {
+ return d.dirents, nil
+ }
+
+ // It's not clear if 9P2000.L's readdir is expected to return "." and "..",
+ // so we generate them here.
+ parent := d.vfsd.ParentOrSelf().Impl().(*dentry)
+ dirents := []vfs.Dirent{
+ {
+ Name: ".",
+ Type: linux.DT_DIR,
+ Ino: d.ino,
+ NextOff: 1,
+ },
+ {
+ Name: "..",
+ Type: uint8(atomic.LoadUint32(&parent.mode) >> 12),
+ Ino: parent.ino,
+ NextOff: 2,
+ },
+ }
+ off := uint64(0)
+ const count = 64 * 1024 // for consistency with the vfs1 client
+ d.handleMu.RLock()
+ defer d.handleMu.RUnlock()
+ if !d.handleReadable {
+ // This should not be possible because a readable handle should have
+ // been opened when the calling directoryFD was opened.
+ panic("gofer.dentry.getDirents called without a readable handle")
+ }
+ for {
+ p9ds, err := d.handle.file.readdir(ctx, off, count)
+ if err != nil {
+ return nil, err
+ }
+ if len(p9ds) == 0 {
+ // Cache dirents for future directoryFDs if permitted.
+ if d.fs.opts.interop != InteropModeShared {
+ d.dirents = dirents
+ }
+ return dirents, nil
+ }
+ for _, p9d := range p9ds {
+ if p9d.Name == "." || p9d.Name == ".." {
+ continue
+ }
+ dirent := vfs.Dirent{
+ Name: p9d.Name,
+ Ino: p9d.QID.Path,
+ NextOff: int64(len(dirents) + 1),
+ }
+ // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or
+ // DMSOCKET.
+ switch p9d.Type {
+ case p9.TypeSymlink:
+ dirent.Type = linux.DT_LNK
+ case p9.TypeDir:
+ dirent.Type = linux.DT_DIR
+ default:
+ dirent.Type = linux.DT_REG
+ }
+ dirents = append(dirents, dirent)
+ }
+ off = p9ds[len(p9ds)-1].Offset
+ }
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+
+ switch whence {
+ case linux.SEEK_SET:
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ if offset == 0 {
+ // Ensure that the next call to fd.IterDirents() calls
+ // fd.dentry().getDirents().
+ fd.dirents = nil
+ }
+ fd.off = offset
+ return fd.off, nil
+ case linux.SEEK_CUR:
+ offset += fd.off
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ // Don't clear fd.dirents in this case, even if offset == 0.
+ fd.off = offset
+ return fd.off, nil
+ default:
+ return 0, syserror.EINVAL
+ }
+}
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
new file mode 100644
index 000000000..5cfb0dc4c
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -0,0 +1,1090 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Sync implements vfs.FilesystemImpl.Sync.
+func (fs *filesystem) Sync(ctx context.Context) error {
+ // Snapshot current dentries and special files.
+ fs.syncMu.Lock()
+ ds := make([]*dentry, 0, len(fs.dentries))
+ for d := range fs.dentries {
+ ds = append(ds, d)
+ }
+ sffds := make([]*specialFileFD, 0, len(fs.specialFileFDs))
+ for sffd := range fs.specialFileFDs {
+ sffds = append(sffds, sffd)
+ }
+ fs.syncMu.Unlock()
+
+ // Return the first error we encounter, but sync everything we can
+ // regardless.
+ var retErr error
+
+ // Sync regular files.
+ for _, d := range ds {
+ if !d.TryIncRef() {
+ continue
+ }
+ err := d.syncSharedHandle(ctx)
+ d.DecRef()
+ if err != nil && retErr == nil {
+ retErr = err
+ }
+ }
+
+ // Sync special files, which may be writable but do not use dentry shared
+ // handles (so they won't be synced by the above).
+ for _, sffd := range sffds {
+ if !sffd.vfsfd.TryIncRef() {
+ continue
+ }
+ err := sffd.Sync(ctx)
+ sffd.vfsfd.DecRef()
+ if err != nil && retErr == nil {
+ retErr = err
+ }
+ }
+
+ return retErr
+}
+
+// maxFilenameLen is the maximum length of a filename. This is dictated by 9P's
+// encoding of strings, which uses 2 bytes for the length prefix.
+const maxFilenameLen = (1 << 16) - 1
+
+// dentrySlicePool is a pool of *[]*dentry used to store dentries for which
+// dentry.checkCachingLocked() must be called. The pool holds pointers to
+// slices because Go lacks generics, so sync.Pool operates on interface{}, so
+// every call to (what should be) sync.Pool<[]*dentry>.Put() allocates a copy
+// of the slice header on the heap.
+var dentrySlicePool = sync.Pool{
+ New: func() interface{} {
+ ds := make([]*dentry, 0, 4) // arbitrary non-zero initial capacity
+ return &ds
+ },
+}
+
+func appendDentry(ds *[]*dentry, d *dentry) *[]*dentry {
+ if ds == nil {
+ ds = dentrySlicePool.Get().(*[]*dentry)
+ }
+ *ds = append(*ds, d)
+ return ds
+}
+
+// Preconditions: ds != nil.
+func putDentrySlice(ds *[]*dentry) {
+ // Allow dentries to be GC'd.
+ for i := range *ds {
+ (*ds)[i] = nil
+ }
+ *ds = (*ds)[:0]
+ dentrySlicePool.Put(ds)
+}
+
+// stepLocked resolves rp.Component() to an existing file, starting from the
+// given directory.
+//
+// Dentries which may become cached as a result of the traversal are appended
+// to *ds.
+//
+// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
+// !rp.Done(). If fs.opts.interop == InteropModeShared, then d's cached
+// metadata must be up to date.
+func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) {
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
+ return nil, err
+ }
+afterSymlink:
+ name := rp.Component()
+ if name == "." {
+ rp.Advance()
+ return d, nil
+ }
+ if name == ".." {
+ parentVFSD, err := rp.ResolveParent(&d.vfsd)
+ if err != nil {
+ return nil, err
+ }
+ parent := parentVFSD.Impl().(*dentry)
+ if fs.opts.interop == InteropModeShared {
+ // We must assume that parentVFSD is correct, because if d has been
+ // moved elsewhere in the remote filesystem so that its parent has
+ // changed, we have no way of determining its new parent's location
+ // in the filesystem. Get updated metadata for parentVFSD.
+ _, attrMask, attr, err := parent.file.getAttr(ctx, dentryAttrMask())
+ if err != nil {
+ return nil, err
+ }
+ parent.updateFromP9Attrs(attrMask, &attr)
+ }
+ rp.Advance()
+ return parent, nil
+ }
+ childVFSD, err := rp.ResolveChild(&d.vfsd, name)
+ if err != nil {
+ return nil, err
+ }
+ // FIXME(jamieliu): Linux performs revalidation before mount lookup
+ // (fs/namei.c:lookup_fast() => __d_lookup_rcu(), d_revalidate(),
+ // __follow_mount_rcu()).
+ child, err := fs.revalidateChildLocked(ctx, rp.VirtualFilesystem(), d, name, childVFSD, ds)
+ if err != nil {
+ return nil, err
+ }
+ if child == nil {
+ return nil, syserror.ENOENT
+ }
+ if child.isSymlink() && rp.ShouldFollowSymlink() {
+ target, err := child.readlink(ctx, rp.Mount())
+ if err != nil {
+ return nil, err
+ }
+ if err := rp.HandleSymlink(target); err != nil {
+ return nil, err
+ }
+ goto afterSymlink // don't check the current directory again
+ }
+ rp.Advance()
+ return child, nil
+}
+
+// revalidateChildLocked must be called after a call to parent.vfsd.Child(name)
+// or vfs.ResolvingPath.ResolveChild(name) returns childVFSD (which may be
+// nil) to verify that the returned child (or lack thereof) is correct. If no file
+// exists at name, revalidateChildLocked returns (nil, nil).
+//
+// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked.
+// parent.isDir(). name is not "." or "..".
+//
+// Postconditions: If revalidateChildLocked returns a non-nil dentry, its
+// cached metadata is up to date.
+func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, childVFSD *vfs.Dentry, ds **[]*dentry) (*dentry, error) {
+ if childVFSD != nil && fs.opts.interop != InteropModeShared {
+ // We have a cached dentry that is assumed to be correct.
+ return childVFSD.Impl().(*dentry), nil
+ }
+ // We either don't have a cached dentry or need to verify that it's still
+ // correct, either of which requires a remote lookup. Check if this name is
+ // valid before performing the lookup.
+ if len(name) > maxFilenameLen {
+ return nil, syserror.ENAMETOOLONG
+ }
+ // Check if we've already cached this lookup with a negative result.
+ if _, ok := parent.negativeChildren[name]; ok {
+ return nil, nil
+ }
+ // Perform the remote lookup.
+ qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name)
+ if err != nil && err != syserror.ENOENT {
+ return nil, err
+ }
+ if childVFSD != nil {
+ child := childVFSD.Impl().(*dentry)
+ if !file.isNil() && qid.Path == child.ino {
+ // The file at this path hasn't changed. Just update cached
+ // metadata.
+ file.close(ctx)
+ child.updateFromP9Attrs(attrMask, &attr)
+ return child, nil
+ }
+ // The file at this path has changed or no longer exists. Remove
+ // the stale dentry from the tree, and re-evaluate its caching
+ // status (i.e. if it has 0 references, drop it).
+ vfsObj.ForceDeleteDentry(childVFSD)
+ *ds = appendDentry(*ds, child)
+ childVFSD = nil
+ }
+ if file.isNil() {
+ // No file exists at this path now. Cache the negative lookup if
+ // allowed.
+ if fs.opts.interop != InteropModeShared {
+ parent.cacheNegativeChildLocked(name)
+ }
+ return nil, nil
+ }
+ // Create a new dentry representing the file.
+ child, err := fs.newDentry(ctx, file, qid, attrMask, &attr)
+ if err != nil {
+ file.close(ctx)
+ return nil, err
+ }
+ parent.IncRef() // reference held by child on its parent
+ parent.vfsd.InsertChild(&child.vfsd, name)
+ // For now, child has 0 references, so our caller should call
+ // child.checkCachingLocked().
+ *ds = appendDentry(*ds, child)
+ return child, nil
+}
+
+// walkParentDirLocked resolves all but the last path component of rp to an
+// existing directory, starting from the given directory (which is usually
+// rp.Start().Impl().(*dentry)). It does not check that the returned directory
+// is searchable by the provider of rp.
+//
+// Preconditions: fs.renameMu must be locked. !rp.Done(). If fs.opts.interop ==
+// InteropModeShared, then d's cached metadata must be up to date.
+func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) {
+ for !rp.Final() {
+ d.dirMu.Lock()
+ next, err := fs.stepLocked(ctx, rp, d, ds)
+ d.dirMu.Unlock()
+ if err != nil {
+ return nil, err
+ }
+ d = next
+ }
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ return d, nil
+}
+
+// resolveLocked resolves rp to an existing file.
+//
+// Preconditions: fs.renameMu must be locked.
+func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, ds **[]*dentry) (*dentry, error) {
+ d := rp.Start().Impl().(*dentry)
+ if fs.opts.interop == InteropModeShared {
+ // Get updated metadata for rp.Start() as required by fs.stepLocked().
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return nil, err
+ }
+ }
+ for !rp.Done() {
+ d.dirMu.Lock()
+ next, err := fs.stepLocked(ctx, rp, d, ds)
+ d.dirMu.Unlock()
+ if err != nil {
+ return nil, err
+ }
+ d = next
+ }
+ if rp.MustBeDir() && !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ return d, nil
+}
+
+// doCreateAt checks that creating a file at rp is permitted, then invokes
+// create to do so.
+//
+// Preconditions: !rp.Done(). For the final path component in rp,
+// !rp.ShouldFollowSymlink().
+func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, create func(parent *dentry, name string) error) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ start := rp.Start().Impl().(*dentry)
+ if fs.opts.interop == InteropModeShared {
+ // Get updated metadata for start as required by
+ // fs.walkParentDirLocked().
+ if err := start.updateFromGetattr(ctx); err != nil {
+ return err
+ }
+ }
+ parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return err
+ }
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil {
+ return err
+ }
+ if parent.isDeleted() {
+ return syserror.ENOENT
+ }
+ name := rp.Component()
+ if name == "." || name == ".." {
+ return syserror.EEXIST
+ }
+ if len(name) > maxFilenameLen {
+ return syserror.ENAMETOOLONG
+ }
+ if !dir && rp.MustBeDir() {
+ return syserror.ENOENT
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ parent.dirMu.Lock()
+ defer parent.dirMu.Unlock()
+ if fs.opts.interop == InteropModeShared {
+ // The existence of a dentry at name would be inconclusive because the
+ // file it represents may have been deleted from the remote filesystem,
+ // so we would need to make an RPC to revalidate the dentry. Just
+ // attempt the file creation RPC instead. If a file does exist, the RPC
+ // will fail with EEXIST like we would have. If the RPC succeeds, and a
+ // stale dentry exists, the dentry will fail revalidation next time
+ // it's used.
+ return create(parent, name)
+ }
+ if parent.vfsd.Child(name) != nil {
+ return syserror.EEXIST
+ }
+ // No cached dentry exists; however, there might still be an existing file
+ // at name. As above, we attempt the file creation RPC anyway.
+ if err := create(parent, name); err != nil {
+ return err
+ }
+ parent.touchCMtime(ctx)
+ delete(parent.negativeChildren, name)
+ parent.dirents = nil
+ return nil
+}
+
+// Preconditions: !rp.Done().
+func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ start := rp.Start().Impl().(*dentry)
+ if fs.opts.interop == InteropModeShared {
+ // Get updated metadata for start as required by
+ // fs.walkParentDirLocked().
+ if err := start.updateFromGetattr(ctx); err != nil {
+ return err
+ }
+ }
+ parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return err
+ }
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+
+ name := rp.Component()
+ if dir {
+ if name == "." {
+ return syserror.EINVAL
+ }
+ if name == ".." {
+ return syserror.ENOTEMPTY
+ }
+ } else {
+ if name == "." || name == ".." {
+ return syserror.EISDIR
+ }
+ }
+ vfsObj := rp.VirtualFilesystem()
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ parent.dirMu.Lock()
+ defer parent.dirMu.Unlock()
+ childVFSD := parent.vfsd.Child(name)
+ var child *dentry
+ // We only need a dentry representing the file at name if it can be a mount
+ // point. If childVFSD is nil, then it can't be a mount point. If childVFSD
+ // is non-nil but stale, the actual file can't be a mount point either; we
+ // detect this case by just speculatively calling PrepareDeleteDentry and
+ // only revalidating the dentry if that fails (indicating that the existing
+ // dentry is a mount point).
+ if childVFSD != nil {
+ child = childVFSD.Impl().(*dentry)
+ if err := vfsObj.PrepareDeleteDentry(mntns, childVFSD); err != nil {
+ child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, childVFSD, &ds)
+ if err != nil {
+ return err
+ }
+ if child != nil {
+ childVFSD = &child.vfsd
+ if err := vfsObj.PrepareDeleteDentry(mntns, childVFSD); err != nil {
+ return err
+ }
+ } else {
+ childVFSD = nil
+ }
+ }
+ } else if _, ok := parent.negativeChildren[name]; ok {
+ return syserror.ENOENT
+ }
+ flags := uint32(0)
+ if dir {
+ if child != nil && !child.isDir() {
+ return syserror.ENOTDIR
+ }
+ flags = linux.AT_REMOVEDIR
+ } else {
+ if child != nil && child.isDir() {
+ return syserror.EISDIR
+ }
+ if rp.MustBeDir() {
+ return syserror.ENOTDIR
+ }
+ }
+ err = parent.file.unlinkAt(ctx, name, flags)
+ if err != nil {
+ if childVFSD != nil {
+ vfsObj.AbortDeleteDentry(childVFSD)
+ }
+ return err
+ }
+ if fs.opts.interop != InteropModeShared {
+ parent.touchCMtime(ctx)
+ parent.cacheNegativeChildLocked(name)
+ parent.dirents = nil
+ }
+ if child != nil {
+ child.setDeleted()
+ vfsObj.CommitDeleteDentry(childVFSD)
+ ds = appendDentry(ds, child)
+ }
+ return nil
+}
+
+// renameMuRUnlockAndCheckCaching calls fs.renameMu.RUnlock(), then calls
+// dentry.checkCachingLocked on all dentries in *ds with fs.renameMu locked for
+// writing.
+//
+// ds is a pointer-to-pointer since defer evaluates its arguments immediately,
+// but dentry slices are allocated lazily, and it's much easier to say "defer
+// fs.renameMuRUnlockAndCheckCaching(&ds)" than "defer func() {
+// fs.renameMuRUnlockAndCheckCaching(ds) }()" to work around this.
+func (fs *filesystem) renameMuRUnlockAndCheckCaching(ds **[]*dentry) {
+ fs.renameMu.RUnlock()
+ if *ds == nil {
+ return
+ }
+ if len(**ds) != 0 {
+ fs.renameMu.Lock()
+ for _, d := range **ds {
+ d.checkCachingLocked()
+ }
+ fs.renameMu.Unlock()
+ }
+ putDentrySlice(*ds)
+}
+
+func (fs *filesystem) renameMuUnlockAndCheckCaching(ds **[]*dentry) {
+ if *ds == nil {
+ fs.renameMu.Unlock()
+ return
+ }
+ for _, d := range **ds {
+ d.checkCachingLocked()
+ }
+ fs.renameMu.Unlock()
+ putDentrySlice(*ds)
+}
+
+// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
+func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return nil, err
+ }
+ if opts.CheckSearchable {
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
+ return nil, err
+ }
+ }
+ d.IncRef()
+ return &d.vfsd, nil
+}
+
+// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt.
+func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ start := rp.Start().Impl().(*dentry)
+ if fs.opts.interop == InteropModeShared {
+ // Get updated metadata for start as required by
+ // fs.walkParentDirLocked().
+ if err := start.updateFromGetattr(ctx); err != nil {
+ return nil, err
+ }
+ }
+ d, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return nil, err
+ }
+ d.IncRef()
+ return &d.vfsd, nil
+}
+
+// LinkAt implements vfs.FilesystemImpl.LinkAt.
+func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string) error {
+ if rp.Mount() != vd.Mount() {
+ return syserror.EXDEV
+ }
+ // 9P2000.L supports hard links, but we don't.
+ return syserror.EPERM
+ })
+}
+
+// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
+func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
+ return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string) error {
+ creds := rp.Credentials()
+ _, err := parent.file.mkdir(ctx, name, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
+ return err
+ })
+}
+
+// MknodAt implements vfs.FilesystemImpl.MknodAt.
+func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string) error {
+ creds := rp.Credentials()
+ _, err := parent.file.mknod(ctx, name, (p9.FileMode)(opts.Mode), opts.DevMajor, opts.DevMinor, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
+ return err
+ })
+}
+
+// OpenAt implements vfs.FilesystemImpl.OpenAt.
+func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ // Reject O_TMPFILE, which is not supported; supporting it correctly in the
+ // presence of other remote filesystem users requires remote filesystem
+ // support, and it isn't clear that there's any way to implement this in
+ // 9P.
+ if opts.Flags&linux.O_TMPFILE != 0 {
+ return nil, syserror.EOPNOTSUPP
+ }
+ mayCreate := opts.Flags&linux.O_CREAT != 0
+ mustCreate := opts.Flags&(linux.O_CREAT|linux.O_EXCL) == (linux.O_CREAT | linux.O_EXCL)
+
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+
+ start := rp.Start().Impl().(*dentry)
+ if fs.opts.interop == InteropModeShared {
+ // Get updated metadata for start as required by fs.stepLocked().
+ if err := start.updateFromGetattr(ctx); err != nil {
+ return nil, err
+ }
+ }
+ if rp.Done() {
+ return start.openLocked(ctx, rp, &opts)
+ }
+
+afterTrailingSymlink:
+ parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return nil, err
+ }
+ // Check for search permission in the parent directory.
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
+ return nil, err
+ }
+ // Determine whether or not we need to create a file.
+ parent.dirMu.Lock()
+ child, err := fs.stepLocked(ctx, rp, parent, &ds)
+ if err == syserror.ENOENT && mayCreate {
+ fd, err := parent.createAndOpenChildLocked(ctx, rp, &opts)
+ parent.dirMu.Unlock()
+ return fd, err
+ }
+ if err != nil {
+ parent.dirMu.Unlock()
+ return nil, err
+ }
+ // Open existing child or follow symlink.
+ parent.dirMu.Unlock()
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
+ if child.isSymlink() && rp.ShouldFollowSymlink() {
+ target, err := child.readlink(ctx, rp.Mount())
+ if err != nil {
+ return nil, err
+ }
+ if err := rp.HandleSymlink(target); err != nil {
+ return nil, err
+ }
+ start = parent
+ goto afterTrailingSymlink
+ }
+ return child.openLocked(ctx, rp, &opts)
+}
+
+// Preconditions: fs.renameMu must be locked.
+func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
+ ats := vfs.AccessTypesForOpenFlags(opts)
+ if err := d.checkPermissions(rp.Credentials(), ats, d.isDir()); err != nil {
+ return nil, err
+ }
+ mnt := rp.Mount()
+ filetype := d.fileType()
+ switch {
+ case filetype == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD:
+ if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, opts.Flags&linux.O_TRUNC != 0); err != nil {
+ return nil, err
+ }
+ fd := &regularFileFD{}
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{
+ AllowDirectIO: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+ case filetype == linux.S_IFDIR:
+ // Can't open directories with O_CREAT.
+ if opts.Flags&linux.O_CREAT != 0 {
+ return nil, syserror.EISDIR
+ }
+ // Can't open directories writably.
+ if ats&vfs.MayWrite != 0 {
+ return nil, syserror.EISDIR
+ }
+ if opts.Flags&linux.O_DIRECT != 0 {
+ return nil, syserror.EINVAL
+ }
+ if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, false /* write */, false /* trunc */); err != nil {
+ return nil, err
+ }
+ fd := &directoryFD{}
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+ case filetype == linux.S_IFLNK:
+ // Can't open symlinks without O_PATH (which is unimplemented).
+ return nil, syserror.ELOOP
+ default:
+ if opts.Flags&linux.O_DIRECT != 0 {
+ return nil, syserror.EINVAL
+ }
+ h, err := openHandle(ctx, d.file, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, opts.Flags&linux.O_TRUNC != 0)
+ if err != nil {
+ return nil, err
+ }
+ fd := &specialFileFD{
+ handle: h,
+ }
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ h.close(ctx)
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+ }
+}
+
+// Preconditions: d.fs.renameMu must be locked. d.dirMu must be locked.
+func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite, true); err != nil {
+ return nil, err
+ }
+ if d.isDeleted() {
+ return nil, syserror.ENOENT
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return nil, err
+ }
+ defer mnt.EndWrite()
+
+ // 9P2000.L's lcreate takes a fid representing the parent directory, and
+ // converts it into an open fid representing the created file, so we need
+ // to duplicate the directory fid first.
+ _, dirfile, err := d.file.walk(ctx, nil)
+ if err != nil {
+ return nil, err
+ }
+ creds := rp.Credentials()
+ name := rp.Component()
+ fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, (p9.OpenFlags)(opts.Flags), (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
+ if err != nil {
+ dirfile.close(ctx)
+ return nil, err
+ }
+ // Then we need to walk to the file we just created to get a non-open fid
+ // representing it, and to get its metadata. This must use d.file since, as
+ // explained above, dirfile was invalidated by dirfile.Create().
+ walkQID, nonOpenFile, attrMask, attr, err := d.file.walkGetAttrOne(ctx, name)
+ if err != nil {
+ openFile.close(ctx)
+ if fdobj != nil {
+ fdobj.Close()
+ }
+ return nil, err
+ }
+ // Sanity-check that we walked to the file we created.
+ if createQID.Path != walkQID.Path {
+ // Probably due to concurrent remote filesystem mutation?
+ ctx.Warningf("gofer.dentry.createAndOpenChildLocked: created file has QID %v before walk, QID %v after (interop=%v)", createQID, walkQID, d.fs.opts.interop)
+ nonOpenFile.close(ctx)
+ openFile.close(ctx)
+ if fdobj != nil {
+ fdobj.Close()
+ }
+ return nil, syserror.EAGAIN
+ }
+
+ // Construct the new dentry.
+ child, err := d.fs.newDentry(ctx, nonOpenFile, createQID, attrMask, &attr)
+ if err != nil {
+ nonOpenFile.close(ctx)
+ openFile.close(ctx)
+ if fdobj != nil {
+ fdobj.Close()
+ }
+ return nil, err
+ }
+ // Incorporate the fid that was opened by lcreate.
+ useRegularFileFD := child.fileType() == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD
+ if useRegularFileFD {
+ child.handleMu.Lock()
+ child.handle.file = openFile
+ if fdobj != nil {
+ child.handle.fd = int32(fdobj.Release())
+ }
+ child.handleReadable = vfs.MayReadFileWithOpenFlags(opts.Flags)
+ child.handleWritable = vfs.MayWriteFileWithOpenFlags(opts.Flags)
+ child.handleMu.Unlock()
+ }
+ // Take a reference on the new dentry to be held by the new file
+ // description. (This reference also means that the new dentry is not
+ // eligible for caching yet, so we don't need to append to a dentry slice.)
+ child.refs = 1
+ // Insert the dentry into the tree.
+ d.IncRef() // reference held by child on its parent d
+ d.vfsd.InsertChild(&child.vfsd, name)
+ if d.fs.opts.interop != InteropModeShared {
+ d.touchCMtime(ctx)
+ delete(d.negativeChildren, name)
+ d.dirents = nil
+ }
+
+ // Finally, construct a file description representing the created file.
+ var childVFSFD *vfs.FileDescription
+ mnt.IncRef()
+ if useRegularFileFD {
+ fd := &regularFileFD{}
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &child.vfsd, &vfs.FileDescriptionOptions{
+ AllowDirectIO: true,
+ }); err != nil {
+ return nil, err
+ }
+ childVFSFD = &fd.vfsfd
+ } else {
+ fd := &specialFileFD{
+ handle: handle{
+ file: openFile,
+ fd: -1,
+ },
+ }
+ if fdobj != nil {
+ fd.handle.fd = int32(fdobj.Release())
+ }
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &child.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ fd.handle.close(ctx)
+ return nil, err
+ }
+ childVFSFD = &fd.vfsfd
+ }
+ return childVFSFD, nil
+}
+
+// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
+func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return "", err
+ }
+ if !d.isSymlink() {
+ return "", syserror.EINVAL
+ }
+ return d.readlink(ctx, rp.Mount())
+}
+
+// RenameAt implements vfs.FilesystemImpl.RenameAt.
+func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
+ if opts.Flags != 0 {
+ // Requires 9P support.
+ return syserror.EINVAL
+ }
+
+ var ds *[]*dentry
+ fs.renameMu.Lock()
+ defer fs.renameMuUnlockAndCheckCaching(&ds)
+ newParent, err := fs.walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry), &ds)
+ if err != nil {
+ return err
+ }
+ newName := rp.Component()
+ if newName == "." || newName == ".." {
+ return syserror.EBUSY
+ }
+ mnt := rp.Mount()
+ if mnt != oldParentVD.Mount() {
+ return syserror.EXDEV
+ }
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+
+ oldParent := oldParentVD.Dentry().Impl().(*dentry)
+ if fs.opts.interop == InteropModeShared {
+ if err := oldParent.updateFromGetattr(ctx); err != nil {
+ return err
+ }
+ }
+ if err := oldParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil {
+ return err
+ }
+ vfsObj := rp.VirtualFilesystem()
+ // We need a dentry representing the renamed file since, if it's a
+ // directory, we need to check for write permission on it.
+ oldParent.dirMu.Lock()
+ defer oldParent.dirMu.Unlock()
+ renamed, err := fs.revalidateChildLocked(ctx, vfsObj, oldParent, oldName, oldParent.vfsd.Child(oldName), &ds)
+ if err != nil {
+ return err
+ }
+ if renamed == nil {
+ return syserror.ENOENT
+ }
+ if renamed.isDir() {
+ if renamed == newParent || renamed.vfsd.IsAncestorOf(&newParent.vfsd) {
+ return syserror.EINVAL
+ }
+ if oldParent != newParent {
+ if err := renamed.checkPermissions(rp.Credentials(), vfs.MayWrite, true); err != nil {
+ return err
+ }
+ }
+ } else {
+ if opts.MustBeDir || rp.MustBeDir() {
+ return syserror.ENOTDIR
+ }
+ }
+
+ if oldParent != newParent {
+ if err := newParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil {
+ return err
+ }
+ newParent.dirMu.Lock()
+ defer newParent.dirMu.Unlock()
+ }
+ if newParent.isDeleted() {
+ return syserror.ENOENT
+ }
+ replacedVFSD := newParent.vfsd.Child(newName)
+ var replaced *dentry
+ // This is similar to unlinkAt, except:
+ //
+ // - We revalidate the replaced dentry unconditionally for simplicity.
+ //
+ // - If rp.MustBeDir(), then we need a dentry representing the replaced
+ // file regardless to confirm that it's a directory.
+ if replacedVFSD != nil || rp.MustBeDir() {
+ replaced, err = fs.revalidateChildLocked(ctx, vfsObj, newParent, newName, replacedVFSD, &ds)
+ if err != nil {
+ return err
+ }
+ if replaced != nil {
+ if replaced.isDir() {
+ if !renamed.isDir() {
+ return syserror.EISDIR
+ }
+ } else {
+ if rp.MustBeDir() || renamed.isDir() {
+ return syserror.ENOTDIR
+ }
+ }
+ replacedVFSD = &replaced.vfsd
+ } else {
+ replacedVFSD = nil
+ }
+ }
+
+ if oldParent == newParent && oldName == newName {
+ return nil
+ }
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ if err := vfsObj.PrepareRenameDentry(mntns, &renamed.vfsd, replacedVFSD); err != nil {
+ return err
+ }
+ if err := renamed.file.rename(ctx, newParent.file, newName); err != nil {
+ vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD)
+ return err
+ }
+ if fs.opts.interop != InteropModeShared {
+ oldParent.cacheNegativeChildLocked(oldName)
+ oldParent.dirents = nil
+ delete(newParent.negativeChildren, newName)
+ newParent.dirents = nil
+ }
+ vfsObj.CommitRenameReplaceDentry(&renamed.vfsd, &newParent.vfsd, newName, replacedVFSD)
+ return nil
+}
+
+// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
+func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ return fs.unlinkAt(ctx, rp, true /* dir */)
+}
+
+// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
+func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return err
+ }
+ return d.setStat(ctx, rp.Credentials(), &opts.Stat, rp.Mount())
+}
+
+// StatAt implements vfs.FilesystemImpl.StatAt.
+func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ // Since walking updates metadata for all traversed dentries under
+ // InteropModeShared, including the returned one, we can return cached
+ // metadata here regardless of fs.opts.interop.
+ var stat linux.Statx
+ d.statTo(&stat)
+ return stat, nil
+}
+
+// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
+func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return linux.Statfs{}, err
+ }
+ fsstat, err := d.file.statFS(ctx)
+ if err != nil {
+ return linux.Statfs{}, err
+ }
+ nameLen := uint64(fsstat.NameLength)
+ if nameLen > maxFilenameLen {
+ nameLen = maxFilenameLen
+ }
+ return linux.Statfs{
+ // This is primarily for distinguishing a gofer file system in
+ // tests. Testing is important, so instead of defining
+ // something completely random, use a standard value.
+ Type: linux.V9FS_MAGIC,
+ BlockSize: int64(fsstat.BlockSize),
+ Blocks: fsstat.Blocks,
+ BlocksFree: fsstat.BlocksFree,
+ BlocksAvailable: fsstat.BlocksAvailable,
+ Files: fsstat.Files,
+ FilesFree: fsstat.FilesFree,
+ NameLength: nameLen,
+ }, nil
+}
+
+// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
+func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string) error {
+ creds := rp.Credentials()
+ _, err := parent.file.symlink(ctx, target, name, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
+ return err
+ })
+}
+
+// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
+func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ return fs.unlinkAt(ctx, rp, false /* dir */)
+}
+
+// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
+func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath) ([]string, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return nil, err
+ }
+ return d.listxattr(ctx)
+}
+
+// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
+func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) (string, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return "", err
+ }
+ return d.getxattr(ctx, name)
+}
+
+// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
+func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return err
+ }
+ return d.setxattr(ctx, &opts)
+}
+
+// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
+func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(&ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return err
+ }
+ return d.removexattr(ctx, name)
+}
+
+// PrependPath implements vfs.FilesystemImpl.PrependPath.
+func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
+ fs.renameMu.RLock()
+ defer fs.renameMu.RUnlock()
+ return vfs.GenericPrependPath(vfsroot, vd, b)
+}
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
new file mode 100644
index 000000000..d00850e25
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -0,0 +1,1150 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gofer provides a filesystem implementation that is backed by a 9p
+// server, interchangably referred to as "gofers" throughout this package.
+//
+// Lock order:
+// regularFileFD/directoryFD.mu
+// filesystem.renameMu
+// dentry.dirMu
+// filesystem.syncMu
+// dentry.metadataMu
+// *** "memmap.Mappable locks" below this point
+// dentry.mapsMu
+// *** "memmap.Mappable locks taken by Translate" below this point
+// dentry.handleMu
+// dentry.dataMu
+//
+// Locking dentry.dirMu in multiple dentries requires holding
+// filesystem.renameMu for writing.
+package gofer
+
+import (
+ "fmt"
+ "strconv"
+ "sync"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/unet"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Name is the default filesystem name.
+const Name = "9p"
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct{}
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ vfsfs vfs.Filesystem
+
+ // mfp is used to allocate memory that caches regular file contents. mfp is
+ // immutable.
+ mfp pgalloc.MemoryFileProvider
+
+ // Immutable options.
+ opts filesystemOptions
+
+ // client is the client used by this filesystem. client is immutable.
+ client *p9.Client
+
+ // uid and gid are the effective KUID and KGID of the filesystem's creator,
+ // and are used as the owner and group for files that don't specify one.
+ // uid and gid are immutable.
+ uid auth.KUID
+ gid auth.KGID
+
+ // renameMu serves two purposes:
+ //
+ // - It synchronizes path resolution with renaming initiated by this
+ // client.
+ //
+ // - It is held by path resolution to ensure that reachable dentries remain
+ // valid. A dentry is reachable by path resolution if it has a non-zero
+ // reference count (such that it is usable as vfs.ResolvingPath.Start() or
+ // is reachable from its children), or if it is a child dentry (such that
+ // it is reachable from its parent).
+ renameMu sync.RWMutex
+
+ // cachedDentries contains all dentries with 0 references. (Due to race
+ // conditions, it may also contain dentries with non-zero references.)
+ // cachedDentriesLen is the number of dentries in cachedDentries. These
+ // fields are protected by renameMu.
+ cachedDentries dentryList
+ cachedDentriesLen uint64
+
+ // dentries contains all dentries in this filesystem. specialFileFDs
+ // contains all open specialFileFDs. These fields are protected by syncMu.
+ syncMu sync.Mutex
+ dentries map[*dentry]struct{}
+ specialFileFDs map[*specialFileFD]struct{}
+}
+
+type filesystemOptions struct {
+ // "Standard" 9P options.
+ fd int
+ aname string
+ interop InteropMode // derived from the "cache" mount option
+ msize uint32
+ version string
+
+ // maxCachedDentries is the maximum number of dentries with 0 references
+ // retained by the client.
+ maxCachedDentries uint64
+
+ // If forcePageCache is true, host FDs may not be used for application
+ // memory mappings even if available; instead, the client must perform its
+ // own caching of regular file pages. This is primarily useful for testing.
+ forcePageCache bool
+
+ // If limitHostFDTranslation is true, apply maxFillRange() constraints to
+ // host FD mappings returned by dentry.(memmap.Mappable).Translate(). This
+ // makes memory accounting behavior more consistent between cases where
+ // host FDs are / are not available, but may increase the frequency of
+ // sentry-handled page faults on files for which a host FD is available.
+ limitHostFDTranslation bool
+
+ // If overlayfsStaleRead is true, O_RDONLY host FDs provided by the remote
+ // filesystem may not be coherent with writable host FDs opened later, so
+ // mappings of the former must be replaced by mappings of the latter. This
+ // is usually only the case when the remote filesystem is an overlayfs
+ // mount on Linux < 4.19.
+ overlayfsStaleRead bool
+
+ // If regularFilesUseSpecialFileFD is true, application FDs representing
+ // regular files will use distinct file handles for each FD, in the same
+ // way that application FDs representing "special files" such as sockets
+ // do. Note that this disables client caching and mmap for regular files.
+ regularFilesUseSpecialFileFD bool
+}
+
+// InteropMode controls the client's interaction with other remote filesystem
+// users.
+type InteropMode uint32
+
+const (
+ // InteropModeExclusive is appropriate when the filesystem client is the
+ // only user of the remote filesystem.
+ //
+ // - The client may cache arbitrary filesystem state (file data, metadata,
+ // filesystem structure, etc.).
+ //
+ // - Client changes to filesystem state may be sent to the remote
+ // filesystem asynchronously, except when server permission checks are
+ // necessary.
+ //
+ // - File timestamps are based on client clocks. This ensures that users of
+ // the client observe timestamps that are coherent with their own clocks
+ // and consistent with Linux's semantics. However, since it is not always
+ // possible for clients to set arbitrary atimes and mtimes, and never
+ // possible for clients to set arbitrary ctimes, file timestamp changes are
+ // stored in the client only and never sent to the remote filesystem.
+ InteropModeExclusive InteropMode = iota
+
+ // InteropModeWritethrough is appropriate when there are read-only users of
+ // the remote filesystem that expect to observe changes made by the
+ // filesystem client.
+ //
+ // - The client may cache arbitrary filesystem state.
+ //
+ // - Client changes to filesystem state must be sent to the remote
+ // filesystem synchronously.
+ //
+ // - File timestamps are based on client clocks. As a corollary, access
+ // timestamp changes from other remote filesystem users will not be visible
+ // to the client.
+ InteropModeWritethrough
+
+ // InteropModeShared is appropriate when there are users of the remote
+ // filesystem that may mutate its state other than the client.
+ //
+ // - The client must verify cached filesystem state before using it.
+ //
+ // - Client changes to filesystem state must be sent to the remote
+ // filesystem synchronously.
+ //
+ // - File timestamps are based on server clocks. This is necessary to
+ // ensure that timestamp changes are synchronized between remote filesystem
+ // users.
+ //
+ // Note that the correctness of InteropModeShared depends on the server
+ // correctly implementing 9P fids (i.e. each fid immutably represents a
+ // single filesystem object), even in the presence of remote filesystem
+ // mutations from other users. If this is violated, the behavior of the
+ // client is undefined.
+ InteropModeShared
+)
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ mfp := pgalloc.MemoryFileProviderFromContext(ctx)
+ if mfp == nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: context does not provide a pgalloc.MemoryFileProvider")
+ return nil, nil, syserror.EINVAL
+ }
+
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ var fsopts filesystemOptions
+
+ // Check that the transport is "fd".
+ trans, ok := mopts["trans"]
+ if !ok {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: transport must be specified as 'trans=fd'")
+ return nil, nil, syserror.EINVAL
+ }
+ delete(mopts, "trans")
+ if trans != "fd" {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: unsupported transport: trans=%s", trans)
+ return nil, nil, syserror.EINVAL
+ }
+
+ // Check that read and write FDs are provided and identical.
+ rfdstr, ok := mopts["rfdno"]
+ if !ok {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD must be specified as 'rfdno=<file descriptor>")
+ return nil, nil, syserror.EINVAL
+ }
+ delete(mopts, "rfdno")
+ rfd, err := strconv.Atoi(rfdstr)
+ if err != nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid read FD: rfdno=%s", rfdstr)
+ return nil, nil, syserror.EINVAL
+ }
+ wfdstr, ok := mopts["wfdno"]
+ if !ok {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: write FD must be specified as 'wfdno=<file descriptor>")
+ return nil, nil, syserror.EINVAL
+ }
+ delete(mopts, "wfdno")
+ wfd, err := strconv.Atoi(wfdstr)
+ if err != nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid write FD: wfdno=%s", wfdstr)
+ return nil, nil, syserror.EINVAL
+ }
+ if rfd != wfd {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD (%d) and write FD (%d) must be equal", rfd, wfd)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.fd = rfd
+
+ // Get the attach name.
+ fsopts.aname = "/"
+ if aname, ok := mopts["aname"]; ok {
+ delete(mopts, "aname")
+ fsopts.aname = aname
+ }
+
+ // Parse the cache policy. For historical reasons, this defaults to the
+ // least generally-applicable option, InteropModeExclusive.
+ fsopts.interop = InteropModeExclusive
+ if cache, ok := mopts["cache"]; ok {
+ delete(mopts, "cache")
+ switch cache {
+ case "fscache":
+ fsopts.interop = InteropModeExclusive
+ case "fscache_writethrough":
+ fsopts.interop = InteropModeWritethrough
+ case "none":
+ fsopts.regularFilesUseSpecialFileFD = true
+ fallthrough
+ case "remote_revalidating":
+ fsopts.interop = InteropModeShared
+ default:
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid cache policy: cache=%s", cache)
+ return nil, nil, syserror.EINVAL
+ }
+ }
+
+ // Parse the 9P message size.
+ fsopts.msize = 1024 * 1024 // 1M, tested to give good enough performance up to 64M
+ if msizestr, ok := mopts["msize"]; ok {
+ delete(mopts, "msize")
+ msize, err := strconv.ParseUint(msizestr, 10, 32)
+ if err != nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid message size: msize=%s", msizestr)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.msize = uint32(msize)
+ }
+
+ // Parse the 9P protocol version.
+ fsopts.version = p9.HighestVersionString()
+ if version, ok := mopts["version"]; ok {
+ delete(mopts, "version")
+ fsopts.version = version
+ }
+
+ // Parse the dentry cache limit.
+ fsopts.maxCachedDentries = 1000
+ if str, ok := mopts["dentry_cache_limit"]; ok {
+ delete(mopts, "dentry_cache_limit")
+ maxCachedDentries, err := strconv.ParseUint(str, 10, 64)
+ if err != nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.maxCachedDentries = maxCachedDentries
+ }
+
+ // Handle simple flags.
+ if _, ok := mopts["force_page_cache"]; ok {
+ delete(mopts, "force_page_cache")
+ fsopts.forcePageCache = true
+ }
+ if _, ok := mopts["limit_host_fd_translation"]; ok {
+ delete(mopts, "limit_host_fd_translation")
+ fsopts.limitHostFDTranslation = true
+ }
+ if _, ok := mopts["overlayfs_stale_read"]; ok {
+ delete(mopts, "overlayfs_stale_read")
+ fsopts.overlayfsStaleRead = true
+ }
+ // fsopts.regularFilesUseSpecialFileFD can only be enabled by specifying
+ // "cache=none".
+
+ // Check for unparsed options.
+ if len(mopts) != 0 {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: unknown options: %v", mopts)
+ return nil, nil, syserror.EINVAL
+ }
+
+ // Establish a connection with the server.
+ conn, err := unet.NewSocket(fsopts.fd)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // Perform version negotiation with the server.
+ ctx.UninterruptibleSleepStart(false)
+ client, err := p9.NewClient(conn, fsopts.msize, fsopts.version)
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ conn.Close()
+ return nil, nil, err
+ }
+ // Ownership of conn has been transferred to client.
+
+ // Perform attach to obtain the filesystem root.
+ ctx.UninterruptibleSleepStart(false)
+ attached, err := client.Attach(fsopts.aname)
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ client.Close()
+ return nil, nil, err
+ }
+ attachFile := p9file{attached}
+ qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask())
+ if err != nil {
+ attachFile.close(ctx)
+ client.Close()
+ return nil, nil, err
+ }
+
+ // Construct the filesystem object.
+ fs := &filesystem{
+ mfp: mfp,
+ opts: fsopts,
+ uid: creds.EffectiveKUID,
+ gid: creds.EffectiveKGID,
+ client: client,
+ dentries: make(map[*dentry]struct{}),
+ specialFileFDs: make(map[*specialFileFD]struct{}),
+ }
+ fs.vfsfs.Init(vfsObj, fs)
+
+ // Construct the root dentry.
+ root, err := fs.newDentry(ctx, attachFile, qid, attrMask, &attr)
+ if err != nil {
+ attachFile.close(ctx)
+ fs.vfsfs.DecRef()
+ return nil, nil, err
+ }
+ // Set the root's reference count to 2. One reference is returned to the
+ // caller, and the other is deliberately leaked to prevent the root from
+ // being "cached" and subsequently evicted. Its resources will still be
+ // cleaned up by fs.Release().
+ root.refs = 2
+
+ return &fs.vfsfs, &root.vfsd, nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release() {
+ ctx := context.Background()
+ mf := fs.mfp.MemoryFile()
+
+ fs.syncMu.Lock()
+ for d := range fs.dentries {
+ d.handleMu.Lock()
+ d.dataMu.Lock()
+ if d.handleWritable {
+ // Write dirty cached data to the remote file.
+ if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt); err != nil {
+ log.Warningf("gofer.filesystem.Release: failed to flush dentry: %v", err)
+ }
+ // TODO(jamieliu): Do we need to flushf/fsync d?
+ }
+ // Discard cached pages.
+ d.cache.DropAll(mf)
+ d.dirty.RemoveAll()
+ d.dataMu.Unlock()
+ // Close the host fd if one exists.
+ if d.handle.fd >= 0 {
+ syscall.Close(int(d.handle.fd))
+ d.handle.fd = -1
+ }
+ d.handleMu.Unlock()
+ }
+ // There can't be any specialFileFDs still using fs, since each such
+ // FileDescription would hold a reference on a Mount holding a reference on
+ // fs.
+ fs.syncMu.Unlock()
+
+ // Close the connection to the server. This implicitly clunks all fids.
+ fs.client.Close()
+}
+
+// dentry implements vfs.DentryImpl.
+type dentry struct {
+ vfsd vfs.Dentry
+
+ // refs is the reference count. Each dentry holds a reference on its
+ // parent, even if disowned. refs is accessed using atomic memory
+ // operations.
+ refs int64
+
+ // fs is the owning filesystem. fs is immutable.
+ fs *filesystem
+
+ // We don't support hard links, so each dentry maps 1:1 to an inode.
+
+ // file is the unopened p9.File that backs this dentry. file is immutable.
+ file p9file
+
+ // If deleted is non-zero, the file represented by this dentry has been
+ // deleted. deleted is accessed using atomic memory operations.
+ deleted uint32
+
+ // If cached is true, dentryEntry links dentry into
+ // filesystem.cachedDentries. cached and dentryEntry are protected by
+ // filesystem.renameMu.
+ cached bool
+ dentryEntry
+
+ dirMu sync.Mutex
+
+ // If this dentry represents a directory, and InteropModeShared is not in
+ // effect, negativeChildren is a set of child names in this directory that
+ // are known not to exist. negativeChildren is protected by dirMu.
+ negativeChildren map[string]struct{}
+
+ // If this dentry represents a directory, InteropModeShared is not in
+ // effect, and dirents is not nil, it is a cache of all entries in the
+ // directory, in the order they were returned by the server. dirents is
+ // protected by dirMu.
+ dirents []vfs.Dirent
+
+ // Cached metadata; protected by metadataMu and accessed using atomic
+ // memory operations unless otherwise specified.
+ metadataMu sync.Mutex
+ ino uint64 // immutable
+ mode uint32 // type is immutable, perms are mutable
+ uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
+ gid uint32 // auth.KGID, but ...
+ blockSize uint32 // 0 if unknown
+ // Timestamps, all nsecs from the Unix epoch.
+ atime int64
+ mtime int64
+ ctime int64
+ btime int64
+ // File size, protected by both metadataMu and dataMu (i.e. both must be
+ // locked to mutate it).
+ size uint64
+
+ mapsMu sync.Mutex
+
+ // If this dentry represents a regular file, mappings tracks mappings of
+ // the file into memmap.MappingSpaces. mappings is protected by mapsMu.
+ mappings memmap.MappingSet
+
+ // If this dentry represents a regular file or directory:
+ //
+ // - handle is the I/O handle used by all regularFileFDs/directoryFDs
+ // representing this dentry.
+ //
+ // - handleReadable is true if handle is readable.
+ //
+ // - handleWritable is true if handle is writable.
+ //
+ // Invariants:
+ //
+ // - If handleReadable == handleWritable == false, then handle.file == nil
+ // (i.e. there is no open handle). Conversely, if handleReadable ||
+ // handleWritable == true, then handle.file != nil (i.e. there is an open
+ // handle).
+ //
+ // - handleReadable and handleWritable cannot transition from true to false
+ // (i.e. handles may not be downgraded).
+ //
+ // These fields are protected by handleMu.
+ handleMu sync.RWMutex
+ handle handle
+ handleReadable bool
+ handleWritable bool
+
+ dataMu sync.RWMutex
+
+ // If this dentry represents a regular file that is client-cached, cache
+ // maps offsets into the cached file to offsets into
+ // filesystem.mfp.MemoryFile() that store the file's data. cache is
+ // protected by dataMu.
+ cache fsutil.FileRangeSet
+
+ // If this dentry represents a regular file that is client-cached, dirty
+ // tracks dirty segments in cache. dirty is protected by dataMu.
+ dirty fsutil.DirtySet
+
+ // pf implements platform.File for mappings of handle.fd.
+ pf dentryPlatformFile
+
+ // If this dentry represents a symbolic link, InteropModeShared is not in
+ // effect, and haveTarget is true, target is the symlink target. haveTarget
+ // and target are protected by dataMu.
+ haveTarget bool
+ target string
+}
+
+// dentryAttrMask returns a p9.AttrMask enabling all attributes used by the
+// gofer client.
+func dentryAttrMask() p9.AttrMask {
+ return p9.AttrMask{
+ Mode: true,
+ UID: true,
+ GID: true,
+ ATime: true,
+ MTime: true,
+ CTime: true,
+ Size: true,
+ BTime: true,
+ }
+}
+
+// newDentry creates a new dentry representing the given file. The dentry
+// initially has no references, but is not cached; it is the caller's
+// responsibility to set the dentry's reference count and/or call
+// dentry.checkCachingLocked() as appropriate.
+func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, mask p9.AttrMask, attr *p9.Attr) (*dentry, error) {
+ if !mask.Mode {
+ ctx.Warningf("can't create gofer.dentry without file type")
+ return nil, syserror.EIO
+ }
+ if attr.Mode.FileType() == p9.ModeRegular && !mask.Size {
+ ctx.Warningf("can't create regular file gofer.dentry without file size")
+ return nil, syserror.EIO
+ }
+
+ d := &dentry{
+ fs: fs,
+ file: file,
+ ino: qid.Path,
+ mode: uint32(attr.Mode),
+ uid: uint32(fs.uid),
+ gid: uint32(fs.gid),
+ blockSize: usermem.PageSize,
+ handle: handle{
+ fd: -1,
+ },
+ }
+ d.pf.dentry = d
+ if mask.UID {
+ d.uid = uint32(attr.UID)
+ }
+ if mask.GID {
+ d.gid = uint32(attr.GID)
+ }
+ if mask.Size {
+ d.size = attr.Size
+ }
+ if attr.BlockSize != 0 {
+ d.blockSize = uint32(attr.BlockSize)
+ }
+ if mask.ATime {
+ d.atime = dentryTimestampFromP9(attr.ATimeSeconds, attr.ATimeNanoSeconds)
+ }
+ if mask.MTime {
+ d.mtime = dentryTimestampFromP9(attr.MTimeSeconds, attr.MTimeNanoSeconds)
+ }
+ if mask.CTime {
+ d.ctime = dentryTimestampFromP9(attr.CTimeSeconds, attr.CTimeNanoSeconds)
+ }
+ if mask.BTime {
+ d.btime = dentryTimestampFromP9(attr.BTimeSeconds, attr.BTimeNanoSeconds)
+ }
+ d.vfsd.Init(d)
+
+ fs.syncMu.Lock()
+ fs.dentries[d] = struct{}{}
+ fs.syncMu.Unlock()
+ return d, nil
+}
+
+// updateFromP9Attrs is called to update d's metadata after an update from the
+// remote filesystem.
+func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) {
+ d.metadataMu.Lock()
+ if mask.Mode {
+ if got, want := uint32(attr.Mode.FileType()), d.fileType(); got != want {
+ d.metadataMu.Unlock()
+ panic(fmt.Sprintf("gofer.dentry file type changed from %#o to %#o", want, got))
+ }
+ atomic.StoreUint32(&d.mode, uint32(attr.Mode))
+ }
+ if mask.UID {
+ atomic.StoreUint32(&d.uid, uint32(attr.UID))
+ }
+ if mask.GID {
+ atomic.StoreUint32(&d.gid, uint32(attr.GID))
+ }
+ // There is no P9_GETATTR_* bit for I/O block size.
+ if attr.BlockSize != 0 {
+ atomic.StoreUint32(&d.blockSize, uint32(attr.BlockSize))
+ }
+ if mask.ATime {
+ atomic.StoreInt64(&d.atime, dentryTimestampFromP9(attr.ATimeSeconds, attr.ATimeNanoSeconds))
+ }
+ if mask.MTime {
+ atomic.StoreInt64(&d.mtime, dentryTimestampFromP9(attr.MTimeSeconds, attr.MTimeNanoSeconds))
+ }
+ if mask.CTime {
+ atomic.StoreInt64(&d.ctime, dentryTimestampFromP9(attr.CTimeSeconds, attr.CTimeNanoSeconds))
+ }
+ if mask.BTime {
+ atomic.StoreInt64(&d.btime, dentryTimestampFromP9(attr.BTimeSeconds, attr.BTimeNanoSeconds))
+ }
+ if mask.Size {
+ d.dataMu.Lock()
+ atomic.StoreUint64(&d.size, attr.Size)
+ d.dataMu.Unlock()
+ }
+ d.metadataMu.Unlock()
+}
+
+func (d *dentry) updateFromGetattr(ctx context.Context) error {
+ // Use d.handle.file, which represents a 9P fid that has been opened, in
+ // preference to d.file, which represents a 9P fid that has not. This may
+ // be significantly more efficient in some implementations.
+ var (
+ file p9file
+ handleMuRLocked bool
+ )
+ d.handleMu.RLock()
+ if !d.handle.file.isNil() {
+ file = d.handle.file
+ handleMuRLocked = true
+ } else {
+ file = d.file
+ d.handleMu.RUnlock()
+ }
+ _, attrMask, attr, err := file.getAttr(ctx, dentryAttrMask())
+ if handleMuRLocked {
+ d.handleMu.RUnlock()
+ }
+ if err != nil {
+ return err
+ }
+ d.updateFromP9Attrs(attrMask, &attr)
+ return nil
+}
+
+func (d *dentry) fileType() uint32 {
+ return atomic.LoadUint32(&d.mode) & linux.S_IFMT
+}
+
+func (d *dentry) statTo(stat *linux.Statx) {
+ stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS | linux.STATX_BTIME
+ stat.Blksize = atomic.LoadUint32(&d.blockSize)
+ stat.Nlink = 1
+ if d.isDir() {
+ stat.Nlink = 2
+ }
+ stat.UID = atomic.LoadUint32(&d.uid)
+ stat.GID = atomic.LoadUint32(&d.gid)
+ stat.Mode = uint16(atomic.LoadUint32(&d.mode))
+ stat.Ino = d.ino
+ stat.Size = atomic.LoadUint64(&d.size)
+ // This is consistent with regularFileFD.Seek(), which treats regular files
+ // as having no holes.
+ stat.Blocks = (stat.Size + 511) / 512
+ stat.Atime = statxTimestampFromDentry(atomic.LoadInt64(&d.atime))
+ stat.Btime = statxTimestampFromDentry(atomic.LoadInt64(&d.btime))
+ stat.Ctime = statxTimestampFromDentry(atomic.LoadInt64(&d.ctime))
+ stat.Mtime = statxTimestampFromDentry(atomic.LoadInt64(&d.mtime))
+ // TODO(jamieliu): device number
+}
+
+func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mnt *vfs.Mount) error {
+ if stat.Mask == 0 {
+ return nil
+ }
+ if stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_SIZE) != 0 {
+ return syserror.EPERM
+ }
+ if err := vfs.CheckSetStat(creds, stat, uint16(atomic.LoadUint32(&d.mode))&^linux.S_IFMT, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
+ return err
+ }
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ setLocalAtime := false
+ setLocalMtime := false
+ if d.fs.opts.interop != InteropModeShared {
+ // Timestamp updates will be handled locally.
+ setLocalAtime = stat.Mask&linux.STATX_ATIME != 0
+ setLocalMtime = stat.Mask&linux.STATX_MTIME != 0
+ stat.Mask &^= linux.STATX_ATIME | linux.STATX_MTIME
+ if !setLocalMtime && (stat.Mask&linux.STATX_SIZE != 0) {
+ // Truncate updates mtime.
+ setLocalMtime = true
+ stat.Mtime.Nsec = linux.UTIME_NOW
+ }
+ }
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ if stat.Mask != 0 {
+ if err := d.file.setAttr(ctx, p9.SetAttrMask{
+ Permissions: stat.Mask&linux.STATX_MODE != 0,
+ UID: stat.Mask&linux.STATX_UID != 0,
+ GID: stat.Mask&linux.STATX_GID != 0,
+ Size: stat.Mask&linux.STATX_SIZE != 0,
+ ATime: stat.Mask&linux.STATX_ATIME != 0,
+ MTime: stat.Mask&linux.STATX_MTIME != 0,
+ ATimeNotSystemTime: stat.Atime.Nsec != linux.UTIME_NOW,
+ MTimeNotSystemTime: stat.Mtime.Nsec != linux.UTIME_NOW,
+ }, p9.SetAttr{
+ Permissions: p9.FileMode(stat.Mode),
+ UID: p9.UID(stat.UID),
+ GID: p9.GID(stat.GID),
+ Size: stat.Size,
+ ATimeSeconds: uint64(stat.Atime.Sec),
+ ATimeNanoSeconds: uint64(stat.Atime.Nsec),
+ MTimeSeconds: uint64(stat.Mtime.Sec),
+ MTimeNanoSeconds: uint64(stat.Mtime.Nsec),
+ }); err != nil {
+ return err
+ }
+ }
+ if d.fs.opts.interop == InteropModeShared {
+ // There's no point to updating d's metadata in this case since it'll
+ // be overwritten by revalidation before the next time it's used
+ // anyway. (InteropModeShared inhibits client caching of regular file
+ // data, so there's no cache to truncate either.)
+ return nil
+ }
+ now, haveNow := nowFromContext(ctx)
+ if !haveNow {
+ ctx.Warningf("gofer.dentry.setStat: current time not available")
+ }
+ if stat.Mask&linux.STATX_MODE != 0 {
+ atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode))
+ }
+ if stat.Mask&linux.STATX_UID != 0 {
+ atomic.StoreUint32(&d.uid, stat.UID)
+ }
+ if stat.Mask&linux.STATX_GID != 0 {
+ atomic.StoreUint32(&d.gid, stat.GID)
+ }
+ if setLocalAtime {
+ if stat.Atime.Nsec == linux.UTIME_NOW {
+ if haveNow {
+ atomic.StoreInt64(&d.atime, now)
+ }
+ } else {
+ atomic.StoreInt64(&d.atime, dentryTimestampFromStatx(stat.Atime))
+ }
+ }
+ if setLocalMtime {
+ if stat.Mtime.Nsec == linux.UTIME_NOW {
+ if haveNow {
+ atomic.StoreInt64(&d.mtime, now)
+ }
+ } else {
+ atomic.StoreInt64(&d.mtime, dentryTimestampFromStatx(stat.Mtime))
+ }
+ }
+ if haveNow {
+ atomic.StoreInt64(&d.ctime, now)
+ }
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ d.dataMu.Lock()
+ oldSize := d.size
+ d.size = stat.Size
+ // d.dataMu must be unlocked to lock d.mapsMu and invalidate mappings
+ // below. This allows concurrent calls to Read/Translate/etc. These
+ // functions synchronize with truncation by refusing to use cache
+ // contents beyond the new d.size. (We are still holding d.metadataMu,
+ // so we can't race with Write or another truncate.)
+ d.dataMu.Unlock()
+ if d.size < oldSize {
+ oldpgend := pageRoundUp(oldSize)
+ newpgend := pageRoundUp(d.size)
+ if oldpgend != newpgend {
+ d.mapsMu.Lock()
+ d.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{
+ // Compare Linux's mm/truncate.c:truncate_setsize() =>
+ // truncate_pagecache() =>
+ // mm/memory.c:unmap_mapping_range(evencows=1).
+ InvalidatePrivate: true,
+ })
+ d.mapsMu.Unlock()
+ }
+ // We are now guaranteed that there are no translations of
+ // truncated pages, and can remove them from the cache. Since
+ // truncated pages have been removed from the remote file, they
+ // should be dropped without being written back.
+ d.dataMu.Lock()
+ d.cache.Truncate(d.size, d.fs.mfp.MemoryFile())
+ d.dirty.KeepClean(memmap.MappableRange{d.size, oldpgend})
+ d.dataMu.Unlock()
+ }
+ }
+ return nil
+}
+
+func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes, isDir bool) error {
+ return vfs.GenericCheckPermissions(creds, ats, isDir, uint16(atomic.LoadUint32(&d.mode))&0777, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid)))
+}
+
+// IncRef implements vfs.DentryImpl.IncRef.
+func (d *dentry) IncRef() {
+ // d.refs may be 0 if d.fs.renameMu is locked, which serializes against
+ // d.checkCachingLocked().
+ atomic.AddInt64(&d.refs, 1)
+}
+
+// TryIncRef implements vfs.DentryImpl.TryIncRef.
+func (d *dentry) TryIncRef() bool {
+ for {
+ refs := atomic.LoadInt64(&d.refs)
+ if refs == 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) {
+ return true
+ }
+ }
+}
+
+// DecRef implements vfs.DentryImpl.DecRef.
+func (d *dentry) DecRef() {
+ if refs := atomic.AddInt64(&d.refs, -1); refs == 0 {
+ d.fs.renameMu.Lock()
+ d.checkCachingLocked()
+ d.fs.renameMu.Unlock()
+ } else if refs < 0 {
+ panic("gofer.dentry.DecRef() called without holding a reference")
+ }
+}
+
+// checkCachingLocked should be called after d's reference count becomes 0 or it
+// becomes disowned.
+//
+// Preconditions: d.fs.renameMu must be locked for writing.
+func (d *dentry) checkCachingLocked() {
+ // Dentries with a non-zero reference count must be retained. (The only way
+ // to obtain a reference on a dentry with zero references is via path
+ // resolution, which requires renameMu, so if d.refs is zero then it will
+ // remain zero while we hold renameMu for writing.)
+ if atomic.LoadInt64(&d.refs) != 0 {
+ if d.cached {
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentriesLen--
+ d.cached = false
+ }
+ return
+ }
+ // Non-child dentries with zero references are no longer reachable by path
+ // resolution and should be dropped immediately.
+ if d.vfsd.Parent() == nil || d.vfsd.IsDisowned() {
+ if d.cached {
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentriesLen--
+ d.cached = false
+ }
+ d.destroyLocked()
+ return
+ }
+ // If d is already cached, just move it to the front of the LRU.
+ if d.cached {
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentries.PushFront(d)
+ return
+ }
+ // Cache the dentry, then evict the least recently used cached dentry if
+ // the cache becomes over-full.
+ d.fs.cachedDentries.PushFront(d)
+ d.fs.cachedDentriesLen++
+ d.cached = true
+ if d.fs.cachedDentriesLen > d.fs.opts.maxCachedDentries {
+ victim := d.fs.cachedDentries.Back()
+ d.fs.cachedDentries.Remove(victim)
+ d.fs.cachedDentriesLen--
+ victim.cached = false
+ // victim.refs may have become non-zero from an earlier path
+ // resolution since it was inserted into fs.cachedDentries; see
+ // dentry.incRefLocked(). Either way, we brought
+ // fs.cachedDentriesLen back down to fs.opts.maxCachedDentries, so
+ // we don't loop.
+ if atomic.LoadInt64(&victim.refs) == 0 {
+ if victimParentVFSD := victim.vfsd.Parent(); victimParentVFSD != nil {
+ victimParent := victimParentVFSD.Impl().(*dentry)
+ victimParent.dirMu.Lock()
+ if !victim.vfsd.IsDisowned() {
+ // victim can't be a mount point (in any mount
+ // namespace), since VFS holds references on mount
+ // points.
+ d.fs.vfsfs.VirtualFilesystem().ForceDeleteDentry(&victim.vfsd)
+ // We're only deleting the dentry, not the file it
+ // represents, so we don't need to update
+ // victimParent.dirents etc.
+ }
+ victimParent.dirMu.Unlock()
+ }
+ victim.destroyLocked()
+ }
+ }
+}
+
+// Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0. d is
+// not a child dentry.
+func (d *dentry) destroyLocked() {
+ ctx := context.Background()
+ d.handleMu.Lock()
+ if !d.handle.file.isNil() {
+ mf := d.fs.mfp.MemoryFile()
+ d.dataMu.Lock()
+ // Write dirty pages back to the remote filesystem.
+ if d.handleWritable {
+ if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, d.handle.writeFromBlocksAt); err != nil {
+ log.Warningf("gofer.dentry.DecRef: failed to write dirty data back: %v", err)
+ }
+ }
+ // Discard cached data.
+ d.cache.DropAll(mf)
+ d.dirty.RemoveAll()
+ d.dataMu.Unlock()
+ // Clunk open fids and close open host FDs.
+ d.handle.close(ctx)
+ }
+ d.handleMu.Unlock()
+ d.file.close(ctx)
+ // Remove d from the set of all dentries.
+ d.fs.syncMu.Lock()
+ delete(d.fs.dentries, d)
+ d.fs.syncMu.Unlock()
+ // Drop the reference held by d on its parent.
+ if parentVFSD := d.vfsd.Parent(); parentVFSD != nil {
+ parent := parentVFSD.Impl().(*dentry)
+ // This is parent.DecRef() without recursive locking of d.fs.renameMu.
+ if refs := atomic.AddInt64(&parent.refs, -1); refs == 0 {
+ parent.checkCachingLocked()
+ } else if refs < 0 {
+ panic("gofer.dentry.DecRef() called without holding a reference")
+ }
+ }
+}
+
+func (d *dentry) isDeleted() bool {
+ return atomic.LoadUint32(&d.deleted) != 0
+}
+
+func (d *dentry) setDeleted() {
+ atomic.StoreUint32(&d.deleted, 1)
+}
+
+func (d *dentry) listxattr(ctx context.Context) ([]string, error) {
+ return nil, syserror.ENOTSUP
+}
+
+func (d *dentry) getxattr(ctx context.Context, name string) (string, error) {
+ // TODO(jamieliu): add vfs.GetxattrOptions.Size
+ return d.file.getXattr(ctx, name, linux.XATTR_SIZE_MAX)
+}
+
+func (d *dentry) setxattr(ctx context.Context, opts *vfs.SetxattrOptions) error {
+ return d.file.setXattr(ctx, opts.Name, opts.Value, opts.Flags)
+}
+
+func (d *dentry) removexattr(ctx context.Context, name string) error {
+ return syserror.ENOTSUP
+}
+
+// Preconditions: d.isRegularFile() || d.isDirectory().
+func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool) error {
+ // O_TRUNC unconditionally requires us to obtain a new handle (opened with
+ // O_TRUNC).
+ if !trunc {
+ d.handleMu.RLock()
+ if (!read || d.handleReadable) && (!write || d.handleWritable) {
+ // The current handle is sufficient.
+ d.handleMu.RUnlock()
+ return nil
+ }
+ d.handleMu.RUnlock()
+ }
+
+ haveOldFD := false
+ d.handleMu.Lock()
+ if (read && !d.handleReadable) || (write && !d.handleWritable) || trunc {
+ // Get a new handle.
+ wantReadable := d.handleReadable || read
+ wantWritable := d.handleWritable || write
+ h, err := openHandle(ctx, d.file, wantReadable, wantWritable, trunc)
+ if err != nil {
+ d.handleMu.Unlock()
+ return err
+ }
+ if !d.handle.file.isNil() {
+ // Check that old and new handles are compatible: If the old handle
+ // includes a host file descriptor but the new one does not, or
+ // vice versa, old and new memory mappings may be incoherent.
+ haveOldFD = d.handle.fd >= 0
+ haveNewFD := h.fd >= 0
+ if haveOldFD != haveNewFD {
+ d.handleMu.Unlock()
+ ctx.Warningf("gofer.dentry.ensureSharedHandle: can't change host FD availability from %v to %v across dentry handle upgrade", haveOldFD, haveNewFD)
+ h.close(ctx)
+ return syserror.EIO
+ }
+ if haveOldFD {
+ // We may have raced with callers of d.pf.FD() that are now
+ // using the old file descriptor, preventing us from safely
+ // closing it. We could handle this by invalidating existing
+ // memmap.Translations, but this is expensive. Instead, use
+ // dup2() to make the old file descriptor refer to the new file
+ // description, then close the new file descriptor (which is no
+ // longer needed). Racing callers may use the old or new file
+ // description, but this doesn't matter since they refer to the
+ // same file (unless d.fs.opts.overlayfsStaleRead is true,
+ // which we handle separately).
+ if err := syscall.Dup2(int(h.fd), int(d.handle.fd)); err != nil {
+ d.handleMu.Unlock()
+ ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to dup fd %d to fd %d: %v", h.fd, d.handle.fd, err)
+ h.close(ctx)
+ return err
+ }
+ syscall.Close(int(h.fd))
+ h.fd = d.handle.fd
+ if d.fs.opts.overlayfsStaleRead {
+ // Replace sentry mappings of the old FD with mappings of
+ // the new FD, since the two are not necessarily coherent.
+ if err := d.pf.hostFileMapper.RegenerateMappings(int(h.fd)); err != nil {
+ d.handleMu.Unlock()
+ ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to replace sentry mappings of old FD with mappings of new FD: %v", err)
+ h.close(ctx)
+ return err
+ }
+ }
+ // Clunk the old fid before making the new handle visible (by
+ // unlocking d.handleMu).
+ d.handle.file.close(ctx)
+ }
+ }
+ // Switch to the new handle.
+ d.handle = h
+ d.handleReadable = wantReadable
+ d.handleWritable = wantWritable
+ }
+ d.handleMu.Unlock()
+
+ if d.fs.opts.overlayfsStaleRead && haveOldFD {
+ // Invalidate application mappings that may be using the old FD; they
+ // will be replaced with mappings using the new FD after future calls
+ // to d.Translate(). This requires holding d.mapsMu, which precedes
+ // d.handleMu in the lock order.
+ d.mapsMu.Lock()
+ d.mappings.InvalidateAll(memmap.InvalidateOpts{})
+ d.mapsMu.Unlock()
+ }
+
+ return nil
+}
+
+// fileDescription is embedded by gofer implementations of
+// vfs.FileDescriptionImpl.
+type fileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+}
+
+func (fd *fileDescription) filesystem() *filesystem {
+ return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem)
+}
+
+func (fd *fileDescription) dentry() *dentry {
+ return fd.vfsfd.Dentry().Impl().(*dentry)
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ d := fd.dentry()
+ if d.fs.opts.interop == InteropModeShared && opts.Mask&(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME|linux.STATX_SIZE|linux.STATX_BLOCKS|linux.STATX_BTIME) != 0 && opts.Sync != linux.AT_STATX_DONT_SYNC {
+ // TODO(jamieliu): Use specialFileFD.handle.file for the getattr if
+ // available?
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return linux.Statx{}, err
+ }
+ }
+ var stat linux.Statx
+ d.statTo(&stat)
+ return stat, nil
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ return fd.dentry().setStat(ctx, auth.CredentialsFromContext(ctx), &opts.Stat, fd.vfsfd.Mount())
+}
+
+// Listxattr implements vfs.FileDescriptionImpl.Listxattr.
+func (fd *fileDescription) Listxattr(ctx context.Context) ([]string, error) {
+ return fd.dentry().listxattr(ctx)
+}
+
+// Getxattr implements vfs.FileDescriptionImpl.Getxattr.
+func (fd *fileDescription) Getxattr(ctx context.Context, name string) (string, error) {
+ return fd.dentry().getxattr(ctx, name)
+}
+
+// Setxattr implements vfs.FileDescriptionImpl.Setxattr.
+func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error {
+ return fd.dentry().setxattr(ctx, &opts)
+}
+
+// Removexattr implements vfs.FileDescriptionImpl.Removexattr.
+func (fd *fileDescription) Removexattr(ctx context.Context, name string) error {
+ return fd.dentry().removexattr(ctx, name)
+}
diff --git a/pkg/sentry/fsimpl/gofer/handle.go b/pkg/sentry/fsimpl/gofer/handle.go
new file mode 100644
index 000000000..cfe66f797
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/handle.go
@@ -0,0 +1,135 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/safemem"
+)
+
+// handle represents a remote "open file descriptor", consisting of an opened
+// fid (p9.File) and optionally a host file descriptor.
+type handle struct {
+ file p9file
+ fd int32 // -1 if unavailable
+}
+
+// Preconditions: read || write.
+func openHandle(ctx context.Context, file p9file, read, write, trunc bool) (handle, error) {
+ _, newfile, err := file.walk(ctx, nil)
+ if err != nil {
+ return handle{fd: -1}, err
+ }
+ var flags p9.OpenFlags
+ switch {
+ case read && !write:
+ flags = p9.ReadOnly
+ case !read && write:
+ flags = p9.WriteOnly
+ case read && write:
+ flags = p9.ReadWrite
+ }
+ if trunc {
+ flags |= p9.OpenTruncate
+ }
+ fdobj, _, _, err := newfile.open(ctx, flags)
+ if err != nil {
+ newfile.close(ctx)
+ return handle{fd: -1}, err
+ }
+ fd := int32(-1)
+ if fdobj != nil {
+ fd = int32(fdobj.Release())
+ }
+ return handle{
+ file: newfile,
+ fd: fd,
+ }, nil
+}
+
+func (h *handle) close(ctx context.Context) {
+ h.file.close(ctx)
+ h.file = p9file{}
+ if h.fd >= 0 {
+ syscall.Close(int(h.fd))
+ h.fd = -1
+ }
+}
+
+func (h *handle) readToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error) {
+ if dsts.IsEmpty() {
+ return 0, nil
+ }
+ if h.fd >= 0 {
+ ctx.UninterruptibleSleepStart(false)
+ n, err := hostPreadv(h.fd, dsts, int64(offset))
+ ctx.UninterruptibleSleepFinish(false)
+ return n, err
+ }
+ if dsts.NumBlocks() == 1 && !dsts.Head().NeedSafecopy() {
+ n, err := h.file.readAt(ctx, dsts.Head().ToSlice(), offset)
+ return uint64(n), err
+ }
+ // Buffer the read since p9.File.ReadAt() takes []byte.
+ buf := make([]byte, dsts.NumBytes())
+ n, err := h.file.readAt(ctx, buf, offset)
+ if n == 0 {
+ return 0, err
+ }
+ if cp, cperr := safemem.CopySeq(dsts, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:n]))); cperr != nil {
+ return cp, cperr
+ }
+ return uint64(n), err
+}
+
+func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) {
+ if srcs.IsEmpty() {
+ return 0, nil
+ }
+ if h.fd >= 0 {
+ ctx.UninterruptibleSleepStart(false)
+ n, err := hostPwritev(h.fd, srcs, int64(offset))
+ ctx.UninterruptibleSleepFinish(false)
+ return n, err
+ }
+ if srcs.NumBlocks() == 1 && !srcs.Head().NeedSafecopy() {
+ n, err := h.file.writeAt(ctx, srcs.Head().ToSlice(), offset)
+ return uint64(n), err
+ }
+ // Buffer the write since p9.File.WriteAt() takes []byte.
+ buf := make([]byte, srcs.NumBytes())
+ cp, cperr := safemem.CopySeq(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), srcs)
+ if cp == 0 {
+ return 0, cperr
+ }
+ n, err := h.file.writeAt(ctx, buf[:cp], offset)
+ if err != nil {
+ return uint64(n), err
+ }
+ return cp, cperr
+}
+
+func (h *handle) sync(ctx context.Context) error {
+ if h.fd >= 0 {
+ ctx.UninterruptibleSleepStart(false)
+ err := syscall.Fsync(int(h.fd))
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+ }
+ return h.file.fsync(ctx)
+}
diff --git a/pkg/sentry/fsimpl/gofer/handle_unsafe.go b/pkg/sentry/fsimpl/gofer/handle_unsafe.go
new file mode 100644
index 000000000..19560ab26
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/handle_unsafe.go
@@ -0,0 +1,66 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/safemem"
+)
+
+// Preconditions: !dsts.IsEmpty().
+func hostPreadv(fd int32, dsts safemem.BlockSeq, off int64) (uint64, error) {
+ // No buffering is necessary regardless of safecopy; host syscalls will
+ // return EFAULT if appropriate, instead of raising SIGBUS.
+ if dsts.NumBlocks() == 1 {
+ // Use pread() instead of preadv() to avoid iovec allocation and
+ // copying.
+ dst := dsts.Head()
+ n, _, e := syscall.Syscall6(syscall.SYS_PREAD64, uintptr(fd), dst.Addr(), uintptr(dst.Len()), uintptr(off), 0, 0)
+ if e != 0 {
+ return 0, e
+ }
+ return uint64(n), nil
+ }
+ iovs := safemem.IovecsFromBlockSeq(dsts)
+ n, _, e := syscall.Syscall6(syscall.SYS_PREADV, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(off), 0, 0)
+ if e != 0 {
+ return 0, e
+ }
+ return uint64(n), nil
+}
+
+// Preconditions: !srcs.IsEmpty().
+func hostPwritev(fd int32, srcs safemem.BlockSeq, off int64) (uint64, error) {
+ // No buffering is necessary regardless of safecopy; host syscalls will
+ // return EFAULT if appropriate, instead of raising SIGBUS.
+ if srcs.NumBlocks() == 1 {
+ // Use pwrite() instead of pwritev() to avoid iovec allocation and
+ // copying.
+ src := srcs.Head()
+ n, _, e := syscall.Syscall6(syscall.SYS_PWRITE64, uintptr(fd), src.Addr(), uintptr(src.Len()), uintptr(off), 0, 0)
+ if e != 0 {
+ return 0, e
+ }
+ return uint64(n), nil
+ }
+ iovs := safemem.IovecsFromBlockSeq(srcs)
+ n, _, e := syscall.Syscall6(syscall.SYS_PWRITEV, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(off), 0, 0)
+ if e != 0 {
+ return 0, e
+ }
+ return uint64(n), nil
+}
diff --git a/pkg/sentry/fsimpl/gofer/p9file.go b/pkg/sentry/fsimpl/gofer/p9file.go
new file mode 100644
index 000000000..755ac2985
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/p9file.go
@@ -0,0 +1,219 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// p9file is a wrapper around p9.File that provides methods that are
+// Context-aware.
+type p9file struct {
+ file p9.File
+}
+
+func (f p9file) isNil() bool {
+ return f.file == nil
+}
+
+func (f p9file) walk(ctx context.Context, names []string) ([]p9.QID, p9file, error) {
+ ctx.UninterruptibleSleepStart(false)
+ qids, newfile, err := f.file.Walk(names)
+ ctx.UninterruptibleSleepFinish(false)
+ return qids, p9file{newfile}, err
+}
+
+func (f p9file) walkGetAttr(ctx context.Context, names []string) ([]p9.QID, p9file, p9.AttrMask, p9.Attr, error) {
+ ctx.UninterruptibleSleepStart(false)
+ qids, newfile, attrMask, attr, err := f.file.WalkGetAttr(names)
+ ctx.UninterruptibleSleepFinish(false)
+ return qids, p9file{newfile}, attrMask, attr, err
+}
+
+// walkGetAttrOne is a wrapper around p9.File.WalkGetAttr that takes a single
+// path component and returns a single qid.
+func (f p9file) walkGetAttrOne(ctx context.Context, name string) (p9.QID, p9file, p9.AttrMask, p9.Attr, error) {
+ ctx.UninterruptibleSleepStart(false)
+ qids, newfile, attrMask, attr, err := f.file.WalkGetAttr([]string{name})
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ return p9.QID{}, p9file{}, p9.AttrMask{}, p9.Attr{}, err
+ }
+ if len(qids) != 1 {
+ ctx.Warningf("p9.File.WalkGetAttr returned %d qids (%v), wanted 1", len(qids), qids)
+ if newfile != nil {
+ p9file{newfile}.close(ctx)
+ }
+ return p9.QID{}, p9file{}, p9.AttrMask{}, p9.Attr{}, syserror.EIO
+ }
+ return qids[0], p9file{newfile}, attrMask, attr, nil
+}
+
+func (f p9file) statFS(ctx context.Context) (p9.FSStat, error) {
+ ctx.UninterruptibleSleepStart(false)
+ fsstat, err := f.file.StatFS()
+ ctx.UninterruptibleSleepFinish(false)
+ return fsstat, err
+}
+
+func (f p9file) getAttr(ctx context.Context, req p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) {
+ ctx.UninterruptibleSleepStart(false)
+ qid, attrMask, attr, err := f.file.GetAttr(req)
+ ctx.UninterruptibleSleepFinish(false)
+ return qid, attrMask, attr, err
+}
+
+func (f p9file) setAttr(ctx context.Context, valid p9.SetAttrMask, attr p9.SetAttr) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.SetAttr(valid, attr)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) getXattr(ctx context.Context, name string, size uint64) (string, error) {
+ ctx.UninterruptibleSleepStart(false)
+ val, err := f.file.GetXattr(name, size)
+ ctx.UninterruptibleSleepFinish(false)
+ return val, err
+}
+
+func (f p9file) setXattr(ctx context.Context, name, value string, flags uint32) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.SetXattr(name, value, flags)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) allocate(ctx context.Context, mode p9.AllocateMode, offset, length uint64) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.Allocate(mode, offset, length)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) close(ctx context.Context) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.Close()
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) open(ctx context.Context, flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) {
+ ctx.UninterruptibleSleepStart(false)
+ fdobj, qid, iounit, err := f.file.Open(flags)
+ ctx.UninterruptibleSleepFinish(false)
+ return fdobj, qid, iounit, err
+}
+
+func (f p9file) readAt(ctx context.Context, p []byte, offset uint64) (int, error) {
+ ctx.UninterruptibleSleepStart(false)
+ n, err := f.file.ReadAt(p, offset)
+ ctx.UninterruptibleSleepFinish(false)
+ return n, err
+}
+
+func (f p9file) writeAt(ctx context.Context, p []byte, offset uint64) (int, error) {
+ ctx.UninterruptibleSleepStart(false)
+ n, err := f.file.WriteAt(p, offset)
+ ctx.UninterruptibleSleepFinish(false)
+ return n, err
+}
+
+func (f p9file) fsync(ctx context.Context) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.FSync()
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) create(ctx context.Context, name string, flags p9.OpenFlags, permissions p9.FileMode, uid p9.UID, gid p9.GID) (*fd.FD, p9file, p9.QID, uint32, error) {
+ ctx.UninterruptibleSleepStart(false)
+ fdobj, newfile, qid, iounit, err := f.file.Create(name, flags, permissions, uid, gid)
+ ctx.UninterruptibleSleepFinish(false)
+ return fdobj, p9file{newfile}, qid, iounit, err
+}
+
+func (f p9file) mkdir(ctx context.Context, name string, permissions p9.FileMode, uid p9.UID, gid p9.GID) (p9.QID, error) {
+ ctx.UninterruptibleSleepStart(false)
+ qid, err := f.file.Mkdir(name, permissions, uid, gid)
+ ctx.UninterruptibleSleepFinish(false)
+ return qid, err
+}
+
+func (f p9file) symlink(ctx context.Context, oldName string, newName string, uid p9.UID, gid p9.GID) (p9.QID, error) {
+ ctx.UninterruptibleSleepStart(false)
+ qid, err := f.file.Symlink(oldName, newName, uid, gid)
+ ctx.UninterruptibleSleepFinish(false)
+ return qid, err
+}
+
+func (f p9file) link(ctx context.Context, target p9file, newName string) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.Link(target.file, newName)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) mknod(ctx context.Context, name string, mode p9.FileMode, major uint32, minor uint32, uid p9.UID, gid p9.GID) (p9.QID, error) {
+ ctx.UninterruptibleSleepStart(false)
+ qid, err := f.file.Mknod(name, mode, major, minor, uid, gid)
+ ctx.UninterruptibleSleepFinish(false)
+ return qid, err
+}
+
+func (f p9file) rename(ctx context.Context, newDir p9file, newName string) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.Rename(newDir.file, newName)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) unlinkAt(ctx context.Context, name string, flags uint32) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.UnlinkAt(name, flags)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) readdir(ctx context.Context, offset uint64, count uint32) ([]p9.Dirent, error) {
+ ctx.UninterruptibleSleepStart(false)
+ dirents, err := f.file.Readdir(offset, count)
+ ctx.UninterruptibleSleepFinish(false)
+ return dirents, err
+}
+
+func (f p9file) readlink(ctx context.Context) (string, error) {
+ ctx.UninterruptibleSleepStart(false)
+ target, err := f.file.Readlink()
+ ctx.UninterruptibleSleepFinish(false)
+ return target, err
+}
+
+func (f p9file) flush(ctx context.Context) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.Flush()
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) connect(ctx context.Context, flags p9.ConnectFlags) (*fd.FD, error) {
+ ctx.UninterruptibleSleepStart(false)
+ fdobj, err := f.file.Connect(flags)
+ ctx.UninterruptibleSleepFinish(false)
+ return fdobj, err
+}
diff --git a/pkg/sentry/fsimpl/gofer/pagemath.go b/pkg/sentry/fsimpl/gofer/pagemath.go
new file mode 100644
index 000000000..847cb0784
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/pagemath.go
@@ -0,0 +1,31 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// This are equivalent to usermem.Addr.RoundDown/Up, but without the
+// potentially truncating conversion to usermem.Addr. This is necessary because
+// there is no way to define generic "PageRoundDown/Up" functions in Go.
+
+func pageRoundDown(x uint64) uint64 {
+ return x &^ (usermem.PageSize - 1)
+}
+
+func pageRoundUp(x uint64) uint64 {
+ return pageRoundDown(x + usermem.PageSize - 1)
+}
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
new file mode 100644
index 000000000..54c1031a7
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -0,0 +1,865 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+ "io"
+ "math"
+ "sync"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func (d *dentry) isRegularFile() bool {
+ return d.fileType() == linux.S_IFREG
+}
+
+type regularFileFD struct {
+ fileDescription
+
+ // off is the file offset. off is protected by mu.
+ mu sync.Mutex
+ off int64
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *regularFileFD) Release() {
+}
+
+// OnClose implements vfs.FileDescriptionImpl.OnClose.
+func (fd *regularFileFD) OnClose(ctx context.Context) error {
+ if !fd.vfsfd.IsWritable() {
+ return nil
+ }
+ // Skip flushing if writes may be buffered by the client, since (as with
+ // the VFS1 client) we don't flush buffered writes on close anyway.
+ d := fd.dentry()
+ if d.fs.opts.interop == InteropModeExclusive {
+ return nil
+ }
+ d.handleMu.RLock()
+ defer d.handleMu.RUnlock()
+ return d.handle.file.flush(ctx)
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ // Check for reading at EOF before calling into MM (but not under
+ // InteropModeShared, which makes d.size unreliable).
+ d := fd.dentry()
+ if d.fs.opts.interop != InteropModeShared && uint64(offset) >= atomic.LoadUint64(&d.size) {
+ return 0, io.EOF
+ }
+
+ if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 {
+ // Lock d.metadataMu for the rest of the read to prevent d.size from
+ // changing.
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ // Write dirty cached pages that will be touched by the read back to
+ // the remote file.
+ if err := d.writeback(ctx, offset, dst.NumBytes()); err != nil {
+ return 0, err
+ }
+ }
+
+ rw := getDentryReadWriter(ctx, d, offset)
+ if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 {
+ // Require the read to go to the remote file.
+ rw.direct = true
+ }
+ n, err := dst.CopyOutFrom(ctx, rw)
+ putDentryReadWriter(rw)
+ if d.fs.opts.interop != InteropModeShared {
+ // Compare Linux's mm/filemap.c:do_generic_file_read() => file_accessed().
+ d.touchAtime(ctx, fd.vfsfd.Mount())
+ }
+ return n, err
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ fd.mu.Lock()
+ n, err := fd.PRead(ctx, dst, fd.off, opts)
+ fd.off += n
+ fd.mu.Unlock()
+ return n, err
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ d := fd.dentry()
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ if d.fs.opts.interop != InteropModeShared {
+ // Compare Linux's mm/filemap.c:__generic_file_write_iter() =>
+ // file_update_time(). This is d.touchCMtime(), but without locking
+ // d.metadataMu (recursively).
+ if now, ok := nowFromContext(ctx); ok {
+ atomic.StoreInt64(&d.mtime, now)
+ atomic.StoreInt64(&d.ctime, now)
+ }
+ }
+ if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 {
+ // Write dirty cached pages that will be touched by the write back to
+ // the remote file.
+ if err := d.writeback(ctx, offset, src.NumBytes()); err != nil {
+ return 0, err
+ }
+ // Remove touched pages from the cache.
+ pgstart := pageRoundDown(uint64(offset))
+ pgend := pageRoundUp(uint64(offset + src.NumBytes()))
+ if pgend < pgstart {
+ return 0, syserror.EINVAL
+ }
+ mr := memmap.MappableRange{pgstart, pgend}
+ var freed []platform.FileRange
+ d.dataMu.Lock()
+ cseg := d.cache.LowerBoundSegment(mr.Start)
+ for cseg.Ok() && cseg.Start() < mr.End {
+ cseg = d.cache.Isolate(cseg, mr)
+ freed = append(freed, platform.FileRange{cseg.Value(), cseg.Value() + cseg.Range().Length()})
+ cseg = d.cache.Remove(cseg).NextSegment()
+ }
+ d.dataMu.Unlock()
+ // Invalidate mappings of removed pages.
+ d.mapsMu.Lock()
+ d.mappings.Invalidate(mr, memmap.InvalidateOpts{})
+ d.mapsMu.Unlock()
+ // Finally free pages removed from the cache.
+ mf := d.fs.mfp.MemoryFile()
+ for _, freedFR := range freed {
+ mf.DecRef(freedFR)
+ }
+ }
+ rw := getDentryReadWriter(ctx, d, offset)
+ if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 {
+ // Require the write to go to the remote file.
+ rw.direct = true
+ }
+ n, err := src.CopyInTo(ctx, rw)
+ putDentryReadWriter(rw)
+ if n != 0 && fd.vfsfd.StatusFlags()&(linux.O_DSYNC|linux.O_SYNC) != 0 {
+ // Write dirty cached pages touched by the write back to the remote
+ // file.
+ if err := d.writeback(ctx, offset, src.NumBytes()); err != nil {
+ return 0, err
+ }
+ // Request the remote filesystem to sync the remote file.
+ if err := d.handle.file.fsync(ctx); err != nil {
+ return 0, err
+ }
+ }
+ return n, err
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ fd.mu.Lock()
+ n, err := fd.PWrite(ctx, src, fd.off, opts)
+ fd.off += n
+ fd.mu.Unlock()
+ return n, err
+}
+
+type dentryReadWriter struct {
+ ctx context.Context
+ d *dentry
+ off uint64
+ direct bool
+}
+
+var dentryReadWriterPool = sync.Pool{
+ New: func() interface{} {
+ return &dentryReadWriter{}
+ },
+}
+
+func getDentryReadWriter(ctx context.Context, d *dentry, offset int64) *dentryReadWriter {
+ rw := dentryReadWriterPool.Get().(*dentryReadWriter)
+ rw.ctx = ctx
+ rw.d = d
+ rw.off = uint64(offset)
+ rw.direct = false
+ return rw
+}
+
+func putDentryReadWriter(rw *dentryReadWriter) {
+ rw.ctx = nil
+ rw.d = nil
+ dentryReadWriterPool.Put(rw)
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ if dsts.IsEmpty() {
+ return 0, nil
+ }
+
+ // If we have a mmappable host FD (which must be used here to ensure
+ // coherence with memory-mapped I/O), or if InteropModeShared is in effect
+ // (which prevents us from caching file contents and makes dentry.size
+ // unreliable), or if the file was opened O_DIRECT, read directly from
+ // dentry.handle without locking dentry.dataMu.
+ rw.d.handleMu.RLock()
+ if (rw.d.handle.fd >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct {
+ n, err := rw.d.handle.readToBlocksAt(rw.ctx, dsts, rw.off)
+ rw.d.handleMu.RUnlock()
+ rw.off += n
+ return n, err
+ }
+
+ // Otherwise read from/through the cache.
+ mf := rw.d.fs.mfp.MemoryFile()
+ fillCache := mf.ShouldCacheEvictable()
+ var dataMuUnlock func()
+ if fillCache {
+ rw.d.dataMu.Lock()
+ dataMuUnlock = rw.d.dataMu.Unlock
+ } else {
+ rw.d.dataMu.RLock()
+ dataMuUnlock = rw.d.dataMu.RUnlock
+ }
+
+ // Compute the range to read (limited by file size and overflow-checked).
+ if rw.off >= rw.d.size {
+ dataMuUnlock()
+ rw.d.handleMu.RUnlock()
+ return 0, io.EOF
+ }
+ end := rw.d.size
+ if rend := rw.off + dsts.NumBytes(); rend > rw.off && rend < end {
+ end = rend
+ }
+
+ var done uint64
+ seg, gap := rw.d.cache.Find(rw.off)
+ for rw.off < end {
+ mr := memmap.MappableRange{rw.off, end}
+ switch {
+ case seg.Ok():
+ // Get internal mappings from the cache.
+ ims, err := mf.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read)
+ if err != nil {
+ dataMuUnlock()
+ rw.d.handleMu.RUnlock()
+ return done, err
+ }
+
+ // Copy from internal mappings.
+ n, err := safemem.CopySeq(dsts, ims)
+ done += n
+ rw.off += n
+ dsts = dsts.DropFirst64(n)
+ if err != nil {
+ dataMuUnlock()
+ rw.d.handleMu.RUnlock()
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok():
+ gapMR := gap.Range().Intersect(mr)
+ if fillCache {
+ // Read into the cache, then re-enter the loop to read from the
+ // cache.
+ reqMR := memmap.MappableRange{
+ Start: pageRoundDown(gapMR.Start),
+ End: pageRoundUp(gapMR.End),
+ }
+ optMR := gap.Range()
+ err := rw.d.cache.Fill(rw.ctx, reqMR, maxFillRange(reqMR, optMR), mf, usage.PageCache, rw.d.handle.readToBlocksAt)
+ mf.MarkEvictable(rw.d, pgalloc.EvictableRange{optMR.Start, optMR.End})
+ seg, gap = rw.d.cache.Find(rw.off)
+ if !seg.Ok() {
+ dataMuUnlock()
+ rw.d.handleMu.RUnlock()
+ return done, err
+ }
+ // err might have occurred in part of gap.Range() outside
+ // gapMR. Forget about it for now; if the error matters and
+ // persists, we'll run into it again in a later iteration of
+ // this loop.
+ } else {
+ // Read directly from the file.
+ gapDsts := dsts.TakeFirst64(gapMR.Length())
+ n, err := rw.d.handle.readToBlocksAt(rw.ctx, gapDsts, gapMR.Start)
+ done += n
+ rw.off += n
+ dsts = dsts.DropFirst64(n)
+ // Partial reads are fine. But we must stop reading.
+ if n != gapDsts.NumBytes() || err != nil {
+ dataMuUnlock()
+ rw.d.handleMu.RUnlock()
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{}
+ }
+ }
+ }
+ dataMuUnlock()
+ rw.d.handleMu.RUnlock()
+ return done, nil
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+//
+// Preconditions: rw.d.metadataMu must be locked.
+func (rw *dentryReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ if srcs.IsEmpty() {
+ return 0, nil
+ }
+
+ // If we have a mmappable host FD (which must be used here to ensure
+ // coherence with memory-mapped I/O), or if InteropModeShared is in effect
+ // (which prevents us from caching file contents), or if the file was
+ // opened with O_DIRECT, write directly to dentry.handle without locking
+ // dentry.dataMu.
+ rw.d.handleMu.RLock()
+ if (rw.d.handle.fd >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct {
+ n, err := rw.d.handle.writeFromBlocksAt(rw.ctx, srcs, rw.off)
+ rw.d.handleMu.RUnlock()
+ rw.off += n
+ return n, err
+ }
+
+ // Otherwise write to/through the cache.
+ mf := rw.d.fs.mfp.MemoryFile()
+ rw.d.dataMu.Lock()
+
+ // Compute the range to write (overflow-checked).
+ start := rw.off
+ end := rw.off + srcs.NumBytes()
+ if end <= rw.off {
+ end = math.MaxInt64
+ }
+
+ var (
+ done uint64
+ retErr error
+ )
+ seg, gap := rw.d.cache.Find(rw.off)
+ for rw.off < end {
+ mr := memmap.MappableRange{rw.off, end}
+ switch {
+ case seg.Ok():
+ // Get internal mappings from the cache.
+ segMR := seg.Range().Intersect(mr)
+ ims, err := mf.MapInternal(seg.FileRangeOf(segMR), usermem.Write)
+ if err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Copy to internal mappings.
+ n, err := safemem.CopySeq(ims, srcs)
+ done += n
+ rw.off += n
+ srcs = srcs.DropFirst64(n)
+ rw.d.dirty.MarkDirty(segMR)
+ if err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok():
+ // Write directly to the file. At present, we never fill the cache
+ // when writing, since doing so can convert small writes into
+ // inefficient read-modify-write cycles, and we have no mechanism
+ // for detecting or avoiding this.
+ gapMR := gap.Range().Intersect(mr)
+ gapSrcs := srcs.TakeFirst64(gapMR.Length())
+ n, err := rw.d.handle.writeFromBlocksAt(rw.ctx, gapSrcs, gapMR.Start)
+ done += n
+ rw.off += n
+ srcs = srcs.DropFirst64(n)
+ // Partial writes are fine. But we must stop writing.
+ if n != gapSrcs.NumBytes() || err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Continue.
+ seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{}
+ }
+ }
+exitLoop:
+ if rw.off > rw.d.size {
+ atomic.StoreUint64(&rw.d.size, rw.off)
+ // The remote file's size will implicitly be extended to the correct
+ // value when we write back to it.
+ }
+ // If InteropModeWritethrough is in effect, flush written data back to the
+ // remote filesystem.
+ if rw.d.fs.opts.interop == InteropModeWritethrough && done != 0 {
+ if err := fsutil.SyncDirty(rw.ctx, memmap.MappableRange{
+ Start: start,
+ End: rw.off,
+ }, &rw.d.cache, &rw.d.dirty, rw.d.size, mf, rw.d.handle.writeFromBlocksAt); err != nil {
+ // We have no idea how many bytes were actually flushed.
+ rw.off = start
+ done = 0
+ retErr = err
+ }
+ }
+ rw.d.dataMu.Unlock()
+ rw.d.handleMu.RUnlock()
+ return done, retErr
+}
+
+func (d *dentry) writeback(ctx context.Context, offset, size int64) error {
+ if size == 0 {
+ return nil
+ }
+ d.handleMu.RLock()
+ defer d.handleMu.RUnlock()
+ d.dataMu.Lock()
+ defer d.dataMu.Unlock()
+ // Compute the range of valid bytes (overflow-checked).
+ if uint64(offset) >= d.size {
+ return nil
+ }
+ end := int64(d.size)
+ if rend := offset + size; rend > offset && rend < end {
+ end = rend
+ }
+ return fsutil.SyncDirty(ctx, memmap.MappableRange{
+ Start: uint64(offset),
+ End: uint64(end),
+ }, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt)
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ switch whence {
+ case linux.SEEK_SET:
+ // Use offset as specified.
+ case linux.SEEK_CUR:
+ offset += fd.off
+ case linux.SEEK_END, linux.SEEK_DATA, linux.SEEK_HOLE:
+ // Ensure file size is up to date.
+ d := fd.dentry()
+ if fd.filesystem().opts.interop == InteropModeShared {
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return 0, err
+ }
+ }
+ size := int64(atomic.LoadUint64(&d.size))
+ // For SEEK_DATA and SEEK_HOLE, treat the file as a single contiguous
+ // block of data.
+ switch whence {
+ case linux.SEEK_END:
+ offset += size
+ case linux.SEEK_DATA:
+ if offset > size {
+ return 0, syserror.ENXIO
+ }
+ // Use offset as specified.
+ case linux.SEEK_HOLE:
+ if offset > size {
+ return 0, syserror.ENXIO
+ }
+ offset = size
+ }
+ default:
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ fd.off = offset
+ return offset, nil
+}
+
+// Sync implements vfs.FileDescriptionImpl.Sync.
+func (fd *regularFileFD) Sync(ctx context.Context) error {
+ return fd.dentry().syncSharedHandle(ctx)
+}
+
+func (d *dentry) syncSharedHandle(ctx context.Context) error {
+ d.handleMu.RLock()
+ if !d.handleWritable {
+ d.handleMu.RUnlock()
+ return nil
+ }
+ d.dataMu.Lock()
+ // Write dirty cached data to the remote file.
+ err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt)
+ d.dataMu.Unlock()
+ if err == nil {
+ // Sync the remote file.
+ err = d.handle.sync(ctx)
+ }
+ d.handleMu.RUnlock()
+ return err
+}
+
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ d := fd.dentry()
+ switch d.fs.opts.interop {
+ case InteropModeExclusive:
+ // Any mapping is fine.
+ case InteropModeWritethrough:
+ // Shared writable mappings require a host FD, since otherwise we can't
+ // synchronously flush memory-mapped writes to the remote file.
+ if opts.Private || !opts.MaxPerms.Write {
+ break
+ }
+ fallthrough
+ case InteropModeShared:
+ // All mappings require a host FD to be coherent with other filesystem
+ // users.
+ if d.fs.opts.forcePageCache {
+ // Whether or not we have a host FD, we're not allowed to use it.
+ return syserror.ENODEV
+ }
+ d.handleMu.RLock()
+ haveFD := d.handle.fd >= 0
+ d.handleMu.RUnlock()
+ if !haveFD {
+ return syserror.ENODEV
+ }
+ default:
+ panic(fmt.Sprintf("unknown InteropMode %v", d.fs.opts.interop))
+ }
+ // After this point, d may be used as a memmap.Mappable.
+ d.pf.hostFileMapperInitOnce.Do(d.pf.hostFileMapper.Init)
+ return vfs.GenericConfigureMMap(&fd.vfsfd, d, opts)
+}
+
+func (d *dentry) mayCachePages() bool {
+ if d.fs.opts.interop == InteropModeShared {
+ return false
+ }
+ if d.fs.opts.forcePageCache {
+ return true
+ }
+ d.handleMu.RLock()
+ haveFD := d.handle.fd >= 0
+ d.handleMu.RUnlock()
+ return haveFD
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (d *dentry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+ d.mapsMu.Lock()
+ mapped := d.mappings.AddMapping(ms, ar, offset, writable)
+ // Do this unconditionally since whether we have a host FD can change
+ // across save/restore.
+ for _, r := range mapped {
+ d.pf.hostFileMapper.IncRefOn(r)
+ }
+ if d.mayCachePages() {
+ // d.Evict() will refuse to evict memory-mapped pages, so tell the
+ // MemoryFile to not bother trying.
+ mf := d.fs.mfp.MemoryFile()
+ for _, r := range mapped {
+ mf.MarkUnevictable(d, pgalloc.EvictableRange{r.Start, r.End})
+ }
+ }
+ d.mapsMu.Unlock()
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (d *dentry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+ d.mapsMu.Lock()
+ unmapped := d.mappings.RemoveMapping(ms, ar, offset, writable)
+ for _, r := range unmapped {
+ d.pf.hostFileMapper.DecRefOn(r)
+ }
+ if d.mayCachePages() {
+ // Pages that are no longer referenced by any application memory
+ // mappings are now considered unused; allow MemoryFile to evict them
+ // when necessary.
+ mf := d.fs.mfp.MemoryFile()
+ d.dataMu.Lock()
+ for _, r := range unmapped {
+ // Since these pages are no longer mapped, they are no longer
+ // concurrently dirtyable by a writable memory mapping.
+ d.dirty.AllowClean(r)
+ mf.MarkEvictable(d, pgalloc.EvictableRange{r.Start, r.End})
+ }
+ d.dataMu.Unlock()
+ }
+ d.mapsMu.Unlock()
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+ return d.AddMapping(ctx, ms, dstAR, offset, writable)
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (d *dentry) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ d.handleMu.RLock()
+ if d.handle.fd >= 0 && !d.fs.opts.forcePageCache {
+ d.handleMu.RUnlock()
+ mr := optional
+ if d.fs.opts.limitHostFDTranslation {
+ mr = maxFillRange(required, optional)
+ }
+ return []memmap.Translation{
+ {
+ Source: mr,
+ File: &d.pf,
+ Offset: mr.Start,
+ Perms: usermem.AnyAccess,
+ },
+ }, nil
+ }
+
+ d.dataMu.Lock()
+
+ // Constrain translations to d.size (rounded up) to prevent translation to
+ // pages that may be concurrently truncated.
+ pgend := pageRoundUp(d.size)
+ var beyondEOF bool
+ if required.End > pgend {
+ if required.Start >= pgend {
+ d.dataMu.Unlock()
+ d.handleMu.RUnlock()
+ return nil, &memmap.BusError{io.EOF}
+ }
+ beyondEOF = true
+ required.End = pgend
+ }
+ if optional.End > pgend {
+ optional.End = pgend
+ }
+
+ mf := d.fs.mfp.MemoryFile()
+ cerr := d.cache.Fill(ctx, required, maxFillRange(required, optional), mf, usage.PageCache, d.handle.readToBlocksAt)
+
+ var ts []memmap.Translation
+ var translatedEnd uint64
+ for seg := d.cache.FindSegment(required.Start); seg.Ok() && seg.Start() < required.End; seg, _ = seg.NextNonEmpty() {
+ segMR := seg.Range().Intersect(optional)
+ // TODO(jamieliu): Make Translations writable even if writability is
+ // not required if already kept-dirty by another writable translation.
+ perms := usermem.AccessType{
+ Read: true,
+ Execute: true,
+ }
+ if at.Write {
+ // From this point forward, this memory can be dirtied through the
+ // mapping at any time.
+ d.dirty.KeepDirty(segMR)
+ perms.Write = true
+ }
+ ts = append(ts, memmap.Translation{
+ Source: segMR,
+ File: mf,
+ Offset: seg.FileRangeOf(segMR).Start,
+ Perms: perms,
+ })
+ translatedEnd = segMR.End
+ }
+
+ d.dataMu.Unlock()
+ d.handleMu.RUnlock()
+
+ // Don't return the error returned by c.cache.Fill if it occurred outside
+ // of required.
+ if translatedEnd < required.End && cerr != nil {
+ return ts, &memmap.BusError{cerr}
+ }
+ if beyondEOF {
+ return ts, &memmap.BusError{io.EOF}
+ }
+ return ts, nil
+}
+
+func maxFillRange(required, optional memmap.MappableRange) memmap.MappableRange {
+ const maxReadahead = 64 << 10 // 64 KB, chosen arbitrarily
+ if required.Length() >= maxReadahead {
+ return required
+ }
+ if optional.Length() <= maxReadahead {
+ return optional
+ }
+ optional.Start = required.Start
+ if optional.Length() <= maxReadahead {
+ return optional
+ }
+ optional.End = optional.Start + maxReadahead
+ return optional
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (d *dentry) InvalidateUnsavable(ctx context.Context) error {
+ // Whether we have a host fd (and consequently what platform.File is
+ // mapped) can change across save/restore, so invalidate all translations
+ // unconditionally.
+ d.mapsMu.Lock()
+ defer d.mapsMu.Unlock()
+ d.mappings.InvalidateAll(memmap.InvalidateOpts{})
+
+ // Write the cache's contents back to the remote file so that if we have a
+ // host fd after restore, the remote file's contents are coherent.
+ mf := d.fs.mfp.MemoryFile()
+ d.dataMu.Lock()
+ defer d.dataMu.Unlock()
+ if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, d.handle.writeFromBlocksAt); err != nil {
+ return err
+ }
+
+ // Discard the cache so that it's not stored in saved state. This is safe
+ // because per InvalidateUnsavable invariants, no new translations can have
+ // been returned after we invalidated all existing translations above.
+ d.cache.DropAll(mf)
+ d.dirty.RemoveAll()
+
+ return nil
+}
+
+// Evict implements pgalloc.EvictableMemoryUser.Evict.
+func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) {
+ d.mapsMu.Lock()
+ defer d.mapsMu.Unlock()
+ d.dataMu.Lock()
+ defer d.dataMu.Unlock()
+
+ mr := memmap.MappableRange{er.Start, er.End}
+ mf := d.fs.mfp.MemoryFile()
+ // Only allow pages that are no longer memory-mapped to be evicted.
+ for mgap := d.mappings.LowerBoundGap(mr.Start); mgap.Ok() && mgap.Start() < mr.End; mgap = mgap.NextGap() {
+ mgapMR := mgap.Range().Intersect(mr)
+ if mgapMR.Length() == 0 {
+ continue
+ }
+ if err := fsutil.SyncDirty(ctx, mgapMR, &d.cache, &d.dirty, d.size, mf, d.handle.writeFromBlocksAt); err != nil {
+ log.Warningf("Failed to writeback cached data %v: %v", mgapMR, err)
+ }
+ d.cache.Drop(mgapMR, mf)
+ d.dirty.KeepClean(mgapMR)
+ }
+}
+
+// dentryPlatformFile implements platform.File. It exists solely because dentry
+// cannot implement both vfs.DentryImpl.IncRef and platform.File.IncRef.
+//
+// dentryPlatformFile is only used when a host FD representing the remote file
+// is available (i.e. dentry.handle.fd >= 0), and that FD is used for
+// application memory mappings (i.e. !filesystem.opts.forcePageCache).
+type dentryPlatformFile struct {
+ *dentry
+
+ // fdRefs counts references on platform.File offsets. fdRefs is protected
+ // by dentry.dataMu.
+ fdRefs fsutil.FrameRefSet
+
+ // If this dentry represents a regular file, and handle.fd >= 0,
+ // hostFileMapper caches mappings of handle.fd.
+ hostFileMapper fsutil.HostFileMapper
+
+ // hostFileMapperInitOnce is used to lazily initialize hostFileMapper.
+ hostFileMapperInitOnce sync.Once
+}
+
+// IncRef implements platform.File.IncRef.
+func (d *dentryPlatformFile) IncRef(fr platform.FileRange) {
+ d.dataMu.Lock()
+ seg, gap := d.fdRefs.Find(fr.Start)
+ for {
+ switch {
+ case seg.Ok() && seg.Start() < fr.End:
+ seg = d.fdRefs.Isolate(seg, fr)
+ seg.SetValue(seg.Value() + 1)
+ seg, gap = seg.NextNonEmpty()
+ case gap.Ok() && gap.Start() < fr.End:
+ newRange := gap.Range().Intersect(fr)
+ usage.MemoryAccounting.Inc(newRange.Length(), usage.Mapped)
+ seg, gap = d.fdRefs.InsertWithoutMerging(gap, newRange, 1).NextNonEmpty()
+ default:
+ d.fdRefs.MergeAdjacent(fr)
+ d.dataMu.Unlock()
+ return
+ }
+ }
+}
+
+// DecRef implements platform.File.DecRef.
+func (d *dentryPlatformFile) DecRef(fr platform.FileRange) {
+ d.dataMu.Lock()
+ seg := d.fdRefs.FindSegment(fr.Start)
+
+ for seg.Ok() && seg.Start() < fr.End {
+ seg = d.fdRefs.Isolate(seg, fr)
+ if old := seg.Value(); old == 1 {
+ usage.MemoryAccounting.Dec(seg.Range().Length(), usage.Mapped)
+ seg = d.fdRefs.Remove(seg).NextSegment()
+ } else {
+ seg.SetValue(old - 1)
+ seg = seg.NextSegment()
+ }
+ }
+ d.fdRefs.MergeAdjacent(fr)
+ d.dataMu.Unlock()
+
+}
+
+// MapInternal implements platform.File.MapInternal.
+func (d *dentryPlatformFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+ d.handleMu.RLock()
+ bs, err := d.hostFileMapper.MapInternal(fr, int(d.handle.fd), at.Write)
+ d.handleMu.RUnlock()
+ return bs, err
+}
+
+// FD implements platform.File.FD.
+func (d *dentryPlatformFile) FD() int {
+ d.handleMu.RLock()
+ fd := d.handle.fd
+ d.handleMu.RUnlock()
+ return int(fd)
+}
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
new file mode 100644
index 000000000..08c691c47
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -0,0 +1,159 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// specialFileFD implements vfs.FileDescriptionImpl for files other than
+// regular files, directories, and symlinks: pipes, sockets, etc. It is also
+// used for regular files when filesystemOptions.specialRegularFiles is in
+// effect. specialFileFD differs from regularFileFD by using per-FD handles
+// instead of shared per-dentry handles, and never buffering I/O.
+type specialFileFD struct {
+ fileDescription
+
+ // handle is immutable.
+ handle handle
+
+ // off is the file offset. off is protected by mu. (POSIX 2.9.7 only
+ // requires operations using the file offset to be atomic for regular files
+ // and symlinks; however, since specialFileFD may be used for regular
+ // files, we apply this atomicity unconditionally.)
+ mu sync.Mutex
+ off int64
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *specialFileFD) Release() {
+ fd.handle.close(context.Background())
+ fs := fd.vfsfd.Mount().Filesystem().Impl().(*filesystem)
+ fs.syncMu.Lock()
+ delete(fs.specialFileFDs, fd)
+ fs.syncMu.Unlock()
+}
+
+// OnClose implements vfs.FileDescriptionImpl.OnClose.
+func (fd *specialFileFD) OnClose(ctx context.Context) error {
+ if !fd.vfsfd.IsWritable() {
+ return nil
+ }
+ return fd.handle.file.flush(ctx)
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ // Going through dst.CopyOutFrom() holds MM locks around file operations of
+ // unknown duration. For regularFileFD, doing so is necessary to support
+ // mmap due to lock ordering; MM locks precede dentry.dataMu. That doesn't
+ // hold here since specialFileFD doesn't client-cache data. Just buffer the
+ // read instead.
+ if d := fd.dentry(); d.fs.opts.interop != InteropModeShared {
+ d.touchAtime(ctx, fd.vfsfd.Mount())
+ }
+ buf := make([]byte, dst.NumBytes())
+ n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset))
+ if n == 0 {
+ return 0, err
+ }
+ if cp, cperr := dst.CopyOut(ctx, buf[:n]); cperr != nil {
+ return int64(cp), cperr
+ }
+ return int64(n), err
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *specialFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ fd.mu.Lock()
+ n, err := fd.PRead(ctx, dst, fd.off, opts)
+ fd.off += n
+ fd.mu.Unlock()
+ return n, err
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *specialFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ // Do a buffered write. See rationale in PRead.
+ if d := fd.dentry(); d.fs.opts.interop != InteropModeShared {
+ d.touchCMtime(ctx)
+ }
+ buf := make([]byte, src.NumBytes())
+ // Don't do partial writes if we get a partial read from src.
+ if _, err := src.CopyIn(ctx, buf); err != nil {
+ return 0, err
+ }
+ n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset))
+ return int64(n), err
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *specialFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ fd.mu.Lock()
+ n, err := fd.PWrite(ctx, src, fd.off, opts)
+ fd.off += n
+ fd.mu.Unlock()
+ return n, err
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *specialFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ switch whence {
+ case linux.SEEK_SET:
+ // Use offset as given.
+ case linux.SEEK_CUR:
+ offset += fd.off
+ default:
+ // SEEK_END, SEEK_DATA, and SEEK_HOLE aren't supported since it's not
+ // clear that file size is even meaningful for these files.
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ fd.off = offset
+ return offset, nil
+}
+
+// Sync implements vfs.FileDescriptionImpl.Sync.
+func (fd *specialFileFD) Sync(ctx context.Context) error {
+ if !fd.vfsfd.IsWritable() {
+ return nil
+ }
+ return fd.handle.sync(ctx)
+}
diff --git a/pkg/sentry/fsimpl/gofer/symlink.go b/pkg/sentry/fsimpl/gofer/symlink.go
new file mode 100644
index 000000000..adf43be60
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/symlink.go
@@ -0,0 +1,47 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+func (d *dentry) isSymlink() bool {
+ return d.fileType() == linux.S_IFLNK
+}
+
+// Precondition: d.isSymlink().
+func (d *dentry) readlink(ctx context.Context, mnt *vfs.Mount) (string, error) {
+ if d.fs.opts.interop != InteropModeShared {
+ d.touchAtime(ctx, mnt)
+ d.dataMu.Lock()
+ if d.haveTarget {
+ target := d.target
+ d.dataMu.Unlock()
+ return target, nil
+ }
+ }
+ target, err := d.file.readlink(ctx)
+ if d.fs.opts.interop != InteropModeShared {
+ if err == nil {
+ d.haveTarget = true
+ d.target = target
+ }
+ d.dataMu.Unlock()
+ }
+ return target, err
+}
diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go
new file mode 100644
index 000000000..7598ec6a8
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/time.go
@@ -0,0 +1,75 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+func dentryTimestampFromP9(s, ns uint64) int64 {
+ return int64(s*1e9 + ns)
+}
+
+func dentryTimestampFromStatx(ts linux.StatxTimestamp) int64 {
+ return ts.Sec*1e9 + int64(ts.Nsec)
+}
+
+func statxTimestampFromDentry(ns int64) linux.StatxTimestamp {
+ return linux.StatxTimestamp{
+ Sec: ns / 1e9,
+ Nsec: uint32(ns % 1e9),
+ }
+}
+
+func nowFromContext(ctx context.Context) (int64, bool) {
+ if clock := ktime.RealtimeClockFromContext(ctx); clock != nil {
+ return clock.Now().Nanoseconds(), true
+ }
+ return 0, false
+}
+
+// Preconditions: fs.interop != InteropModeShared.
+func (d *dentry) touchAtime(ctx context.Context, mnt *vfs.Mount) {
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return
+ }
+ now, ok := nowFromContext(ctx)
+ if !ok {
+ mnt.EndWrite()
+ return
+ }
+ d.metadataMu.Lock()
+ atomic.StoreInt64(&d.atime, now)
+ d.metadataMu.Unlock()
+ mnt.EndWrite()
+}
+
+// Preconditions: fs.interop != InteropModeShared. The caller has successfully
+// called vfs.Mount.CheckBeginWrite().
+func (d *dentry) touchCMtime(ctx context.Context) {
+ now, ok := nowFromContext(ctx)
+ if !ok {
+ return
+ }
+ d.metadataMu.Lock()
+ atomic.StoreInt64(&d.mtime, now)
+ atomic.StoreInt64(&d.ctime, now)
+ d.metadataMu.Unlock()
+}
diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
index 733792c78..d092ccb2a 100644
--- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
+++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
@@ -53,9 +53,9 @@ func (f *DynamicBytesFile) Init(creds *auth.Credentials, ino uint64, data vfs.Dy
}
// Open implements Inode.Open.
-func (f *DynamicBytesFile) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+func (f *DynamicBytesFile) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &DynamicBytesFD{}
- if err := fd.Init(rp.Mount(), vfsd, f.data, flags); err != nil {
+ if err := fd.Init(rp.Mount(), vfsd, f.data, opts.Flags); err != nil {
return nil, err
}
return &fd.vfsfd, nil
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
index 6104751c8..5650512e0 100644
--- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -43,12 +43,12 @@ type GenericDirectoryFD struct {
}
// Init initializes a GenericDirectoryFD.
-func (fd *GenericDirectoryFD) Init(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, flags uint32) error {
- if vfs.AccessTypesForOpenFlags(flags)&vfs.MayWrite != 0 {
+func (fd *GenericDirectoryFD) Init(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, opts *vfs.OpenOptions) error {
+ if vfs.AccessTypesForOpenFlags(opts)&vfs.MayWrite != 0 {
// Can't open directories for writing.
return syserror.EISDIR
}
- if err := fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{}); err != nil {
+ if err := fd.vfsfd.Init(fd, opts.Flags, m, d, &vfs.FileDescriptionOptions{}); err != nil {
return err
}
fd.children = children
@@ -116,8 +116,8 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
Ino: stat.Ino,
NextOff: 1,
}
- if !cb.Handle(dirent) {
- return nil
+ if err := cb.Handle(dirent); err != nil {
+ return err
}
fd.off++
}
@@ -132,8 +132,8 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
Ino: stat.Ino,
NextOff: 2,
}
- if !cb.Handle(dirent) {
- return nil
+ if err := cb.Handle(dirent); err != nil {
+ return err
}
fd.off++
}
@@ -153,8 +153,8 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent
Ino: stat.Ino,
NextOff: fd.off + 1,
}
- if !cb.Handle(dirent) {
- return nil
+ if err := cb.Handle(dirent); err != nil {
+ return err
}
fd.off++
}
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index 9d65d0179..292f58afd 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -111,10 +111,10 @@ func (fs *Filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir
// Dentry isn't cached; it either doesn't exist or failed
// revalidation. Attempt to resolve it via Lookup.
//
- // FIXME(b/144498111): Inode.Lookup() should return *(kernfs.)Dentry,
- // not *vfs.Dentry, since (kernfs.)Filesystem assumes that all dentries
- // in the filesystem are (kernfs.)Dentry and performs vfs.DentryImpl
- // casts accordingly.
+ // FIXME(gvisor.dev/issue/1193): Inode.Lookup() should return
+ // *(kernfs.)Dentry, not *vfs.Dentry, since (kernfs.)Filesystem assumes
+ // that all dentries in the filesystem are (kernfs.)Dentry and performs
+ // vfs.DentryImpl casts accordingly.
var err error
childVFSD, err = parent.inode.Lookup(ctx, name)
if err != nil {
@@ -365,7 +365,7 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
// appropriate bits in rp), but are returned by
// FileDescriptionImpl.StatusFlags().
opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW
- ats := vfs.AccessTypesForOpenFlags(opts.Flags)
+ ats := vfs.AccessTypesForOpenFlags(&opts)
// Do not create new file.
if opts.Flags&linux.O_CREAT == 0 {
@@ -379,7 +379,7 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if err := inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
return nil, err
}
- return inode.Open(rp, vfsd, opts.Flags)
+ return inode.Open(rp, vfsd, opts)
}
// May create new file.
@@ -398,7 +398,7 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if err := inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
return nil, err
}
- return inode.Open(rp, vfsd, opts.Flags)
+ return inode.Open(rp, vfsd, opts)
}
afterTrailingSymlink:
parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
@@ -438,7 +438,7 @@ afterTrailingSymlink:
return nil, err
}
parentVFSD.Impl().(*Dentry).InsertChild(pc, child)
- return child.Impl().(*Dentry).inode.Open(rp, child, opts.Flags)
+ return child.Impl().(*Dentry).inode.Open(rp, child, opts)
}
// Open existing file or follow symlink.
if mustCreate {
@@ -463,7 +463,7 @@ afterTrailingSymlink:
if err := childInode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
return nil, err
}
- return childInode.Open(rp, childVFSD, opts.Flags)
+ return childInode.Open(rp, childVFSD, opts)
}
// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
@@ -544,6 +544,7 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
}
mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
virtfs := rp.VirtualFilesystem()
srcDirDentry := srcDirVFSD.Impl().(*Dentry)
@@ -595,7 +596,10 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
parentDentry := vfsd.Parent().Impl().(*Dentry)
parentDentry.dirMu.Lock()
defer parentDentry.dirMu.Unlock()
- if err := virtfs.PrepareDeleteDentry(vfs.MountNamespaceFromContext(ctx), vfsd); err != nil {
+
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil {
return err
}
if err := parentDentry.inode.RmDir(ctx, rp.Component(), vfsd); err != nil {
@@ -697,7 +701,9 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
parentDentry := vfsd.Parent().Impl().(*Dentry)
parentDentry.dirMu.Lock()
defer parentDentry.dirMu.Unlock()
- if err := virtfs.PrepareDeleteDentry(vfs.MountNamespaceFromContext(ctx), vfsd); err != nil {
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil {
return err
}
if err := parentDentry.inode.Unlink(ctx, rp.Component(), vfsd); err != nil {
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index adca2313f..099d70a16 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -507,7 +507,7 @@ type InodeSymlink struct {
}
// Open implements Inode.Open.
-func (InodeSymlink) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+func (InodeSymlink) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
return nil, syserror.ELOOP
}
@@ -549,8 +549,8 @@ func (s *StaticDirectory) Init(creds *auth.Credentials, ino uint64, perm linux.F
}
// Open implements kernfs.Inode.
-func (s *StaticDirectory) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+func (s *StaticDirectory) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &GenericDirectoryFD{}
- fd.Init(rp.Mount(), vfsd, &s.OrderedChildren, flags)
+ fd.Init(rp.Mount(), vfsd, &s.OrderedChildren, &opts)
return fd.VFSFileDescription(), nil
}
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index 79ebea8a5..c74fa999b 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -303,7 +303,7 @@ type Inode interface {
// inode for its lifetime.
//
// Precondition: !rp.Done(). vfsd.Impl() must be a kernfs Dentry.
- Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error)
+ Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error)
}
type inodeRefs interface {
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
index ee65cf491..0459fb305 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -45,7 +45,10 @@ type RootDentryFn func(*auth.Credentials, *filesystem) *kernfs.Dentry
func newTestSystem(t *testing.T, rootFn RootDentryFn) *testutil.System {
ctx := contexttest.Context(t)
creds := auth.CredentialsFromContext(ctx)
- v := vfs.New()
+ v := &vfs.VirtualFilesystem{}
+ if err := v.Init(); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
v.MustRegisterFilesystemType("testfs", &fsType{rootFn: rootFn}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
@@ -113,9 +116,9 @@ func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMod
return &dir.dentry
}
-func (d *readonlyDir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+func (d *readonlyDir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &kernfs.GenericDirectoryFD{}
- if err := fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, flags); err != nil {
+ if err := fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, &opts); err != nil {
return nil, err
}
return fd.VFSFileDescription(), nil
@@ -143,9 +146,9 @@ func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, conte
return &dir.dentry
}
-func (d *dir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+func (d *dir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &kernfs.GenericDirectoryFD{}
- fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, flags)
+ fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, &opts)
return fd.VFSFileDescription(), nil
}
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
index 12aac2e6a..a83245866 100644
--- a/pkg/sentry/fsimpl/proc/BUILD
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -14,6 +14,7 @@ go_library(
"tasks_net.go",
"tasks_sys.go",
],
+ visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/context",
diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go
index 11477b6a9..5c19d5522 100644
--- a/pkg/sentry/fsimpl/proc/filesystem.go
+++ b/pkg/sentry/fsimpl/proc/filesystem.go
@@ -26,15 +26,18 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
)
-// procFSType is the factory class for procfs.
+// Name is the default filesystem name.
+const Name = "proc"
+
+// FilesystemType is the factory class for procfs.
//
// +stateify savable
-type procFSType struct{}
+type FilesystemType struct{}
-var _ vfs.FilesystemType = (*procFSType)(nil)
+var _ vfs.FilesystemType = (*FilesystemType)(nil)
// GetFilesystem implements vfs.FilesystemType.
-func (ft *procFSType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+func (ft *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
k := kernel.KernelFromContext(ctx)
if k == nil {
return nil, nil, fmt.Errorf("procfs requires a kernel")
@@ -47,12 +50,13 @@ func (ft *procFSType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFile
procfs := &kernfs.Filesystem{}
procfs.VFSFilesystem().Init(vfsObj, procfs)
- var data *InternalData
+ var cgroups map[string]string
if opts.InternalData != nil {
- data = opts.InternalData.(*InternalData)
+ data := opts.InternalData.(*InternalData)
+ cgroups = data.Cgroups
}
- _, dentry := newTasksInode(procfs, k, pidns, data.Cgroups)
+ _, dentry := newTasksInode(procfs, k, pidns, cgroups)
return procfs.VFSFilesystem(), dentry.VFSDentry(), nil
}
diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go
index 353e37195..f3f4e49b4 100644
--- a/pkg/sentry/fsimpl/proc/subtasks.go
+++ b/pkg/sentry/fsimpl/proc/subtasks.go
@@ -105,8 +105,8 @@ func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallb
Ino: i.inoGen.NextIno(),
NextOff: offset + 1,
}
- if !cb.Handle(dirent) {
- return offset, nil
+ if err := cb.Handle(dirent); err != nil {
+ return offset, err
}
offset++
}
@@ -114,9 +114,9 @@ func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallb
}
// Open implements kernfs.Inode.
-func (i *subtasksInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+func (i *subtasksInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &kernfs.GenericDirectoryFD{}
- fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, flags)
+ fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, &opts)
return fd.VFSFileDescription(), nil
}
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index eb5bc62c0..2d814668a 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -98,9 +98,9 @@ func (i *taskInode) Valid(ctx context.Context) bool {
}
// Open implements kernfs.Inode.
-func (i *taskInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+func (i *taskInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &kernfs.GenericDirectoryFD{}
- fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, flags)
+ fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, &opts)
return fd.VFSFileDescription(), nil
}
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index 14bd334e8..10c08fa90 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -73,9 +73,9 @@ func newTasksInode(inoGen InoGenerator, k *kernel.Kernel, pidns *kernel.PIDNames
"meminfo": newDentry(root, inoGen.NextIno(), 0444, &meminfoData{}),
"mounts": kernfs.NewStaticSymlink(root, inoGen.NextIno(), "self/mounts"),
"net": newNetDir(root, inoGen, k),
- "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{}),
+ "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{k: k}),
"uptime": newDentry(root, inoGen.NextIno(), 0444, &uptimeData{}),
- "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{}),
+ "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{k: k}),
}
inode := &tasksInode{
@@ -151,8 +151,8 @@ func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback
Ino: i.inoGen.NextIno(),
NextOff: offset + 1,
}
- if !cb.Handle(dirent) {
- return offset, nil
+ if err := cb.Handle(dirent); err != nil {
+ return offset, err
}
offset++
}
@@ -163,8 +163,8 @@ func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback
Ino: i.inoGen.NextIno(),
NextOff: offset + 1,
}
- if !cb.Handle(dirent) {
- return offset, nil
+ if err := cb.Handle(dirent); err != nil {
+ return offset, err
}
offset++
}
@@ -196,8 +196,8 @@ func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback
Ino: i.inoGen.NextIno(),
NextOff: FIRST_PROCESS_ENTRY + 2 + int64(tid) + 1,
}
- if !cb.Handle(dirent) {
- return offset, nil
+ if err := cb.Handle(dirent); err != nil {
+ return offset, err
}
offset++
}
@@ -205,9 +205,9 @@ func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback
}
// Open implements kernfs.Inode.
-func (i *tasksInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+func (i *tasksInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &kernfs.GenericDirectoryFD{}
- fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, flags)
+ fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, &opts)
return fd.VFSFileDescription(), nil
}
diff --git a/pkg/sentry/fsimpl/proc/tasks_net.go b/pkg/sentry/fsimpl/proc/tasks_net.go
index 608fec017..d4e1812d8 100644
--- a/pkg/sentry/fsimpl/proc/tasks_net.go
+++ b/pkg/sentry/fsimpl/proc/tasks_net.go
@@ -39,7 +39,10 @@ import (
func newNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry {
var contents map[string]*kernfs.Dentry
- if stack := k.NetworkStack(); stack != nil {
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process. We should make this per-process,
+ // a.k.a. /proc/PID/net, and make /proc/net a symlink to /proc/self/net.
+ if stack := k.RootNetworkNamespace().Stack(); stack != nil {
const (
arp = "IP address HW type Flags HW address Mask Device\n"
netlink = "sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n"
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index c7ce74883..3d5dc463c 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -50,7 +50,9 @@ func newSysDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *k
func newSysNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry {
var contents map[string]*kernfs.Dentry
- if stack := k.NetworkStack(); stack != nil {
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process.
+ if stack := k.RootNetworkNamespace().Stack(); stack != nil {
contents = map[string]*kernfs.Dentry{
"ipv4": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{
"tcp_sack": newDentry(root, inoGen.NextIno(), 0644, &tcpSackData{stack: stack}),
diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go
index 6fc3524db..c5d531fe0 100644
--- a/pkg/sentry/fsimpl/proc/tasks_test.go
+++ b/pkg/sentry/fsimpl/proc/tasks_test.go
@@ -90,8 +90,7 @@ func setup(t *testing.T) *testutil.System {
ctx := k.SupervisorContext()
creds := auth.CredentialsFromContext(ctx)
- vfsObj := vfs.New()
- vfsObj.MustRegisterFilesystemType("procfs", &procFSType{}, &vfs.RegisterFilesystemTypeOptions{
+ k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
fsOpts := vfs.GetFilesystemOptions{
@@ -102,11 +101,11 @@ func setup(t *testing.T) *testutil.System {
},
},
}
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "procfs", &fsOpts)
+ mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", Name, &fsOpts)
if err != nil {
t.Fatalf("NewMountNamespace(): %v", err)
}
- return testutil.NewSystem(ctx, t, vfsObj, mntns)
+ return testutil.NewSystem(ctx, t, k.VFS(), mntns)
}
func TestTasksEmpty(t *testing.T) {
@@ -131,7 +130,7 @@ func TestTasks(t *testing.T) {
var tasks []*kernel.Task
for i := 0; i < 5; i++ {
tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
- task, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc)
+ task, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc, s.MntNs, s.Root, s.Root)
if err != nil {
t.Fatalf("CreateTask(): %v", err)
}
@@ -213,7 +212,7 @@ func TestTasksOffset(t *testing.T) {
k := kernel.KernelFromContext(s.Ctx)
for i := 0; i < 3; i++ {
tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
- if _, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc); err != nil {
+ if _, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc, s.MntNs, s.Root, s.Root); err != nil {
t.Fatalf("CreateTask(): %v", err)
}
}
@@ -337,7 +336,7 @@ func TestTask(t *testing.T) {
k := kernel.KernelFromContext(s.Ctx)
tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
- _, err := testutil.CreateTask(s.Ctx, "name", tc)
+ _, err := testutil.CreateTask(s.Ctx, "name", tc, s.MntNs, s.Root, s.Root)
if err != nil {
t.Fatalf("CreateTask(): %v", err)
}
@@ -352,7 +351,7 @@ func TestProcSelf(t *testing.T) {
k := kernel.KernelFromContext(s.Ctx)
tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
- task, err := testutil.CreateTask(s.Ctx, "name", tc)
+ task, err := testutil.CreateTask(s.Ctx, "name", tc, s.MntNs, s.Root, s.Root)
if err != nil {
t.Fatalf("CreateTask(): %v", err)
}
@@ -433,7 +432,7 @@ func TestTree(t *testing.T) {
var tasks []*kernel.Task
for i := 0; i < 5; i++ {
tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
- task, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc)
+ task, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc, s.MntNs, s.Root, s.Root)
if err != nil {
t.Fatalf("CreateTask(): %v", err)
}
diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD
index 66c0d8bc8..a741e2bb6 100644
--- a/pkg/sentry/fsimpl/sys/BUILD
+++ b/pkg/sentry/fsimpl/sys/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"sys.go",
],
+ visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/context",
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
index e35d52d17..c36c4fa11 100644
--- a/pkg/sentry/fsimpl/sys/sys.go
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -28,6 +28,9 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// Name is the default filesystem name.
+const Name = "sysfs"
+
// FilesystemType implements vfs.FilesystemType.
type FilesystemType struct{}
@@ -97,9 +100,9 @@ func (d *dir) SetStat(fs *vfs.Filesystem, opts vfs.SetStatOptions) error {
}
// Open implements kernfs.Inode.Open.
-func (d *dir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+func (d *dir) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
fd := &kernfs.GenericDirectoryFD{}
- fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, flags)
+ fd.Init(rp.Mount(), vfsd, &d.OrderedChildren, &opts)
return fd.VFSFileDescription(), nil
}
diff --git a/pkg/sentry/fsimpl/sys/sys_test.go b/pkg/sentry/fsimpl/sys/sys_test.go
index 8b1cf0bd0..4b3602d47 100644
--- a/pkg/sentry/fsimpl/sys/sys_test.go
+++ b/pkg/sentry/fsimpl/sys/sys_test.go
@@ -34,16 +34,15 @@ func newTestSystem(t *testing.T) *testutil.System {
}
ctx := k.SupervisorContext()
creds := auth.CredentialsFromContext(ctx)
- v := vfs.New()
- v.MustRegisterFilesystemType("sysfs", sys.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ k.VFS().MustRegisterFilesystemType(sys.Name, sys.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
- mns, err := v.NewMountNamespace(ctx, creds, "", "sysfs", &vfs.GetFilesystemOptions{})
+ mns, err := k.VFS().NewMountNamespace(ctx, creds, "", sys.Name, &vfs.GetFilesystemOptions{})
if err != nil {
t.Fatalf("Failed to create new mount namespace: %v", err)
}
- return testutil.NewSystem(ctx, t, v, mns)
+ return testutil.NewSystem(ctx, t, k.VFS(), mns)
}
func TestReadCPUFile(t *testing.T) {
diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD
index efd5974c4..e4f36f4ae 100644
--- a/pkg/sentry/fsimpl/testutil/BUILD
+++ b/pkg/sentry/fsimpl/testutil/BUILD
@@ -16,7 +16,7 @@ go_library(
"//pkg/cpuid",
"//pkg/fspath",
"//pkg/memutil",
- "//pkg/sentry/fs",
+ "//pkg/sentry/fsimpl/tmpfs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/sched",
diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go
index 89f8c4915..488478e29 100644
--- a/pkg/sentry/fsimpl/testutil/kernel.go
+++ b/pkg/sentry/fsimpl/testutil/kernel.go
@@ -24,7 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/memutil"
- "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
@@ -33,6 +33,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
// Platforms are plugable.
_ "gvisor.dev/gvisor/pkg/sentry/platform/kvm"
@@ -99,36 +100,41 @@ func Boot() (*kernel.Kernel, error) {
return nil, fmt.Errorf("initializing kernel: %v", err)
}
- ctx := k.SupervisorContext()
+ kernel.VFS2Enabled = true
- // Create mount namespace without root as it's the minimum required to create
- // the global thread group.
- mntns, err := fs.NewMountNamespace(ctx, nil)
- if err != nil {
- return nil, err
+ if err := k.VFS().Init(); err != nil {
+ return nil, fmt.Errorf("VFS init: %v", err)
}
+ k.VFS().MustRegisterFilesystemType(tmpfs.Name, &tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+
ls, err := limits.NewLinuxLimitSet()
if err != nil {
return nil, err
}
- tg := k.NewThreadGroup(mntns, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, ls)
+ tg := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, ls)
k.TestOnly_SetGlobalInit(tg)
return k, nil
}
// CreateTask creates a new bare bones task for tests.
-func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup) (*kernel.Task, error) {
+func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns *vfs.MountNamespace, root, cwd vfs.VirtualDentry) (*kernel.Task, error) {
k := kernel.KernelFromContext(ctx)
config := &kernel.TaskConfig{
Kernel: k,
ThreadGroup: tc,
TaskContext: &kernel.TaskContext{Name: name},
Credentials: auth.CredentialsFromContext(ctx),
+ NetworkNamespace: k.RootNetworkNamespace(),
AllowedCPUMask: sched.NewFullCPUSet(k.ApplicationCores()),
UTSNamespace: kernel.UTSNamespaceFromContext(ctx),
IPCNamespace: kernel.IPCNamespaceFromContext(ctx),
AbstractSocketNamespace: kernel.NewAbstractSocketNamespace(),
+ MountNamespaceVFS2: mntns,
+ FSContext: kernel.NewFSContextVFS2(root, cwd, 0022),
}
return k.TaskSet().NewTask(config)
}
diff --git a/pkg/sentry/fsimpl/testutil/testutil.go b/pkg/sentry/fsimpl/testutil/testutil.go
index 1c98335c1..e16808c63 100644
--- a/pkg/sentry/fsimpl/testutil/testutil.go
+++ b/pkg/sentry/fsimpl/testutil/testutil.go
@@ -41,12 +41,12 @@ type System struct {
Creds *auth.Credentials
VFS *vfs.VirtualFilesystem
Root vfs.VirtualDentry
- mns *vfs.MountNamespace
+ MntNs *vfs.MountNamespace
}
// NewSystem constructs a System.
//
-// Precondition: Caller must hold a reference on mns, whose ownership
+// Precondition: Caller must hold a reference on MntNs, whose ownership
// is transferred to the new System.
func NewSystem(ctx context.Context, t *testing.T, v *vfs.VirtualFilesystem, mns *vfs.MountNamespace) *System {
s := &System{
@@ -54,7 +54,7 @@ func NewSystem(ctx context.Context, t *testing.T, v *vfs.VirtualFilesystem, mns
Ctx: ctx,
Creds: auth.CredentialsFromContext(ctx),
VFS: v,
- mns: mns,
+ MntNs: mns,
Root: mns.Root(),
}
return s
@@ -75,7 +75,7 @@ func (s *System) WithSubtest(t *testing.T) *System {
Ctx: s.Ctx,
Creds: s.Creds,
VFS: s.VFS,
- mns: s.mns,
+ MntNs: s.MntNs,
Root: s.Root,
}
}
@@ -90,7 +90,7 @@ func (s *System) WithTemporaryContext(ctx context.Context) *System {
Ctx: ctx,
Creds: s.Creds,
VFS: s.VFS,
- mns: s.mns,
+ MntNs: s.MntNs,
Root: s.Root,
}
}
@@ -98,7 +98,7 @@ func (s *System) WithTemporaryContext(ctx context.Context) *System {
// Destroy release resources associated with a test system.
func (s *System) Destroy() {
s.Root.DecRef()
- s.mns.DecRef(s.VFS) // Reference on mns passed to NewSystem.
+ s.MntNs.DecRef() // Reference on MntNs passed to NewSystem.
}
// ReadToEnd reads the contents of fd until EOF to a string.
@@ -226,7 +226,7 @@ func (d *DirentCollector) SkipDotsChecks(value bool) {
}
// Handle implements vfs.IterDirentsCallback.Handle.
-func (d *DirentCollector) Handle(dirent vfs.Dirent) bool {
+func (d *DirentCollector) Handle(dirent vfs.Dirent) error {
d.mu.Lock()
if d.dirents == nil {
d.dirents = make(map[string]*vfs.Dirent)
@@ -234,7 +234,7 @@ func (d *DirentCollector) Handle(dirent vfs.Dirent) bool {
d.order = append(d.order, &dirent)
d.dirents[dirent.Name] = &dirent
d.mu.Unlock()
- return true
+ return nil
}
// Count returns the number of dirents currently in the collector.
diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD
index c61366224..57abd5583 100644
--- a/pkg/sentry/fsimpl/tmpfs/BUILD
+++ b/pkg/sentry/fsimpl/tmpfs/BUILD
@@ -38,6 +38,7 @@ go_library(
"//pkg/sentry/arch",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/pipe",
@@ -47,6 +48,7 @@ go_library(
"//pkg/sentry/platform",
"//pkg/sentry/usage",
"//pkg/sentry/vfs",
+ "//pkg/sentry/vfs/lock",
"//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
@@ -86,6 +88,7 @@ go_test(
"//pkg/context",
"//pkg/fspath",
"//pkg/sentry/contexttest",
+ "//pkg/sentry/fs/lock",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/contexttest",
"//pkg/sentry/vfs",
diff --git a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
index 54241c8e8..383133e44 100644
--- a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
@@ -175,7 +175,10 @@ func BenchmarkVFS2MemfsStat(b *testing.B) {
creds := auth.CredentialsFromContext(ctx)
// Create VFS.
- vfsObj := vfs.New()
+ vfsObj := vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ b.Fatalf("VFS init: %v", err)
+ }
vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
@@ -183,7 +186,7 @@ func BenchmarkVFS2MemfsStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to create tmpfs root mount: %v", err)
}
- defer mntns.DecRef(vfsObj)
+ defer mntns.DecRef()
var filePathBuilder strings.Builder
filePathBuilder.WriteByte('/')
@@ -366,7 +369,10 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
creds := auth.CredentialsFromContext(ctx)
// Create VFS.
- vfsObj := vfs.New()
+ vfsObj := vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ b.Fatalf("VFS init: %v", err)
+ }
vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
@@ -374,7 +380,7 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to create tmpfs root mount: %v", err)
}
- defer mntns.DecRef(vfsObj)
+ defer mntns.DecRef()
var filePathBuilder strings.Builder
filePathBuilder.WriteByte('/')
diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go
index dc0d27cf9..b4380af38 100644
--- a/pkg/sentry/fsimpl/tmpfs/directory.go
+++ b/pkg/sentry/fsimpl/tmpfs/directory.go
@@ -74,25 +74,25 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
defer fs.mu.Unlock()
if fd.off == 0 {
- if !cb.Handle(vfs.Dirent{
+ if err := cb.Handle(vfs.Dirent{
Name: ".",
Type: linux.DT_DIR,
Ino: vfsd.Impl().(*dentry).inode.ino,
NextOff: 1,
- }) {
- return nil
+ }); err != nil {
+ return err
}
fd.off++
}
if fd.off == 1 {
parentInode := vfsd.ParentOrSelf().Impl().(*dentry).inode
- if !cb.Handle(vfs.Dirent{
+ if err := cb.Handle(vfs.Dirent{
Name: "..",
Type: parentInode.direntType(),
Ino: parentInode.ino,
NextOff: 2,
- }) {
- return nil
+ }); err != nil {
+ return err
}
fd.off++
}
@@ -111,14 +111,14 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
for child != nil {
// Skip other directoryFD iterators.
if child.inode != nil {
- if !cb.Handle(vfs.Dirent{
+ if err := cb.Handle(vfs.Dirent{
Name: child.vfsd.Name(),
Type: child.inode.direntType(),
Ino: child.inode.ino,
NextOff: fd.off + 1,
- }) {
+ }); err != nil {
dir.childList.InsertBefore(child, fd.iter)
- return nil
+ return err
}
fd.off++
}
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index 5ee9cf1e9..e1b551422 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -16,7 +16,6 @@ package tmpfs
import (
"fmt"
- "sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -334,7 +333,7 @@ afterTrailingSymlink:
}
func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions, afterCreate bool) (*vfs.FileDescription, error) {
- ats := vfs.AccessTypesForOpenFlags(opts.Flags)
+ ats := vfs.AccessTypesForOpenFlags(opts)
if !afterCreate {
if err := d.inode.checkPermissions(rp.Credentials(), ats, d.inode.isDir()); err != nil {
return nil, err
@@ -347,10 +346,9 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
return nil, err
}
if opts.Flags&linux.O_TRUNC != 0 {
- impl.mu.Lock()
- impl.data.Truncate(0, impl.memFile)
- atomic.StoreUint64(&impl.size, 0)
- impl.mu.Unlock()
+ if _, err := impl.truncate(0); err != nil {
+ return nil, err
+ }
}
return &fd.vfsfd, nil
case *directory:
@@ -486,7 +484,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
vfsObj := rp.VirtualFilesystem()
oldParentDir := oldParent.inode.impl.(*directory)
newParentDir := newParent.inode.impl.(*directory)
- if err := vfsObj.PrepareRenameDentry(vfs.MountNamespaceFromContext(ctx), renamedVFSD, replacedVFSD); err != nil {
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ if err := vfsObj.PrepareRenameDentry(mntns, renamedVFSD, replacedVFSD); err != nil {
return err
}
if replaced != nil {
@@ -543,7 +543,9 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
}
defer mnt.EndWrite()
vfsObj := rp.VirtualFilesystem()
- if err := vfsObj.PrepareDeleteDentry(vfs.MountNamespaceFromContext(ctx), childVFSD); err != nil {
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ if err := vfsObj.PrepareDeleteDentry(mntns, childVFSD); err != nil {
return err
}
parent.inode.impl.(*directory).childList.Remove(child)
@@ -622,7 +624,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
if child.inode.isDir() {
return syserror.EISDIR
}
- if !rp.MustBeDir() {
+ if rp.MustBeDir() {
return syserror.ENOTDIR
}
mnt := rp.Mount()
@@ -631,7 +633,9 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
}
defer mnt.EndWrite()
vfsObj := rp.VirtualFilesystem()
- if err := vfsObj.PrepareDeleteDentry(vfs.MountNamespaceFromContext(ctx), childVFSD); err != nil {
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef()
+ if err := vfsObj.PrepareDeleteDentry(mntns, childVFSD); err != nil {
return err
}
parent.inode.impl.(*directory).childList.Remove(child)
diff --git a/pkg/sentry/fsimpl/tmpfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
index 5ee7f2a72..1614f2c39 100644
--- a/pkg/sentry/fsimpl/tmpfs/pipe_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
@@ -151,7 +151,10 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy
creds := auth.CredentialsFromContext(ctx)
// Create VFS.
- vfsObj := vfs.New()
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index e9e6faf67..711442424 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -15,6 +15,7 @@
package tmpfs
import (
+ "fmt"
"io"
"math"
"sync/atomic"
@@ -22,7 +23,9 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
@@ -33,25 +36,53 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// regularFile is a regular (=S_IFREG) tmpfs file.
type regularFile struct {
inode inode
// memFile is a platform.File used to allocate pages to this regularFile.
memFile *pgalloc.MemoryFile
- // mu protects the fields below.
- mu sync.RWMutex
+ // mapsMu protects mappings.
+ mapsMu sync.Mutex `state:"nosave"`
+
+ // mappings tracks mappings of the file into memmap.MappingSpaces.
+ //
+ // Protected by mapsMu.
+ mappings memmap.MappingSet
+
+ // writableMappingPages tracks how many pages of virtual memory are mapped
+ // as potentially writable from this file. If a page has multiple mappings,
+ // each mapping is counted separately.
+ //
+ // This counter is susceptible to overflow as we can potentially count
+ // mappings from many VMAs. We count pages rather than bytes to slightly
+ // mitigate this.
+ //
+ // Protected by mapsMu.
+ writableMappingPages uint64
+
+ // dataMu protects the fields below.
+ dataMu sync.RWMutex
// data maps offsets into the file to offsets into memFile that store
// the file's data.
+ //
+ // Protected by dataMu.
data fsutil.FileRangeSet
- // size is the size of data, but accessed using atomic memory
- // operations to avoid locking in inode.stat().
- size uint64
-
// seals represents file seals on this inode.
+ //
+ // Protected by dataMu.
seals uint32
+
+ // size is the size of data.
+ //
+ // Protected by both dataMu and inode.mu; reading it requires holding
+ // either mutex, while writing requires holding both AND using atomics.
+ // Readers that do not require consistency (like Stat) may read the
+ // value atomically without holding either lock.
+ size uint64
}
func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode {
@@ -65,39 +96,170 @@ func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMod
// truncate grows or shrinks the file to the given size. It returns true if the
// file size was updated.
-func (rf *regularFile) truncate(size uint64) (bool, error) {
- rf.mu.Lock()
- defer rf.mu.Unlock()
+func (rf *regularFile) truncate(newSize uint64) (bool, error) {
+ rf.inode.mu.Lock()
+ defer rf.inode.mu.Unlock()
+ return rf.truncateLocked(newSize)
+}
- if size == rf.size {
+// Preconditions: rf.inode.mu must be held.
+func (rf *regularFile) truncateLocked(newSize uint64) (bool, error) {
+ oldSize := rf.size
+ if newSize == oldSize {
// Nothing to do.
return false, nil
}
- if size > rf.size {
- // Growing the file.
+ // Need to hold inode.mu and dataMu while modifying size.
+ rf.dataMu.Lock()
+ if newSize > oldSize {
+ // Can we grow the file?
if rf.seals&linux.F_SEAL_GROW != 0 {
- // Seal does not allow growth.
+ rf.dataMu.Unlock()
return false, syserror.EPERM
}
- rf.size = size
+ // We only need to update the file size.
+ atomic.StoreUint64(&rf.size, newSize)
+ rf.dataMu.Unlock()
return true, nil
}
- // Shrinking the file
+ // We are shrinking the file. First check if this is allowed.
if rf.seals&linux.F_SEAL_SHRINK != 0 {
- // Seal does not allow shrink.
+ rf.dataMu.Unlock()
return false, syserror.EPERM
}
- // TODO(gvisor.dev/issues/1197): Invalidate mappings once we have
- // mappings.
+ // Update the file size.
+ atomic.StoreUint64(&rf.size, newSize)
+ rf.dataMu.Unlock()
+
+ // Invalidate past translations of truncated pages.
+ oldpgend := fs.OffsetPageEnd(int64(oldSize))
+ newpgend := fs.OffsetPageEnd(int64(newSize))
+ if newpgend < oldpgend {
+ rf.mapsMu.Lock()
+ rf.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{
+ // Compare Linux's mm/shmem.c:shmem_setattr() =>
+ // mm/memory.c:unmap_mapping_range(evencows=1).
+ InvalidatePrivate: true,
+ })
+ rf.mapsMu.Unlock()
+ }
- rf.data.Truncate(size, rf.memFile)
- rf.size = size
+ // We are now guaranteed that there are no translations of truncated pages,
+ // and can remove them.
+ rf.dataMu.Lock()
+ rf.data.Truncate(newSize, rf.memFile)
+ rf.dataMu.Unlock()
return true, nil
}
+// AddMapping implements memmap.Mappable.AddMapping.
+func (rf *regularFile) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+ rf.mapsMu.Lock()
+ defer rf.mapsMu.Unlock()
+ rf.dataMu.RLock()
+ defer rf.dataMu.RUnlock()
+
+ // Reject writable mapping if F_SEAL_WRITE is set.
+ if rf.seals&linux.F_SEAL_WRITE != 0 && writable {
+ return syserror.EPERM
+ }
+
+ rf.mappings.AddMapping(ms, ar, offset, writable)
+ if writable {
+ pagesBefore := rf.writableMappingPages
+
+ // ar is guaranteed to be page aligned per memmap.Mappable.
+ rf.writableMappingPages += uint64(ar.Length() / usermem.PageSize)
+
+ if rf.writableMappingPages < pagesBefore {
+ panic(fmt.Sprintf("Overflow while mapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages))
+ }
+ }
+
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (rf *regularFile) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+ rf.mapsMu.Lock()
+ defer rf.mapsMu.Unlock()
+
+ rf.mappings.RemoveMapping(ms, ar, offset, writable)
+
+ if writable {
+ pagesBefore := rf.writableMappingPages
+
+ // ar is guaranteed to be page aligned per memmap.Mappable.
+ rf.writableMappingPages -= uint64(ar.Length() / usermem.PageSize)
+
+ if rf.writableMappingPages > pagesBefore {
+ panic(fmt.Sprintf("Underflow while unmapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages))
+ }
+ }
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (rf *regularFile) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+ return rf.AddMapping(ctx, ms, dstAR, offset, writable)
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (rf *regularFile) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ rf.dataMu.Lock()
+ defer rf.dataMu.Unlock()
+
+ // Constrain translations to f.attr.Size (rounded up) to prevent
+ // translation to pages that may be concurrently truncated.
+ pgend := fs.OffsetPageEnd(int64(rf.size))
+ var beyondEOF bool
+ if required.End > pgend {
+ if required.Start >= pgend {
+ return nil, &memmap.BusError{io.EOF}
+ }
+ beyondEOF = true
+ required.End = pgend
+ }
+ if optional.End > pgend {
+ optional.End = pgend
+ }
+
+ cerr := rf.data.Fill(ctx, required, optional, rf.memFile, usage.Tmpfs, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) {
+ // Newly-allocated pages are zeroed, so we don't need to do anything.
+ return dsts.NumBytes(), nil
+ })
+
+ var ts []memmap.Translation
+ var translatedEnd uint64
+ for seg := rf.data.FindSegment(required.Start); seg.Ok() && seg.Start() < required.End; seg, _ = seg.NextNonEmpty() {
+ segMR := seg.Range().Intersect(optional)
+ ts = append(ts, memmap.Translation{
+ Source: segMR,
+ File: rf.memFile,
+ Offset: seg.FileRangeOf(segMR).Start,
+ Perms: usermem.AnyAccess,
+ })
+ translatedEnd = segMR.End
+ }
+
+ // Don't return the error returned by f.data.Fill if it occurred outside of
+ // required.
+ if translatedEnd < required.End && cerr != nil {
+ return ts, &memmap.BusError{cerr}
+ }
+ if beyondEOF {
+ return ts, &memmap.BusError{io.EOF}
+ }
+ return ts, nil
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (*regularFile) InvalidateUnsavable(context.Context) error {
+ return nil
+}
+
type regularFileFD struct {
fileDescription
@@ -151,8 +313,10 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
// Overflow.
return 0, syserror.EFBIG
}
+ f.inode.mu.Lock()
rw := getRegularFileReadWriter(f, offset)
n, err := src.CopyInTo(ctx, rw)
+ f.inode.mu.Unlock()
putRegularFileReadWriter(rw)
return n, err
}
@@ -192,6 +356,34 @@ func (fd *regularFileFD) Sync(ctx context.Context) error {
return nil
}
+// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
+func (fd *regularFileFD) LockBSD(ctx context.Context, uid lock.UniqueID, t lock.LockType, block lock.Blocker) error {
+ return fd.inode().lockBSD(uid, t, block)
+}
+
+// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD.
+func (fd *regularFileFD) UnlockBSD(ctx context.Context, uid lock.UniqueID) error {
+ fd.inode().unlockBSD(uid)
+ return nil
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *regularFileFD) LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, rng lock.LockRange, block lock.Blocker) error {
+ return fd.inode().lockPOSIX(uid, t, rng, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *regularFileFD) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, rng lock.LockRange) error {
+ fd.inode().unlockPOSIX(uid, rng)
+ return nil
+}
+
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ file := fd.inode().impl.(*regularFile)
+ return vfs.GenericConfigureMMap(&fd.vfsfd, file, opts)
+}
+
// regularFileReadWriter implements safemem.Reader and Safemem.Writer.
type regularFileReadWriter struct {
file *regularFile
@@ -221,14 +413,15 @@ func putRegularFileReadWriter(rw *regularFileReadWriter) {
// ReadToBlocks implements safemem.Reader.ReadToBlocks.
func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
- rw.file.mu.RLock()
+ rw.file.dataMu.RLock()
+ defer rw.file.dataMu.RUnlock()
+ size := rw.file.size
// Compute the range to read (limited by file size and overflow-checked).
- if rw.off >= rw.file.size {
- rw.file.mu.RUnlock()
+ if rw.off >= size {
return 0, io.EOF
}
- end := rw.file.size
+ end := size
if rend := rw.off + dsts.NumBytes(); rend > rw.off && rend < end {
end = rend
}
@@ -242,7 +435,6 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er
// Get internal mappings.
ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read)
if err != nil {
- rw.file.mu.RUnlock()
return done, err
}
@@ -252,7 +444,6 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er
rw.off += uint64(n)
dsts = dsts.DropFirst64(n)
if err != nil {
- rw.file.mu.RUnlock()
return done, err
}
@@ -268,7 +459,6 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er
rw.off += uint64(n)
dsts = dsts.DropFirst64(n)
if err != nil {
- rw.file.mu.RUnlock()
return done, err
}
@@ -276,13 +466,16 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er
seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{}
}
}
- rw.file.mu.RUnlock()
return done, nil
}
// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+//
+// Preconditions: inode.mu must be held.
func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
- rw.file.mu.Lock()
+ // Hold dataMu so we can modify size.
+ rw.file.dataMu.Lock()
+ defer rw.file.dataMu.Unlock()
// Compute the range to write (overflow-checked).
end := rw.off + srcs.NumBytes()
@@ -293,7 +486,6 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64,
// Check if seals prevent either file growth or all writes.
switch {
case rw.file.seals&linux.F_SEAL_WRITE != 0: // Write sealed
- rw.file.mu.Unlock()
return 0, syserror.EPERM
case end > rw.file.size && rw.file.seals&linux.F_SEAL_GROW != 0: // Grow sealed
// When growth is sealed, Linux effectively allows writes which would
@@ -315,7 +507,6 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64,
}
if end <= rw.off {
// Truncation would result in no data being written.
- rw.file.mu.Unlock()
return 0, syserror.EPERM
}
}
@@ -372,9 +563,8 @@ exitLoop:
// If the write ends beyond the file's previous size, it causes the
// file to grow.
if rw.off > rw.file.size {
- atomic.StoreUint64(&rw.file.size, rw.off)
+ rw.file.size = rw.off
}
- rw.file.mu.Unlock()
return done, retErr
}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
index 32552e261..0399725cf 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
@@ -24,9 +24,11 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -38,7 +40,11 @@ var nextFileID int64
func newTmpfsRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentry, func(), error) {
creds := auth.CredentialsFromContext(ctx)
- vfsObj := vfs.New()
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err)
+ }
+
vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserMount: true,
})
@@ -49,7 +55,7 @@ func newTmpfsRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentr
root := mntns.Root()
return vfsObj, root, func() {
root.DecRef()
- mntns.DecRef(vfsObj)
+ mntns.DecRef()
}, nil
}
@@ -260,6 +266,60 @@ func TestPWrite(t *testing.T) {
}
}
+func TestLocks(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ var (
+ uid1 lock.UniqueID
+ uid2 lock.UniqueID
+ // Non-blocking.
+ block lock.Blocker
+ )
+
+ uid1 = 123
+ uid2 = 456
+
+ if err := fd.Impl().LockBSD(ctx, uid1, lock.ReadLock, block); err != nil {
+ t.Fatalf("fd.Impl().LockBSD failed: err = %v", err)
+ }
+ if err := fd.Impl().LockBSD(ctx, uid2, lock.ReadLock, block); err != nil {
+ t.Fatalf("fd.Impl().LockBSD failed: err = %v", err)
+ }
+ if got, want := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, block), syserror.ErrWouldBlock; got != want {
+ t.Fatalf("fd.Impl().LockBSD failed: got = %v, want = %v", got, want)
+ }
+ if err := fd.Impl().UnlockBSD(ctx, uid1); err != nil {
+ t.Fatalf("fd.Impl().UnlockBSD failed: err = %v", err)
+ }
+ if err := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, block); err != nil {
+ t.Fatalf("fd.Impl().LockBSD failed: err = %v", err)
+ }
+
+ rng1 := lock.LockRange{0, 1}
+ rng2 := lock.LockRange{1, 2}
+
+ if err := fd.Impl().LockPOSIX(ctx, uid1, lock.ReadLock, rng1, block); err != nil {
+ t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err)
+ }
+ if err := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, rng2, block); err != nil {
+ t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err)
+ }
+ if err := fd.Impl().LockPOSIX(ctx, uid1, lock.WriteLock, rng1, block); err != nil {
+ t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err)
+ }
+ if got, want := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, rng1, block), syserror.ErrWouldBlock; got != want {
+ t.Fatalf("fd.Impl().LockPOSIX failed: got = %v, want = %v", got, want)
+ }
+ if err := fd.Impl().UnlockPOSIX(ctx, uid1, rng1); err != nil {
+ t.Fatalf("fd.Impl().UnlockPOSIX failed: err = %v", err)
+ }
+}
+
func TestPRead(t *testing.T) {
ctx := contexttest.Context(t)
fd, cleanup, err := newFileFD(ctx, 0644)
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index 88dbd6e35..521206305 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -18,9 +18,10 @@
// Lock order:
//
// filesystem.mu
-// regularFileFD.offMu
-// regularFile.mu
// inode.mu
+// regularFileFD.offMu
+// regularFile.mapsMu
+// regularFile.dataMu
package tmpfs
import (
@@ -30,14 +31,19 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/lock"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
+// Name is the default filesystem name.
+const Name = "tmpfs"
+
// FilesystemType implements vfs.FilesystemType.
type FilesystemType struct{}
@@ -153,6 +159,9 @@ type inode struct {
rdevMajor uint32
rdevMinor uint32
+ // Advisory file locks, which lock at the inode level.
+ locks lock.FileLocks
+
impl interface{} // immutable
}
@@ -218,12 +227,15 @@ func (i *inode) tryIncRef() bool {
func (i *inode) decRef() {
if refs := atomic.AddInt64(&i.refs, -1); refs == 0 {
- // This is unnecessary; it's mostly to simulate what tmpfs would do.
if regFile, ok := i.impl.(*regularFile); ok {
- regFile.mu.Lock()
+ // Hold inode.mu and regFile.dataMu while mutating
+ // size.
+ i.mu.Lock()
+ regFile.dataMu.Lock()
regFile.data.DropAll(regFile.memFile)
atomic.StoreUint64(&regFile.size, 0)
- regFile.mu.Unlock()
+ regFile.dataMu.Unlock()
+ i.mu.Unlock()
}
} else if refs < 0 {
panic("tmpfs.inode.decRef() called without holding a reference")
@@ -312,7 +324,7 @@ func (i *inode) setStat(stat linux.Statx) error {
if mask&linux.STATX_SIZE != 0 {
switch impl := i.impl.(type) {
case *regularFile:
- updated, err := impl.truncate(stat.Size)
+ updated, err := impl.truncateLocked(stat.Size)
if err != nil {
return err
}
@@ -352,6 +364,44 @@ func (i *inode) setStat(stat linux.Statx) error {
return nil
}
+// TODO(gvisor.dev/issue/1480): support file locking for file types other than regular.
+func (i *inode) lockBSD(uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error {
+ switch i.impl.(type) {
+ case *regularFile:
+ return i.locks.LockBSD(uid, t, block)
+ }
+ return syserror.EBADF
+}
+
+// TODO(gvisor.dev/issue/1480): support file locking for file types other than regular.
+func (i *inode) unlockBSD(uid fslock.UniqueID) error {
+ switch i.impl.(type) {
+ case *regularFile:
+ i.locks.UnlockBSD(uid)
+ return nil
+ }
+ return syserror.EBADF
+}
+
+// TODO(gvisor.dev/issue/1480): support file locking for file types other than regular.
+func (i *inode) lockPOSIX(uid fslock.UniqueID, t fslock.LockType, rng fslock.LockRange, block fslock.Blocker) error {
+ switch i.impl.(type) {
+ case *regularFile:
+ return i.locks.LockPOSIX(uid, t, rng, block)
+ }
+ return syserror.EBADF
+}
+
+// TODO(gvisor.dev/issue/1480): support file locking for file types other than regular.
+func (i *inode) unlockPOSIX(uid fslock.UniqueID, rng fslock.LockRange) error {
+ switch i.impl.(type) {
+ case *regularFile:
+ i.locks.UnlockPOSIX(uid, rng)
+ return nil
+ }
+ return syserror.EBADF
+}
+
// allocatedBlocksForSize returns the number of 512B blocks needed to
// accommodate the given size in bytes, as appropriate for struct
// stat::st_blocks and struct statx::stx_blocks. (Note that this 512B block
diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD
index 334432abf..07bf39fed 100644
--- a/pkg/sentry/inet/BUILD
+++ b/pkg/sentry/inet/BUILD
@@ -10,6 +10,7 @@ go_library(
srcs = [
"context.go",
"inet.go",
+ "namespace.go",
"test_stack.go",
],
deps = [
diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go
index a7dfb78a7..2916a0644 100644
--- a/pkg/sentry/inet/inet.go
+++ b/pkg/sentry/inet/inet.go
@@ -28,6 +28,10 @@ type Stack interface {
// interface indexes to a slice of associated interface address properties.
InterfaceAddrs() map[int32][]InterfaceAddr
+ // AddInterfaceAddr adds an address to the network interface identified by
+ // index.
+ AddInterfaceAddr(idx int32, addr InterfaceAddr) error
+
// SupportsIPv6 returns true if the stack supports IPv6 connectivity.
SupportsIPv6() bool
diff --git a/pkg/sentry/inet/namespace.go b/pkg/sentry/inet/namespace.go
new file mode 100644
index 000000000..c16667e7f
--- /dev/null
+++ b/pkg/sentry/inet/namespace.go
@@ -0,0 +1,99 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package inet
+
+// Namespace represents a network namespace. See network_namespaces(7).
+//
+// +stateify savable
+type Namespace struct {
+ // stack is the network stack implementation of this network namespace.
+ stack Stack `state:"nosave"`
+
+ // creator allows kernel to create new network stack for network namespaces.
+ // If nil, no networking will function if network is namespaced.
+ creator NetworkStackCreator
+
+ // isRoot indicates whether this is the root network namespace.
+ isRoot bool
+}
+
+// NewRootNamespace creates the root network namespace, with creator
+// allowing new network namespaces to be created. If creator is nil, no
+// networking will function if the network is namespaced.
+func NewRootNamespace(stack Stack, creator NetworkStackCreator) *Namespace {
+ return &Namespace{
+ stack: stack,
+ creator: creator,
+ isRoot: true,
+ }
+}
+
+// NewNamespace creates a new network namespace from the root.
+func NewNamespace(root *Namespace) *Namespace {
+ n := &Namespace{
+ creator: root.creator,
+ }
+ n.init()
+ return n
+}
+
+// Stack returns the network stack of n. Stack may return nil if no network
+// stack is configured.
+func (n *Namespace) Stack() Stack {
+ return n.stack
+}
+
+// IsRoot returns whether n is the root network namespace.
+func (n *Namespace) IsRoot() bool {
+ return n.isRoot
+}
+
+// RestoreRootStack restores the root network namespace with stack. This should
+// only be called when restoring kernel.
+func (n *Namespace) RestoreRootStack(stack Stack) {
+ if !n.isRoot {
+ panic("RestoreRootStack can only be called on root network namespace")
+ }
+ if n.stack != nil {
+ panic("RestoreRootStack called after a stack has already been set")
+ }
+ n.stack = stack
+}
+
+func (n *Namespace) init() {
+ // Root network namespace will have stack assigned later.
+ if n.isRoot {
+ return
+ }
+ if n.creator != nil {
+ var err error
+ n.stack, err = n.creator.CreateStack()
+ if err != nil {
+ panic(err)
+ }
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (n *Namespace) afterLoad() {
+ n.init()
+}
+
+// NetworkStackCreator allows new instances of a network stack to be created. It
+// is used by the kernel to create new network namespaces when requested.
+type NetworkStackCreator interface {
+ // CreateStack creates a new network stack for a network namespace.
+ CreateStack() (Stack, error)
+}
diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go
index dcfcbd97e..d8961fc94 100644
--- a/pkg/sentry/inet/test_stack.go
+++ b/pkg/sentry/inet/test_stack.go
@@ -47,6 +47,12 @@ func (s *TestStack) InterfaceAddrs() map[int32][]InterfaceAddr {
return s.InterfaceAddrsMap
}
+// AddInterfaceAddr implements Stack.AddInterfaceAddr.
+func (s *TestStack) AddInterfaceAddr(idx int32, addr InterfaceAddr) error {
+ s.InterfaceAddrsMap[idx] = append(s.InterfaceAddrsMap[idx], addr)
+ return nil
+}
+
// SupportsIPv6 implements Stack.SupportsIPv6.
func (s *TestStack) SupportsIPv6() bool {
return s.SupportsIPv6Flag
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index a27628c0a..beba29a09 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -91,6 +91,7 @@ go_library(
"fs_context.go",
"ipc_namespace.go",
"kernel.go",
+ "kernel_opts.go",
"kernel_state.go",
"pending_signals.go",
"pending_signals_list.go",
@@ -156,6 +157,7 @@ go_library(
"//pkg/context",
"//pkg/cpuid",
"//pkg/eventchannel",
+ "//pkg/fspath",
"//pkg/log",
"//pkg/metric",
"//pkg/refs",
@@ -166,6 +168,7 @@ go_library(
"//pkg/sentry/fs",
"//pkg/sentry/fs/lock",
"//pkg/sentry/fs/timerfd",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/hostcpu",
"//pkg/sentry/inet",
"//pkg/sentry/kernel/auth",
@@ -198,6 +201,7 @@ go_library(
"//pkg/tcpip/stack",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index 23b88f7a6..58001d56c 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -296,6 +296,50 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
return fds, nil
}
+// NewFDVFS2 allocates a file descriptor greater than or equal to minfd for
+// the given file description. If it succeeds, it takes a reference on file.
+func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) {
+ if minfd < 0 {
+ // Don't accept negative FDs.
+ return -1, syscall.EINVAL
+ }
+
+ // Default limit.
+ end := int32(math.MaxInt32)
+
+ // Ensure we don't get past the provided limit.
+ if limitSet := limits.FromContext(ctx); limitSet != nil {
+ lim := limitSet.Get(limits.NumberOfFiles)
+ if lim.Cur != limits.Infinity {
+ end = int32(lim.Cur)
+ }
+ if minfd >= end {
+ return -1, syscall.EMFILE
+ }
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ // From f.next to find available fd.
+ fd := minfd
+ if fd < f.next {
+ fd = f.next
+ }
+ for fd < end {
+ if d, _, _ := f.get(fd); d == nil {
+ f.setVFS2(fd, file, flags)
+ if fd == f.next {
+ // Update next search start position.
+ f.next = fd + 1
+ }
+ return fd, nil
+ }
+ fd++
+ }
+ return -1, syscall.EMFILE
+}
+
// NewFDAt sets the file reference for the given FD. If there is an active
// reference for that FD, the ref count for that existing reference is
// decremented.
@@ -316,9 +360,6 @@ func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2
return syscall.EBADF
}
- f.mu.Lock()
- defer f.mu.Unlock()
-
// Check the limit for the provided file.
if limitSet := limits.FromContext(ctx); limitSet != nil {
if lim := limitSet.Get(limits.NumberOfFiles); lim.Cur != limits.Infinity && uint64(fd) >= lim.Cur {
@@ -327,6 +368,8 @@ func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2
}
// Install the entry.
+ f.mu.Lock()
+ defer f.mu.Unlock()
f.setAll(fd, file, fileVFS2, flags)
return nil
}
diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go
index 2448c1d99..47f78df9a 100644
--- a/pkg/sentry/kernel/fs_context.go
+++ b/pkg/sentry/kernel/fs_context.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -37,10 +38,16 @@ type FSContext struct {
// destroyed.
root *fs.Dirent
+ // rootVFS2 is the filesystem root.
+ rootVFS2 vfs.VirtualDentry
+
// cwd is the current working directory. Will be nil iff the FSContext
// has been destroyed.
cwd *fs.Dirent
+ // cwdVFS2 is the current working directory.
+ cwdVFS2 vfs.VirtualDentry
+
// umask is the current file mode creation mask. When a thread using this
// context invokes a syscall that creates a file, bits set in umask are
// removed from the permissions that the file is created with.
@@ -60,6 +67,19 @@ func newFSContext(root, cwd *fs.Dirent, umask uint) *FSContext {
return &f
}
+// NewFSContextVFS2 returns a new filesystem context.
+func NewFSContextVFS2(root, cwd vfs.VirtualDentry, umask uint) *FSContext {
+ root.IncRef()
+ cwd.IncRef()
+ f := FSContext{
+ rootVFS2: root,
+ cwdVFS2: cwd,
+ umask: umask,
+ }
+ f.EnableLeakCheck("kernel.FSContext")
+ return &f
+}
+
// destroy is the destructor for an FSContext.
//
// This will call DecRef on both root and cwd Dirents. If either call to
@@ -75,11 +95,17 @@ func (f *FSContext) destroy() {
f.mu.Lock()
defer f.mu.Unlock()
- f.root.DecRef()
- f.root = nil
-
- f.cwd.DecRef()
- f.cwd = nil
+ if VFS2Enabled {
+ f.rootVFS2.DecRef()
+ f.rootVFS2 = vfs.VirtualDentry{}
+ f.cwdVFS2.DecRef()
+ f.cwdVFS2 = vfs.VirtualDentry{}
+ } else {
+ f.root.DecRef()
+ f.root = nil
+ f.cwd.DecRef()
+ f.cwd = nil
+ }
}
// DecRef implements RefCounter.DecRef with destructor f.destroy.
@@ -93,12 +119,21 @@ func (f *FSContext) DecRef() {
func (f *FSContext) Fork() *FSContext {
f.mu.Lock()
defer f.mu.Unlock()
- f.cwd.IncRef()
- f.root.IncRef()
+
+ if VFS2Enabled {
+ f.cwdVFS2.IncRef()
+ f.rootVFS2.IncRef()
+ } else {
+ f.cwd.IncRef()
+ f.root.IncRef()
+ }
+
return &FSContext{
- cwd: f.cwd,
- root: f.root,
- umask: f.umask,
+ cwd: f.cwd,
+ root: f.root,
+ cwdVFS2: f.cwdVFS2,
+ rootVFS2: f.rootVFS2,
+ umask: f.umask,
}
}
@@ -109,12 +144,23 @@ func (f *FSContext) Fork() *FSContext {
func (f *FSContext) WorkingDirectory() *fs.Dirent {
f.mu.Lock()
defer f.mu.Unlock()
- if f.cwd != nil {
- f.cwd.IncRef()
- }
+
+ f.cwd.IncRef()
return f.cwd
}
+// WorkingDirectoryVFS2 returns the current working directory.
+//
+// This will return nil if called after destroy(), otherwise it will return a
+// Dirent with a reference taken.
+func (f *FSContext) WorkingDirectoryVFS2() vfs.VirtualDentry {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ f.cwdVFS2.IncRef()
+ return f.cwdVFS2
+}
+
// SetWorkingDirectory sets the current working directory.
// This will take an extra reference on the Dirent.
//
@@ -137,6 +183,20 @@ func (f *FSContext) SetWorkingDirectory(d *fs.Dirent) {
old.DecRef()
}
+// SetWorkingDirectoryVFS2 sets the current working directory.
+// This will take an extra reference on the VirtualDentry.
+//
+// This is not a valid call after destroy.
+func (f *FSContext) SetWorkingDirectoryVFS2(d vfs.VirtualDentry) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ old := f.cwdVFS2
+ f.cwdVFS2 = d
+ d.IncRef()
+ old.DecRef()
+}
+
// RootDirectory returns the current filesystem root.
//
// This will return nil if called after destroy(), otherwise it will return a
@@ -150,6 +210,18 @@ func (f *FSContext) RootDirectory() *fs.Dirent {
return f.root
}
+// RootDirectoryVFS2 returns the current filesystem root.
+//
+// This will return nil if called after destroy(), otherwise it will return a
+// Dirent with a reference taken.
+func (f *FSContext) RootDirectoryVFS2() vfs.VirtualDentry {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ f.rootVFS2.IncRef()
+ return f.rootVFS2
+}
+
// SetRootDirectory sets the root directory.
// This will take an extra reference on the Dirent.
//
@@ -172,6 +244,28 @@ func (f *FSContext) SetRootDirectory(d *fs.Dirent) {
old.DecRef()
}
+// SetRootDirectoryVFS2 sets the root directory. It takes a reference on vd.
+//
+// This is not a valid call after free.
+func (f *FSContext) SetRootDirectoryVFS2(vd vfs.VirtualDentry) {
+ if !vd.Ok() {
+ panic("FSContext.SetRootDirectoryVFS2 called with zero-value VirtualDentry")
+ }
+
+ f.mu.Lock()
+
+ if !f.rootVFS2.Ok() {
+ f.mu.Unlock()
+ panic(fmt.Sprintf("FSContext.SetRootDirectoryVFS2(%v)) called after destroy", vd))
+ }
+
+ old := f.rootVFS2
+ vd.IncRef()
+ f.rootVFS2 = vd
+ f.mu.Unlock()
+ old.DecRef()
+}
+
// Umask returns the current umask.
func (f *FSContext) Umask() uint {
f.mu.Lock()
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index dcd6e91c4..8b76750e9 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -43,11 +43,13 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/eventchannel"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/timerfd"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/hostcpu"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -71,6 +73,10 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
+// VFS2Enabled is set to true when VFS2 is enabled. Added as a global for allow
+// easy access everywhere. To be removed once VFS2 becomes the default.
+var VFS2Enabled = false
+
// Kernel represents an emulated Linux kernel. It must be initialized by calling
// Init() or LoadFrom().
//
@@ -105,7 +111,7 @@ type Kernel struct {
timekeeper *Timekeeper
tasks *TaskSet
rootUserNamespace *auth.UserNamespace
- networkStack inet.Stack `state:"nosave"`
+ rootNetworkNamespace *inet.Namespace
applicationCores uint
useHostCores bool
extraAuxv []arch.AuxEntry
@@ -235,6 +241,16 @@ type Kernel struct {
// events. This is initialized lazily on the first unimplemented
// syscall.
unimplementedSyscallEmitter eventchannel.Emitter `state:"nosave"`
+
+ // SpecialOpts contains special kernel options.
+ SpecialOpts
+
+ // VFS keeps the filesystem state used across the kernel.
+ vfs vfs.VirtualFilesystem
+
+ // If set to true, report address space activation waits as if the task is in
+ // external wait so that the watchdog doesn't report the task stuck.
+ SleepForAddressSpaceActivation bool
}
// InitKernelArgs holds arguments to Init.
@@ -248,8 +264,9 @@ type InitKernelArgs struct {
// RootUserNamespace is the root user namespace.
RootUserNamespace *auth.UserNamespace
- // NetworkStack is the TCP/IP network stack. NetworkStack may be nil.
- NetworkStack inet.Stack
+ // RootNetworkNamespace is the root network namespace. If nil, no networking
+ // will be available.
+ RootNetworkNamespace *inet.Namespace
// ApplicationCores is the number of logical CPUs visible to sandboxed
// applications. The set of logical CPU IDs is [0, ApplicationCores); thus
@@ -308,7 +325,10 @@ func (k *Kernel) Init(args InitKernelArgs) error {
k.rootUTSNamespace = args.RootUTSNamespace
k.rootIPCNamespace = args.RootIPCNamespace
k.rootAbstractSocketNamespace = args.RootAbstractSocketNamespace
- k.networkStack = args.NetworkStack
+ k.rootNetworkNamespace = args.RootNetworkNamespace
+ if k.rootNetworkNamespace == nil {
+ k.rootNetworkNamespace = inet.NewRootNamespace(nil, nil)
+ }
k.applicationCores = args.ApplicationCores
if args.UseHostCores {
k.useHostCores = true
@@ -531,8 +551,6 @@ func (ts *TaskSet) unregisterEpollWaiters() {
func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) error {
loadStart := time.Now()
- k.networkStack = net
-
initAppCores := k.applicationCores
// Load the pre-saved CPUID FeatureSet.
@@ -563,6 +581,10 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks)
log.Infof("Kernel load stats: %s", &stats)
log.Infof("Kernel load took [%s].", time.Since(kernelStart))
+ // rootNetworkNamespace should be populated after loading the state file.
+ // Restore the root network stack.
+ k.rootNetworkNamespace.RestoreRootStack(net)
+
// Load the memory file's state.
memoryStart := time.Now()
if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil {
@@ -621,7 +643,7 @@ type CreateProcessArgs struct {
// File is a passed host FD pointing to a file to load as the init binary.
//
// This is checked if and only if Filename is "".
- File *fs.File
+ File fsbridge.File
// Argvv is a list of arguments.
Argv []string
@@ -670,6 +692,13 @@ type CreateProcessArgs struct {
// increment it).
MountNamespace *fs.MountNamespace
+ // MountNamespaceVFS2 optionally contains the mount namespace for this
+ // process. If nil, the init process's mount namespace is used.
+ //
+ // Anyone setting MountNamespaceVFS2 must donate a reference (i.e.
+ // increment it).
+ MountNamespaceVFS2 *vfs.MountNamespace
+
// ContainerID is the container that the process belongs to.
ContainerID string
}
@@ -708,11 +737,22 @@ func (ctx *createProcessContext) Value(key interface{}) interface{} {
return ctx.args.Credentials
case fs.CtxRoot:
if ctx.args.MountNamespace != nil {
- // MountNamespace.Root() will take a reference on the root
- // dirent for us.
+ // MountNamespace.Root() will take a reference on the root dirent for us.
return ctx.args.MountNamespace.Root()
}
return nil
+ case vfs.CtxRoot:
+ if ctx.args.MountNamespaceVFS2 == nil {
+ return nil
+ }
+ // MountNamespaceVFS2.Root() takes a reference on the root dirent for us.
+ return ctx.args.MountNamespaceVFS2.Root()
+ case vfs.CtxMountNamespace:
+ if ctx.k.globalInit == nil {
+ return nil
+ }
+ // MountNamespaceVFS2 takes a reference for us.
+ return ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
case fs.CtxDirentCacheLimiter:
return ctx.k.DirentCacheLimiter
case ktime.CtxRealtimeClock:
@@ -754,34 +794,77 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
defer k.extMu.Unlock()
log.Infof("EXEC: %v", args.Argv)
- // Grab the mount namespace.
- mounts := args.MountNamespace
- if mounts == nil {
- mounts = k.GlobalInit().Leader().MountNamespace()
- mounts.IncRef()
- }
-
- tg := k.NewThreadGroup(mounts, args.PIDNamespace, NewSignalHandlers(), linux.SIGCHLD, args.Limits)
ctx := args.NewContext(k)
- // Get the root directory from the MountNamespace.
- root := mounts.Root()
- // The call to newFSContext below will take a reference on root, so we
- // don't need to hold this one.
- defer root.DecRef()
-
- // Grab the working directory.
- remainingTraversals := uint(args.MaxSymlinkTraversals)
- wd := root // Default.
- if args.WorkingDirectory != "" {
- var err error
- wd, err = mounts.FindInode(ctx, root, nil, args.WorkingDirectory, &remainingTraversals)
- if err != nil {
- return nil, 0, fmt.Errorf("failed to find initial working directory %q: %v", args.WorkingDirectory, err)
+ var (
+ opener fsbridge.Lookup
+ fsContext *FSContext
+ mntns *fs.MountNamespace
+ )
+
+ if VFS2Enabled {
+ mntnsVFS2 := args.MountNamespaceVFS2
+ if mntnsVFS2 == nil {
+ // MountNamespaceVFS2 adds a reference to the namespace, which is
+ // transferred to the new process.
+ mntnsVFS2 = k.GlobalInit().Leader().MountNamespaceVFS2()
}
- defer wd.DecRef()
+ // Get the root directory from the MountNamespace.
+ root := args.MountNamespaceVFS2.Root()
+ // The call to newFSContext below will take a reference on root, so we
+ // don't need to hold this one.
+ defer root.DecRef()
+
+ // Grab the working directory.
+ wd := root // Default.
+ if args.WorkingDirectory != "" {
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: wd,
+ Path: fspath.Parse(args.WorkingDirectory),
+ FollowFinalSymlink: true,
+ }
+ var err error
+ wd, err = k.VFS().GetDentryAt(ctx, args.Credentials, &pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return nil, 0, fmt.Errorf("failed to find initial working directory %q: %v", args.WorkingDirectory, err)
+ }
+ defer wd.DecRef()
+ }
+ opener = fsbridge.NewVFSLookup(mntnsVFS2, root, wd)
+ fsContext = NewFSContextVFS2(root, wd, args.Umask)
+
+ } else {
+ mntns = args.MountNamespace
+ if mntns == nil {
+ mntns = k.GlobalInit().Leader().MountNamespace()
+ mntns.IncRef()
+ }
+ // Get the root directory from the MountNamespace.
+ root := mntns.Root()
+ // The call to newFSContext below will take a reference on root, so we
+ // don't need to hold this one.
+ defer root.DecRef()
+
+ // Grab the working directory.
+ remainingTraversals := args.MaxSymlinkTraversals
+ wd := root // Default.
+ if args.WorkingDirectory != "" {
+ var err error
+ wd, err = mntns.FindInode(ctx, root, nil, args.WorkingDirectory, &remainingTraversals)
+ if err != nil {
+ return nil, 0, fmt.Errorf("failed to find initial working directory %q: %v", args.WorkingDirectory, err)
+ }
+ defer wd.DecRef()
+ }
+ opener = fsbridge.NewFSLookup(mntns, root, wd)
+ fsContext = newFSContext(root, wd, args.Umask)
}
+ tg := k.NewThreadGroup(mntns, args.PIDNamespace, NewSignalHandlers(), linux.SIGCHLD, args.Limits)
+
// Check which file to start from.
switch {
case args.Filename != "":
@@ -802,11 +885,9 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
}
// Create a fresh task context.
- remainingTraversals = uint(args.MaxSymlinkTraversals)
+ remainingTraversals := args.MaxSymlinkTraversals
loadArgs := loader.LoadArgs{
- Mounts: mounts,
- Root: root,
- WorkingDirectory: wd,
+ Opener: opener,
RemainingTraversals: &remainingTraversals,
ResolveFinal: true,
Filename: args.Filename,
@@ -831,13 +912,15 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
Kernel: k,
ThreadGroup: tg,
TaskContext: tc,
- FSContext: newFSContext(root, wd, args.Umask),
+ FSContext: fsContext,
FDTable: args.FDTable,
Credentials: args.Credentials,
+ NetworkNamespace: k.RootNetworkNamespace(),
AllowedCPUMask: sched.NewFullCPUSet(k.applicationCores),
UTSNamespace: args.UTSNamespace,
IPCNamespace: args.IPCNamespace,
AbstractSocketNamespace: args.AbstractSocketNamespace,
+ MountNamespaceVFS2: args.MountNamespaceVFS2,
ContainerID: args.ContainerID,
}
t, err := k.tasks.NewTask(config)
@@ -1097,6 +1180,14 @@ func (k *Kernel) SendExternalSignal(info *arch.SignalInfo, context string) {
k.sendExternalSignal(info, context)
}
+// SendExternalSignalThreadGroup injects a signal into an specific ThreadGroup.
+// This function doesn't skip signals like SendExternalSignal does.
+func (k *Kernel) SendExternalSignalThreadGroup(tg *ThreadGroup, info *arch.SignalInfo) error {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ return tg.SendSignal(info)
+}
+
// SendContainerSignal sends the given signal to all processes inside the
// namespace that match the given container ID.
func (k *Kernel) SendContainerSignal(cid string, info *arch.SignalInfo) error {
@@ -1175,10 +1266,9 @@ func (k *Kernel) RootAbstractSocketNamespace() *AbstractSocketNamespace {
return k.rootAbstractSocketNamespace
}
-// NetworkStack returns the network stack. NetworkStack may return nil if no
-// network stack is available.
-func (k *Kernel) NetworkStack() inet.Stack {
- return k.networkStack
+// RootNetworkNamespace returns the root network namespace, always non-nil.
+func (k *Kernel) RootNetworkNamespace() *inet.Namespace {
+ return k.rootNetworkNamespace
}
// GlobalInit returns the thread group with ID 1 in the root PID namespace, or
@@ -1375,6 +1465,20 @@ func (ctx supervisorContext) Value(key interface{}) interface{} {
return ctx.k.globalInit.mounts.Root()
}
return nil
+ case vfs.CtxRoot:
+ if ctx.k.globalInit == nil {
+ return vfs.VirtualDentry{}
+ }
+ mntns := ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
+ defer mntns.DecRef()
+ // Root() takes a reference on the root dirent for us.
+ return mntns.Root()
+ case vfs.CtxMountNamespace:
+ if ctx.k.globalInit == nil {
+ return nil
+ }
+ // MountNamespaceVFS2() takes a reference for us.
+ return ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
case fs.CtxDirentCacheLimiter:
return ctx.k.DirentCacheLimiter
case ktime.CtxRealtimeClock:
@@ -1420,3 +1524,8 @@ func (k *Kernel) EmitUnimplementedEvent(ctx context.Context) {
Registers: t.Arch().StateData().Proto(),
})
}
+
+// VFS returns the virtual filesystem for the kernel.
+func (k *Kernel) VFS() *vfs.VirtualFilesystem {
+ return &k.vfs
+}
diff --git a/pkg/sentry/kernel/kernel_opts.go b/pkg/sentry/kernel/kernel_opts.go
new file mode 100644
index 000000000..2e66ec587
--- /dev/null
+++ b/pkg/sentry/kernel/kernel_opts.go
@@ -0,0 +1,20 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+// SpecialOpts contains non-standard options for the kernel.
+//
+// +stateify savable
+type SpecialOpts struct{}
diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go
index efebfd872..ded95f532 100644
--- a/pkg/sentry/kernel/rseq.go
+++ b/pkg/sentry/kernel/rseq.go
@@ -303,26 +303,14 @@ func (t *Task) rseqAddrInterrupt() {
return
}
- buf = t.CopyScratchBuffer(linux.SizeOfRSeqCriticalSection)
- if _, err := t.CopyInBytes(critAddr, buf); err != nil {
+ var cs linux.RSeqCriticalSection
+ if err := cs.CopyIn(t, critAddr); err != nil {
t.Debugf("Failed to copy critical section from %#x for rseq: %v", critAddr, err)
t.forceSignal(linux.SIGSEGV, false /* unconditional */)
t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
return
}
- // Manually marshal RSeqCriticalSection as this is in the hot path when
- // rseq is enabled. It must be as fast as possible.
- //
- // TODO(b/130243041): Replace with go_marshal.
- cs := linux.RSeqCriticalSection{
- Version: usermem.ByteOrder.Uint32(buf[0:4]),
- Flags: usermem.ByteOrder.Uint32(buf[4:8]),
- Start: usermem.ByteOrder.Uint64(buf[8:16]),
- PostCommitOffset: usermem.ByteOrder.Uint64(buf[16:24]),
- Abort: usermem.ByteOrder.Uint64(buf[24:32]),
- }
-
if cs.Version != 0 {
t.Debugf("Unknown version in %+v", cs)
t.forceSignal(linux.SIGSEGV, false /* unconditional */)
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index 981e8c7fe..2cee2e6ed 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -424,6 +424,11 @@ type Task struct {
// abstractSockets is protected by mu.
abstractSockets *AbstractSocketNamespace
+ // mountNamespaceVFS2 is the task's mount namespace.
+ //
+ // It is protected by mu. It is owned by the task goroutine.
+ mountNamespaceVFS2 *vfs.MountNamespace
+
// parentDeathSignal is sent to this task's thread group when its parent exits.
//
// parentDeathSignal is protected by mu.
@@ -481,13 +486,10 @@ type Task struct {
numaPolicy int32
numaNodeMask uint64
- // If netns is true, the task is in a non-root network namespace. Network
- // namespaces aren't currently implemented in full; being in a network
- // namespace simply prevents the task from observing any network devices
- // (including loopback) or using abstract socket addresses (see unix(7)).
+ // netns is the task's network namespace. netns is never nil.
//
- // netns is protected by mu. netns is owned by the task goroutine.
- netns bool
+ // netns is protected by mu.
+ netns *inet.Namespace
// If rseqPreempted is true, before the next call to p.Switch(),
// interrupt rseq critical regions as defined by rseqAddr and
@@ -638,6 +640,11 @@ func (t *Task) Value(key interface{}) interface{} {
return int32(t.ThreadGroup().ID())
case fs.CtxRoot:
return t.fsContext.RootDirectory()
+ case vfs.CtxRoot:
+ return t.fsContext.RootDirectoryVFS2()
+ case vfs.CtxMountNamespace:
+ t.mountNamespaceVFS2.IncRef()
+ return t.mountNamespaceVFS2
case fs.CtxDirentCacheLimiter:
return t.k.DirentCacheLimiter
case inet.CtxStack:
@@ -701,6 +708,14 @@ func (t *Task) SyscallRestartBlock() SyscallRestartBlock {
// Preconditions: The caller must be running on the task goroutine, or t.mu
// must be locked.
func (t *Task) IsChrooted() bool {
+ if VFS2Enabled {
+ realRoot := t.mountNamespaceVFS2.Root()
+ defer realRoot.DecRef()
+ root := t.fsContext.RootDirectoryVFS2()
+ defer root.DecRef()
+ return root != realRoot
+ }
+
realRoot := t.tg.mounts.Root()
defer realRoot.DecRef()
root := t.fsContext.RootDirectory()
@@ -774,6 +789,15 @@ func (t *Task) NewFDFrom(fd int32, file *fs.File, flags FDFlags) (int32, error)
return fds[0], nil
}
+// NewFDFromVFS2 is a convenience wrapper for t.FDTable().NewFDVFS2.
+//
+// This automatically passes the task as the context.
+//
+// Precondition: same as FDTable.Get.
+func (t *Task) NewFDFromVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) {
+ return t.fdTable.NewFDVFS2(t, fd, file, flags)
+}
+
// NewFDAt is a convenience wrapper for t.FDTable().NewFDAt.
//
// This automatically passes the task as the context.
@@ -783,6 +807,15 @@ func (t *Task) NewFDAt(fd int32, file *fs.File, flags FDFlags) error {
return t.fdTable.NewFDAt(t, fd, file, flags)
}
+// NewFDAtVFS2 is a convenience wrapper for t.FDTable().NewFDAtVFS2.
+//
+// This automatically passes the task as the context.
+//
+// Precondition: same as FDTable.
+func (t *Task) NewFDAtVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) error {
+ return t.fdTable.NewFDAtVFS2(t, fd, file, flags)
+}
+
// WithMuLocked executes f with t.mu locked.
func (t *Task) WithMuLocked(f func(*Task)) {
t.mu.Lock()
@@ -796,6 +829,15 @@ func (t *Task) MountNamespace() *fs.MountNamespace {
return t.tg.mounts
}
+// MountNamespaceVFS2 returns t's MountNamespace. A reference is taken on the
+// returned mount namespace.
+func (t *Task) MountNamespaceVFS2() *vfs.MountNamespace {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.mountNamespaceVFS2.IncRef()
+ return t.mountNamespaceVFS2
+}
+
// AbstractSockets returns t's AbstractSocketNamespace.
func (t *Task) AbstractSockets() *AbstractSocketNamespace {
return t.abstractSockets
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index 53d4d211b..78866f280 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -17,6 +17,7 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bpf"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -54,8 +55,7 @@ type SharingOptions struct {
NewUserNamespace bool
// If NewNetworkNamespace is true, the task should have an independent
- // network namespace. (Note that network namespaces are not really
- // implemented; see comment on Task.netns for details.)
+ // network namespace.
NewNetworkNamespace bool
// If NewFiles is true, the task should use an independent file descriptor
@@ -199,6 +199,17 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
ipcns = NewIPCNamespace(userns)
}
+ netns := t.NetworkNamespace()
+ if opts.NewNetworkNamespace {
+ netns = inet.NewNamespace(netns)
+ }
+
+ // TODO(b/63601033): Implement CLONE_NEWNS.
+ mntnsVFS2 := t.mountNamespaceVFS2
+ if mntnsVFS2 != nil {
+ mntnsVFS2.IncRef()
+ }
+
tc, err := t.tc.Fork(t, t.k, !opts.NewAddressSpace)
if err != nil {
return 0, nil, err
@@ -241,7 +252,9 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
rseqAddr := usermem.Addr(0)
rseqSignature := uint32(0)
if opts.NewThreadGroup {
- tg.mounts.IncRef()
+ if tg.mounts != nil {
+ tg.mounts.IncRef()
+ }
sh := t.tg.signalHandlers
if opts.NewSignalHandlers {
sh = sh.Fork()
@@ -260,11 +273,12 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
FDTable: fdTable,
Credentials: creds,
Niceness: t.Niceness(),
- NetworkNamespaced: t.netns,
+ NetworkNamespace: netns,
AllowedCPUMask: t.CPUMask(),
UTSNamespace: utsns,
IPCNamespace: ipcns,
AbstractSocketNamespace: t.abstractSockets,
+ MountNamespaceVFS2: mntnsVFS2,
RSeqAddr: rseqAddr,
RSeqSignature: rseqSignature,
ContainerID: t.ContainerID(),
@@ -274,9 +288,6 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
} else {
cfg.InheritParent = t
}
- if opts.NewNetworkNamespace {
- cfg.NetworkNamespaced = true
- }
nt, err := t.tg.pidns.owner.NewTask(cfg)
if err != nil {
if opts.NewThreadGroup {
@@ -473,7 +484,7 @@ func (t *Task) Unshare(opts *SharingOptions) error {
t.mu.Unlock()
return syserror.EPERM
}
- t.netns = true
+ t.netns = inet.NewNamespace(t.netns)
}
if opts.NewUTSNamespace {
if !haveCapSysAdmin {
diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go
index 2d6e7733c..0158b1788 100644
--- a/pkg/sentry/kernel/task_context.go
+++ b/pkg/sentry/kernel/task_context.go
@@ -136,11 +136,11 @@ func (t *Task) Stack() *arch.Stack {
func (k *Kernel) LoadTaskImage(ctx context.Context, args loader.LoadArgs) (*TaskContext, *syserr.Error) {
// If File is not nil, we should load that instead of resolving Filename.
if args.File != nil {
- args.Filename = args.File.MappedName(ctx)
+ args.Filename = args.File.PathnameWithDeleted(ctx)
}
// Prepare a new user address space to load into.
- m := mm.NewMemoryManager(k, k)
+ m := mm.NewMemoryManager(k, k, k.SleepForAddressSpaceActivation)
defer m.DecUsers(ctx)
args.MemoryManager = m
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
index 8f57a34a6..00c425cca 100644
--- a/pkg/sentry/kernel/task_exec.go
+++ b/pkg/sentry/kernel/task_exec.go
@@ -220,7 +220,7 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
t.mu.Unlock()
t.unstopVforkParent()
// NOTE(b/30316266): All locks must be dropped prior to calling Activate.
- t.MemoryManager().Activate()
+ t.MemoryManager().Activate(t)
t.ptraceExec(oldTID)
return (*runSyscallExit)(nil)
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index 435761e5a..c4ade6e8e 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -269,6 +269,13 @@ func (*runExitMain) execute(t *Task) taskRunState {
t.fsContext.DecRef()
t.fdTable.DecRef()
+ t.mu.Lock()
+ if t.mountNamespaceVFS2 != nil {
+ t.mountNamespaceVFS2.DecRef()
+ t.mountNamespaceVFS2 = nil
+ }
+ t.mu.Unlock()
+
// If this is the last task to exit from the thread group, release the
// thread group's resources.
if lastExiter {
diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go
index 41259210c..eeccaa197 100644
--- a/pkg/sentry/kernel/task_log.go
+++ b/pkg/sentry/kernel/task_log.go
@@ -32,21 +32,21 @@ const (
// Infof logs an formatted info message by calling log.Infof.
func (t *Task) Infof(fmt string, v ...interface{}) {
if log.IsLogging(log.Info) {
- log.Infof(t.logPrefix.Load().(string)+fmt, v...)
+ log.InfofAtDepth(1, t.logPrefix.Load().(string)+fmt, v...)
}
}
// Warningf logs a warning string by calling log.Warningf.
func (t *Task) Warningf(fmt string, v ...interface{}) {
if log.IsLogging(log.Warning) {
- log.Warningf(t.logPrefix.Load().(string)+fmt, v...)
+ log.WarningfAtDepth(1, t.logPrefix.Load().(string)+fmt, v...)
}
}
// Debugf creates a debug string that includes the task ID.
func (t *Task) Debugf(fmt string, v ...interface{}) {
if log.IsLogging(log.Debug) {
- log.Debugf(t.logPrefix.Load().(string)+fmt, v...)
+ log.DebugfAtDepth(1, t.logPrefix.Load().(string)+fmt, v...)
}
}
@@ -198,18 +198,11 @@ func (t *Task) traceExecEvent(tc *TaskContext) {
if !trace.IsEnabled() {
return
}
- d := tc.MemoryManager.Executable()
- if d == nil {
+ file := tc.MemoryManager.Executable()
+ if file == nil {
trace.Logf(t.traceContext, traceCategory, "exec: << unknown >>")
return
}
- defer d.DecRef()
- root := t.fsContext.RootDirectory()
- if root == nil {
- trace.Logf(t.traceContext, traceCategory, "exec: << no root directory >>")
- return
- }
- defer root.DecRef()
- n, _ := d.FullName(root)
- trace.Logf(t.traceContext, traceCategory, "exec: %s", n)
+ defer file.DecRef()
+ trace.Logf(t.traceContext, traceCategory, "exec: %s", file.PathnameWithDeleted(t))
}
diff --git a/pkg/sentry/kernel/task_net.go b/pkg/sentry/kernel/task_net.go
index 172a31e1d..f7711232c 100644
--- a/pkg/sentry/kernel/task_net.go
+++ b/pkg/sentry/kernel/task_net.go
@@ -22,14 +22,23 @@ import (
func (t *Task) IsNetworkNamespaced() bool {
t.mu.Lock()
defer t.mu.Unlock()
- return t.netns
+ return !t.netns.IsRoot()
}
// NetworkContext returns the network stack used by the task. NetworkContext
// may return nil if no network stack is available.
+//
+// TODO(gvisor.dev/issue/1833): Migrate callers of this method to
+// NetworkNamespace().
func (t *Task) NetworkContext() inet.Stack {
- if t.IsNetworkNamespaced() {
- return nil
- }
- return t.k.networkStack
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.netns.Stack()
+}
+
+// NetworkNamespace returns the network namespace observed by the task.
+func (t *Task) NetworkNamespace() *inet.Namespace {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.netns
}
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index de838beef..a5035bb7f 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -17,10 +17,12 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
"gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -64,9 +66,8 @@ type TaskConfig struct {
// Niceness is the niceness of the new task.
Niceness int
- // If NetworkNamespaced is true, the new task should observe a non-root
- // network namespace.
- NetworkNamespaced bool
+ // NetworkNamespace is the network namespace to be used for the new task.
+ NetworkNamespace *inet.Namespace
// AllowedCPUMask contains the cpus that this task can run on.
AllowedCPUMask sched.CPUSet
@@ -80,6 +81,9 @@ type TaskConfig struct {
// AbstractSocketNamespace is the AbstractSocketNamespace of the new task.
AbstractSocketNamespace *AbstractSocketNamespace
+ // MountNamespaceVFS2 is the MountNamespace of the new task.
+ MountNamespaceVFS2 *vfs.MountNamespace
+
// RSeqAddr is a pointer to the the userspace linux.RSeq structure.
RSeqAddr usermem.Addr
@@ -116,28 +120,29 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
parent: cfg.Parent,
children: make(map[*Task]struct{}),
},
- runState: (*runApp)(nil),
- interruptChan: make(chan struct{}, 1),
- signalMask: cfg.SignalMask,
- signalStack: arch.SignalStack{Flags: arch.SignalStackFlagDisable},
- tc: *tc,
- fsContext: cfg.FSContext,
- fdTable: cfg.FDTable,
- p: cfg.Kernel.Platform.NewContext(),
- k: cfg.Kernel,
- ptraceTracees: make(map[*Task]struct{}),
- allowedCPUMask: cfg.AllowedCPUMask.Copy(),
- ioUsage: &usage.IO{},
- niceness: cfg.Niceness,
- netns: cfg.NetworkNamespaced,
- utsns: cfg.UTSNamespace,
- ipcns: cfg.IPCNamespace,
- abstractSockets: cfg.AbstractSocketNamespace,
- rseqCPU: -1,
- rseqAddr: cfg.RSeqAddr,
- rseqSignature: cfg.RSeqSignature,
- futexWaiter: futex.NewWaiter(),
- containerID: cfg.ContainerID,
+ runState: (*runApp)(nil),
+ interruptChan: make(chan struct{}, 1),
+ signalMask: cfg.SignalMask,
+ signalStack: arch.SignalStack{Flags: arch.SignalStackFlagDisable},
+ tc: *tc,
+ fsContext: cfg.FSContext,
+ fdTable: cfg.FDTable,
+ p: cfg.Kernel.Platform.NewContext(),
+ k: cfg.Kernel,
+ ptraceTracees: make(map[*Task]struct{}),
+ allowedCPUMask: cfg.AllowedCPUMask.Copy(),
+ ioUsage: &usage.IO{},
+ niceness: cfg.Niceness,
+ netns: cfg.NetworkNamespace,
+ utsns: cfg.UTSNamespace,
+ ipcns: cfg.IPCNamespace,
+ abstractSockets: cfg.AbstractSocketNamespace,
+ mountNamespaceVFS2: cfg.MountNamespaceVFS2,
+ rseqCPU: -1,
+ rseqAddr: cfg.RSeqAddr,
+ rseqSignature: cfg.RSeqSignature,
+ futexWaiter: futex.NewWaiter(),
+ containerID: cfg.ContainerID,
}
t.creds.Store(cfg.Credentials)
t.endStopCond.L = &t.tg.signalHandlers.mu
diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go
index 2bf3ce8a8..b02044ad2 100644
--- a/pkg/sentry/kernel/task_usermem.go
+++ b/pkg/sentry/kernel/task_usermem.go
@@ -30,7 +30,7 @@ var MAX_RW_COUNT = int(usermem.Addr(math.MaxInt32).RoundDown())
// Activate ensures that the task has an active address space.
func (t *Task) Activate() {
if mm := t.MemoryManager(); mm != nil {
- if err := mm.Activate(); err != nil {
+ if err := mm.Activate(t); err != nil {
panic("unable to activate mm: " + err.Error())
}
}
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index 768e958d2..268f62e9d 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -256,7 +256,7 @@ type ThreadGroup struct {
tty *TTY
}
-// NewThreadGroup returns a new, empty thread group in PID namespace ns. The
+// NewThreadGroup returns a new, empty thread group in PID namespace pidns. The
// thread group leader will send its parent terminationSignal when it exits.
// The new thread group isn't visible to the system until a task has been
// created inside of it by a successful call to TaskSet.NewTask.
@@ -317,7 +317,9 @@ func (tg *ThreadGroup) release() {
for _, it := range its {
it.DestroyTimer()
}
- tg.mounts.DecRef()
+ if tg.mounts != nil {
+ tg.mounts.DecRef()
+ }
}
// forEachChildThreadGroupLocked indicates over all child ThreadGroups.
diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD
index 23790378a..c6aa65f28 100644
--- a/pkg/sentry/loader/BUILD
+++ b/pkg/sentry/loader/BUILD
@@ -33,6 +33,7 @@ go_library(
"//pkg/sentry/fs",
"//pkg/sentry/fs/anon",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/limits",
"//pkg/sentry/memmap",
@@ -40,6 +41,7 @@ go_library(
"//pkg/sentry/pgalloc",
"//pkg/sentry/uniqueid",
"//pkg/sentry/usage",
+ "//pkg/sentry/vfs",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/usermem",
diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go
index 122ed05c2..616fafa2c 100644
--- a/pkg/sentry/loader/elf.go
+++ b/pkg/sentry/loader/elf.go
@@ -27,7 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/mm"
@@ -97,11 +97,11 @@ type elfInfo struct {
// accepts from the ELF, and it doesn't parse unnecessary parts of the file.
//
// ctx may be nil if f does not need it.
-func parseHeader(ctx context.Context, f *fs.File) (elfInfo, error) {
+func parseHeader(ctx context.Context, f fsbridge.File) (elfInfo, error) {
// Check ident first; it will tell us the endianness of the rest of the
// structs.
var ident [elf.EI_NIDENT]byte
- _, err := readFull(ctx, f, usermem.BytesIOSequence(ident[:]), 0)
+ _, err := f.ReadFull(ctx, usermem.BytesIOSequence(ident[:]), 0)
if err != nil {
log.Infof("Error reading ELF ident: %v", err)
// The entire ident array always exists.
@@ -137,7 +137,7 @@ func parseHeader(ctx context.Context, f *fs.File) (elfInfo, error) {
var hdr elf.Header64
hdrBuf := make([]byte, header64Size)
- _, err = readFull(ctx, f, usermem.BytesIOSequence(hdrBuf), 0)
+ _, err = f.ReadFull(ctx, usermem.BytesIOSequence(hdrBuf), 0)
if err != nil {
log.Infof("Error reading ELF header: %v", err)
// The entire header always exists.
@@ -187,7 +187,7 @@ func parseHeader(ctx context.Context, f *fs.File) (elfInfo, error) {
}
phdrBuf := make([]byte, totalPhdrSize)
- _, err = readFull(ctx, f, usermem.BytesIOSequence(phdrBuf), int64(hdr.Phoff))
+ _, err = f.ReadFull(ctx, usermem.BytesIOSequence(phdrBuf), int64(hdr.Phoff))
if err != nil {
log.Infof("Error reading ELF phdrs: %v", err)
// If phdrs were specified, they should all exist.
@@ -227,7 +227,7 @@ func parseHeader(ctx context.Context, f *fs.File) (elfInfo, error) {
// mapSegment maps a phdr into the Task. offset is the offset to apply to
// phdr.Vaddr.
-func mapSegment(ctx context.Context, m *mm.MemoryManager, f *fs.File, phdr *elf.ProgHeader, offset usermem.Addr) error {
+func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr *elf.ProgHeader, offset usermem.Addr) error {
// We must make a page-aligned mapping.
adjust := usermem.Addr(phdr.Vaddr).PageOffset()
@@ -395,7 +395,7 @@ type loadedELF struct {
//
// Preconditions:
// * f is an ELF file
-func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, info elfInfo, sharedLoadOffset usermem.Addr) (loadedELF, error) {
+func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, info elfInfo, sharedLoadOffset usermem.Addr) (loadedELF, error) {
first := true
var start, end usermem.Addr
var interpreter string
@@ -431,7 +431,7 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, info el
}
path := make([]byte, phdr.Filesz)
- _, err := readFull(ctx, f, usermem.BytesIOSequence(path), int64(phdr.Off))
+ _, err := f.ReadFull(ctx, usermem.BytesIOSequence(path), int64(phdr.Off))
if err != nil {
// If an interpreter was specified, it should exist.
ctx.Infof("Error reading PT_INTERP path: %v", err)
@@ -564,7 +564,7 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, info el
// Preconditions:
// * f is an ELF file
// * f is the first ELF loaded into m
-func loadInitialELF(ctx context.Context, m *mm.MemoryManager, fs *cpuid.FeatureSet, f *fs.File) (loadedELF, arch.Context, error) {
+func loadInitialELF(ctx context.Context, m *mm.MemoryManager, fs *cpuid.FeatureSet, f fsbridge.File) (loadedELF, arch.Context, error) {
info, err := parseHeader(ctx, f)
if err != nil {
ctx.Infof("Failed to parse initial ELF: %v", err)
@@ -602,7 +602,7 @@ func loadInitialELF(ctx context.Context, m *mm.MemoryManager, fs *cpuid.FeatureS
//
// Preconditions:
// * f is an ELF file
-func loadInterpreterELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, initial loadedELF) (loadedELF, error) {
+func loadInterpreterELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, initial loadedELF) (loadedELF, error) {
info, err := parseHeader(ctx, f)
if err != nil {
if err == syserror.ENOEXEC {
@@ -649,16 +649,14 @@ func loadELF(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, error
// Refresh the traversal limit.
*args.RemainingTraversals = linux.MaxSymlinkTraversals
args.Filename = bin.interpreter
- d, i, err := openPath(ctx, args)
+ intFile, err := openPath(ctx, args)
if err != nil {
ctx.Infof("Error opening interpreter %s: %v", bin.interpreter, err)
return loadedELF{}, nil, err
}
- defer i.DecRef()
- // We don't need the Dirent.
- d.DecRef()
+ defer intFile.DecRef()
- interp, err = loadInterpreterELF(ctx, args.MemoryManager, i, bin)
+ interp, err = loadInterpreterELF(ctx, args.MemoryManager, intFile, bin)
if err != nil {
ctx.Infof("Error loading interpreter: %v", err)
return loadedELF{}, nil, err
diff --git a/pkg/sentry/loader/interpreter.go b/pkg/sentry/loader/interpreter.go
index 098a45d36..3886b4d33 100644
--- a/pkg/sentry/loader/interpreter.go
+++ b/pkg/sentry/loader/interpreter.go
@@ -19,7 +19,7 @@ import (
"io"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -37,9 +37,9 @@ const (
)
// parseInterpreterScript returns the interpreter path and argv.
-func parseInterpreterScript(ctx context.Context, filename string, f *fs.File, argv []string) (newpath string, newargv []string, err error) {
+func parseInterpreterScript(ctx context.Context, filename string, f fsbridge.File, argv []string) (newpath string, newargv []string, err error) {
line := make([]byte, interpMaxLineLength)
- n, err := readFull(ctx, f, usermem.BytesIOSequence(line), 0)
+ n, err := f.ReadFull(ctx, usermem.BytesIOSequence(line), 0)
// Short read is OK.
if err != nil && err != io.ErrUnexpectedEOF {
if err == io.EOF {
diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go
index 9a613d6b7..d6675b8f0 100644
--- a/pkg/sentry/loader/loader.go
+++ b/pkg/sentry/loader/loader.go
@@ -20,7 +20,6 @@ import (
"fmt"
"io"
"path"
- "strings"
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -29,8 +28,10 @@ import (
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -41,16 +42,6 @@ type LoadArgs struct {
// MemoryManager is the memory manager to load the executable into.
MemoryManager *mm.MemoryManager
- // Mounts is the mount namespace in which to look up Filename.
- Mounts *fs.MountNamespace
-
- // Root is the root directory under which to look up Filename.
- Root *fs.Dirent
-
- // WorkingDirectory is the working directory under which to look up
- // Filename.
- WorkingDirectory *fs.Dirent
-
// RemainingTraversals is the maximum number of symlinks to follow to
// resolve Filename. This counter is passed by reference to keep it
// updated throughout the call stack.
@@ -65,7 +56,12 @@ type LoadArgs struct {
// File is an open fs.File object of the executable. If File is not
// nil, then File will be loaded and Filename will be ignored.
- File *fs.File
+ //
+ // The caller is responsible for checking that the user can execute this file.
+ File fsbridge.File
+
+ // Opener is used to open the executable file when 'File' is nil.
+ Opener fsbridge.Lookup
// CloseOnExec indicates that the executable (or one of its parent
// directories) was opened with O_CLOEXEC. If the executable is an
@@ -106,103 +102,32 @@ func readFull(ctx context.Context, f *fs.File, dst usermem.IOSequence, offset in
// installed in the Task FDTable. The caller takes ownership of both.
//
// args.Filename must be a readable, executable, regular file.
-func openPath(ctx context.Context, args LoadArgs) (*fs.Dirent, *fs.File, error) {
+func openPath(ctx context.Context, args LoadArgs) (fsbridge.File, error) {
if args.Filename == "" {
ctx.Infof("cannot open empty name")
- return nil, nil, syserror.ENOENT
- }
-
- var d *fs.Dirent
- var err error
- if args.ResolveFinal {
- d, err = args.Mounts.FindInode(ctx, args.Root, args.WorkingDirectory, args.Filename, args.RemainingTraversals)
- } else {
- d, err = args.Mounts.FindLink(ctx, args.Root, args.WorkingDirectory, args.Filename, args.RemainingTraversals)
- }
- if err != nil {
- return nil, nil, err
- }
- // Defer a DecRef for the sake of failure cases.
- defer d.DecRef()
-
- if !args.ResolveFinal && fs.IsSymlink(d.Inode.StableAttr) {
- return nil, nil, syserror.ELOOP
- }
-
- if err := checkPermission(ctx, d); err != nil {
- return nil, nil, err
- }
-
- // If they claim it's a directory, then make sure.
- //
- // N.B. we reject directories below, but we must first reject
- // non-directories passed as directories.
- if strings.HasSuffix(args.Filename, "/") && !fs.IsDir(d.Inode.StableAttr) {
- return nil, nil, syserror.ENOTDIR
- }
-
- if err := checkIsRegularFile(ctx, d, args.Filename); err != nil {
- return nil, nil, err
- }
-
- f, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true})
- if err != nil {
- return nil, nil, err
- }
- // Defer a DecRef for the sake of failure cases.
- defer f.DecRef()
-
- if err := checkPread(ctx, f, args.Filename); err != nil {
- return nil, nil, err
- }
-
- d.IncRef()
- f.IncRef()
- return d, f, err
-}
-
-// checkFile performs checks on a file to be executed.
-func checkFile(ctx context.Context, f *fs.File, filename string) error {
- if err := checkPermission(ctx, f.Dirent); err != nil {
- return err
- }
-
- if err := checkIsRegularFile(ctx, f.Dirent, filename); err != nil {
- return err
+ return nil, syserror.ENOENT
}
- return checkPread(ctx, f, filename)
-}
-
-// checkPermission checks whether the file is readable and executable.
-func checkPermission(ctx context.Context, d *fs.Dirent) error {
- perms := fs.PermMask{
- // TODO(gvisor.dev/issue/160): Linux requires only execute
- // permission, not read. However, our backing filesystems may
- // prevent us from reading the file without read permission.
- //
- // Additionally, a task with a non-readable executable has
- // additional constraints on access via ptrace and procfs.
- Read: true,
- Execute: true,
+ // TODO(gvisor.dev/issue/160): Linux requires only execute permission,
+ // not read. However, our backing filesystems may prevent us from reading
+ // the file without read permission. Additionally, a task with a
+ // non-readable executable has additional constraints on access via
+ // ptrace and procfs.
+ opts := vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ FileExec: true,
}
- return d.Inode.CheckPermission(ctx, perms)
+ return args.Opener.OpenPath(ctx, args.Filename, opts, args.RemainingTraversals, args.ResolveFinal)
}
// checkIsRegularFile prevents us from trying to execute a directory, pipe, etc.
-func checkIsRegularFile(ctx context.Context, d *fs.Dirent, filename string) error {
- attr := d.Inode.StableAttr
- if !fs.IsRegular(attr) {
- ctx.Infof("%s is not regular: %v", filename, attr)
- return syserror.EACCES
+func checkIsRegularFile(ctx context.Context, file fsbridge.File, filename string) error {
+ t, err := file.Type(ctx)
+ if err != nil {
+ return err
}
- return nil
-}
-
-// checkPread checks whether we can read the file at arbitrary offsets.
-func checkPread(ctx context.Context, f *fs.File, filename string) error {
- if !f.Flags().Pread {
- ctx.Infof("%s cannot be read at an offset: %+v", filename, f.Flags())
+ if t != linux.ModeRegular {
+ ctx.Infof("%q is not a regular file: %v", filename, t)
return syserror.EACCES
}
return nil
@@ -224,8 +149,10 @@ const (
maxLoaderAttempts = 6
)
-// loadExecutable loads an executable that is pointed to by args.File. If nil,
-// the path args.Filename is resolved and loaded. If the executable is an
+// loadExecutable loads an executable that is pointed to by args.File. The
+// caller is responsible for checking that the user can execute this file.
+// If nil, the path args.Filename is resolved and loaded (check that the user
+// can execute this file is done here in this case). If the executable is an
// interpreter script rather than an ELF, the binary of the corresponding
// interpreter will be loaded.
//
@@ -234,37 +161,27 @@ const (
// * arch.Context matching the binary arch
// * fs.Dirent of the binary file
// * Possibly updated args.Argv
-func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, *fs.Dirent, []string, error) {
+func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, fsbridge.File, []string, error) {
for i := 0; i < maxLoaderAttempts; i++ {
- var (
- d *fs.Dirent
- err error
- )
if args.File == nil {
- d, args.File, err = openPath(ctx, args)
- // We will return d in the successful case, but defer a DecRef for the
- // sake of intermediate loops and failure cases.
- if d != nil {
- defer d.DecRef()
- }
- if args.File != nil {
- defer args.File.DecRef()
+ var err error
+ args.File, err = openPath(ctx, args)
+ if err != nil {
+ ctx.Infof("Error opening %s: %v", args.Filename, err)
+ return loadedELF{}, nil, nil, nil, err
}
+ // Ensure file is release in case the code loops or errors out.
+ defer args.File.DecRef()
} else {
- d = args.File.Dirent
- d.IncRef()
- defer d.DecRef()
- err = checkFile(ctx, args.File, args.Filename)
- }
- if err != nil {
- ctx.Infof("Error opening %s: %v", args.Filename, err)
- return loadedELF{}, nil, nil, nil, err
+ if err := checkIsRegularFile(ctx, args.File, args.Filename); err != nil {
+ return loadedELF{}, nil, nil, nil, err
+ }
}
// Check the header. Is this an ELF or interpreter script?
var hdr [4]uint8
// N.B. We assume that reading from a regular file cannot block.
- _, err = readFull(ctx, args.File, usermem.BytesIOSequence(hdr[:]), 0)
+ _, err := args.File.ReadFull(ctx, usermem.BytesIOSequence(hdr[:]), 0)
// Allow unexpected EOF, as a valid executable could be only three bytes
// (e.g., #!a).
if err != nil && err != io.ErrUnexpectedEOF {
@@ -281,9 +198,10 @@ func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context
ctx.Infof("Error loading ELF: %v", err)
return loadedELF{}, nil, nil, nil, err
}
- // An ELF is always terminal. Hold on to d.
- d.IncRef()
- return loaded, ac, d, args.Argv, err
+ // An ELF is always terminal. Hold on to file.
+ args.File.IncRef()
+ return loaded, ac, args.File, args.Argv, err
+
case bytes.Equal(hdr[:2], []byte(interpreterScriptMagic)):
if args.CloseOnExec {
return loadedELF{}, nil, nil, nil, syserror.ENOENT
@@ -295,6 +213,7 @@ func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context
}
// Refresh the traversal limit for the interpreter.
*args.RemainingTraversals = linux.MaxSymlinkTraversals
+
default:
ctx.Infof("Unknown magic: %v", hdr)
return loadedELF{}, nil, nil, nil, syserror.ENOEXEC
@@ -317,11 +236,11 @@ func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context
// * Load is called on the Task goroutine.
func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *VDSO) (abi.OS, arch.Context, string, *syserr.Error) {
// Load the executable itself.
- loaded, ac, d, newArgv, err := loadExecutable(ctx, args)
+ loaded, ac, file, newArgv, err := loadExecutable(ctx, args)
if err != nil {
return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux())
}
- defer d.DecRef()
+ defer file.DecRef()
// Load the VDSO.
vdsoAddr, err := loadVDSO(ctx, args.MemoryManager, vdso, loaded)
@@ -390,7 +309,7 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V
m.SetEnvvStart(sl.EnvvStart)
m.SetEnvvEnd(sl.EnvvEnd)
m.SetAuxv(auxv)
- m.SetExecutable(d)
+ m.SetExecutable(file)
ac.SetIP(uintptr(loaded.entry))
ac.SetStack(uintptr(stack.Bottom))
diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go
index 52f446ed7..161b28c2c 100644
--- a/pkg/sentry/loader/vdso.go
+++ b/pkg/sentry/loader/vdso.go
@@ -27,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/anon"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
@@ -69,6 +70,8 @@ type byteReader struct {
var _ fs.FileOperations = (*byteReader)(nil)
// newByteReaderFile creates a fake file to read data from.
+//
+// TODO(gvisor.dev/issue/1623): Convert to VFS2.
func newByteReaderFile(ctx context.Context, data []byte) *fs.File {
// Create a fake inode.
inode := fs.NewInode(
@@ -123,7 +126,7 @@ func (b *byteReader) Write(ctx context.Context, file *fs.File, src usermem.IOSeq
// * PT_LOAD segments don't extend beyond the end of the file.
//
// ctx may be nil if f does not need it.
-func validateVDSO(ctx context.Context, f *fs.File, size uint64) (elfInfo, error) {
+func validateVDSO(ctx context.Context, f fsbridge.File, size uint64) (elfInfo, error) {
info, err := parseHeader(ctx, f)
if err != nil {
log.Infof("Unable to parse VDSO header: %v", err)
@@ -221,7 +224,7 @@ type VDSO struct {
// PrepareVDSO validates the system VDSO and returns a VDSO, containing the
// param page for updating by the kernel.
func PrepareVDSO(ctx context.Context, mfp pgalloc.MemoryFileProvider) (*VDSO, error) {
- vdsoFile := newByteReaderFile(ctx, vdsoBin)
+ vdsoFile := fsbridge.NewFSFile(newByteReaderFile(ctx, vdsoBin))
// First make sure the VDSO is valid. vdsoFile does not use ctx, so a
// nil context can be passed.
diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD
index e5729ced5..73591dab7 100644
--- a/pkg/sentry/mm/BUILD
+++ b/pkg/sentry/mm/BUILD
@@ -105,8 +105,8 @@ go_library(
"//pkg/safecopy",
"//pkg/safemem",
"//pkg/sentry/arch",
- "//pkg/sentry/fs",
"//pkg/sentry/fs/proc/seqfile",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/futex",
"//pkg/sentry/kernel/shm",
diff --git a/pkg/sentry/mm/README.md b/pkg/sentry/mm/README.md
index e1322e373..f4d43d927 100644
--- a/pkg/sentry/mm/README.md
+++ b/pkg/sentry/mm/README.md
@@ -274,7 +274,7 @@ In the sentry:
methods
[`platform.AddressSpace.MapFile` and `platform.AddressSpace.Unmap`][platform].
-[memmap]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/memmap/memmap.go
-[mm]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/mm/mm.go
-[pgalloc]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/pgalloc/pgalloc.go
-[platform]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/platform/platform.go
+[memmap]: https://github.com/google/gvisor/blob/master/pkg/sentry/memmap/memmap.go
+[mm]: https://github.com/google/gvisor/blob/master/pkg/sentry/mm/mm.go
+[pgalloc]: https://github.com/google/gvisor/blob/master/pkg/sentry/pgalloc/pgalloc.go
+[platform]: https://github.com/google/gvisor/blob/master/pkg/sentry/platform/platform.go
diff --git a/pkg/sentry/mm/address_space.go b/pkg/sentry/mm/address_space.go
index e58a63deb..0332fc71c 100644
--- a/pkg/sentry/mm/address_space.go
+++ b/pkg/sentry/mm/address_space.go
@@ -18,7 +18,7 @@ import (
"fmt"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/atomicbitops"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -39,11 +39,18 @@ func (mm *MemoryManager) AddressSpace() platform.AddressSpace {
//
// When this MemoryManager is no longer needed by a task, it should call
// Deactivate to release the reference.
-func (mm *MemoryManager) Activate() error {
+func (mm *MemoryManager) Activate(ctx context.Context) error {
// Fast path: the MemoryManager already has an active
// platform.AddressSpace, and we just need to indicate that we need it too.
- if atomicbitops.IncUnlessZeroInt32(&mm.active) {
- return nil
+ for {
+ active := atomic.LoadInt32(&mm.active)
+ if active == 0 {
+ // Fall back to the slow path.
+ break
+ }
+ if atomic.CompareAndSwapInt32(&mm.active, active, active+1) {
+ return nil
+ }
}
for {
@@ -85,16 +92,20 @@ func (mm *MemoryManager) Activate() error {
if as == nil {
// AddressSpace is unavailable, we must wait.
//
- // activeMu must not be held while waiting, as the user
- // of the address space we are waiting on may attempt
- // to take activeMu.
- //
- // Don't call UninterruptibleSleepStart to register the
- // wait to allow the watchdog stuck task to trigger in
- // case a process is starved waiting for the address
- // space.
+ // activeMu must not be held while waiting, as the user of the address
+ // space we are waiting on may attempt to take activeMu.
mm.activeMu.Unlock()
+
+ sleep := mm.p.CooperativelySchedulesAddressSpace() && mm.sleepForActivation
+ if sleep {
+ // Mark this task sleeping while waiting for the address space to
+ // prevent the watchdog from reporting it as a stuck task.
+ ctx.UninterruptibleSleepStart(false)
+ }
<-c
+ if sleep {
+ ctx.UninterruptibleSleepFinish(false)
+ }
continue
}
@@ -118,8 +129,15 @@ func (mm *MemoryManager) Activate() error {
func (mm *MemoryManager) Deactivate() {
// Fast path: this is not the last goroutine to deactivate the
// MemoryManager.
- if atomicbitops.DecUnlessOneInt32(&mm.active) {
- return
+ for {
+ active := atomic.LoadInt32(&mm.active)
+ if active == 1 {
+ // Fall back to the slow path.
+ break
+ }
+ if atomic.CompareAndSwapInt32(&mm.active, active, active-1) {
+ return
+ }
}
mm.activeMu.Lock()
diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go
index 47b8fbf43..d8a5b9d29 100644
--- a/pkg/sentry/mm/lifecycle.go
+++ b/pkg/sentry/mm/lifecycle.go
@@ -18,7 +18,6 @@ import (
"fmt"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/limits"
@@ -29,16 +28,17 @@ import (
)
// NewMemoryManager returns a new MemoryManager with no mappings and 1 user.
-func NewMemoryManager(p platform.Platform, mfp pgalloc.MemoryFileProvider) *MemoryManager {
+func NewMemoryManager(p platform.Platform, mfp pgalloc.MemoryFileProvider, sleepForActivation bool) *MemoryManager {
return &MemoryManager{
- p: p,
- mfp: mfp,
- haveASIO: p.SupportsAddressSpaceIO(),
- privateRefs: &privateRefs{},
- users: 1,
- auxv: arch.Auxv{},
- dumpability: UserDumpable,
- aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ p: p,
+ mfp: mfp,
+ haveASIO: p.SupportsAddressSpaceIO(),
+ privateRefs: &privateRefs{},
+ users: 1,
+ auxv: arch.Auxv{},
+ dumpability: UserDumpable,
+ aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ sleepForActivation: sleepForActivation,
}
}
@@ -80,9 +80,10 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) {
envv: mm.envv,
auxv: append(arch.Auxv(nil), mm.auxv...),
// IncRef'd below, once we know that there isn't an error.
- executable: mm.executable,
- dumpability: mm.dumpability,
- aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ executable: mm.executable,
+ dumpability: mm.dumpability,
+ aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ sleepForActivation: mm.sleepForActivation,
}
// Copy vmas.
@@ -229,7 +230,15 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) {
// IncUsers increments mm's user count and returns true. If the user count is
// already 0, IncUsers does nothing and returns false.
func (mm *MemoryManager) IncUsers() bool {
- return atomicbitops.IncUnlessZeroInt32(&mm.users)
+ for {
+ users := atomic.LoadInt32(&mm.users)
+ if users == 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt32(&mm.users, users, users+1) {
+ return true
+ }
+ }
}
// DecUsers decrements mm's user count. If the user count reaches 0, all
diff --git a/pkg/sentry/mm/metadata.go b/pkg/sentry/mm/metadata.go
index f550acae0..6a49334f4 100644
--- a/pkg/sentry/mm/metadata.go
+++ b/pkg/sentry/mm/metadata.go
@@ -16,7 +16,7 @@ package mm
import (
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -132,7 +132,7 @@ func (mm *MemoryManager) SetAuxv(auxv arch.Auxv) {
//
// An additional reference will be taken in the case of a non-nil executable,
// which must be released by the caller.
-func (mm *MemoryManager) Executable() *fs.Dirent {
+func (mm *MemoryManager) Executable() fsbridge.File {
mm.metadataMu.Lock()
defer mm.metadataMu.Unlock()
@@ -147,15 +147,15 @@ func (mm *MemoryManager) Executable() *fs.Dirent {
// SetExecutable sets the executable.
//
// This takes a reference on d.
-func (mm *MemoryManager) SetExecutable(d *fs.Dirent) {
+func (mm *MemoryManager) SetExecutable(file fsbridge.File) {
mm.metadataMu.Lock()
// Grab a new reference.
- d.IncRef()
+ file.IncRef()
// Set the executable.
orig := mm.executable
- mm.executable = d
+ mm.executable = file
mm.metadataMu.Unlock()
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
index 09e582dd3..c2195ae11 100644
--- a/pkg/sentry/mm/mm.go
+++ b/pkg/sentry/mm/mm.go
@@ -37,7 +37,7 @@ package mm
import (
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/platform"
@@ -215,7 +215,7 @@ type MemoryManager struct {
// is not nil, it holds a reference on the Dirent.
//
// executable is protected by metadataMu.
- executable *fs.Dirent
+ executable fsbridge.File
// dumpability describes if and how this MemoryManager may be dumped to
// userspace.
@@ -226,6 +226,11 @@ type MemoryManager struct {
// aioManager keeps track of AIOContexts used for async IOs. AIOManager
// must be cloned when CLONE_VM is used.
aioManager aioManager
+
+ // sleepForActivation indicates whether the task should report to be sleeping
+ // before trying to activate the address space. When set to true, delays in
+ // activation are not reported as stuck tasks by the watchdog.
+ sleepForActivation bool
}
// vma represents a virtual memory area.
diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go
index edacca741..fdc308542 100644
--- a/pkg/sentry/mm/mm_test.go
+++ b/pkg/sentry/mm/mm_test.go
@@ -31,7 +31,7 @@ import (
func testMemoryManager(ctx context.Context) *MemoryManager {
p := platform.FromContext(ctx)
mfp := pgalloc.MemoryFileProviderFromContext(ctx)
- mm := NewMemoryManager(p, mfp)
+ mm := NewMemoryManager(p, mfp, false)
mm.layout = arch.MmapLayout{
MinAddr: p.MinUserAddress(),
MaxAddr: p.MaxUserAddress(),
diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go
index 35cd55fef..4b23f7803 100644
--- a/pkg/sentry/platform/kvm/bluepill.go
+++ b/pkg/sentry/platform/kvm/bluepill.go
@@ -81,12 +81,6 @@ func (c *vCPU) die(context *arch.SignalContext64, msg string) {
// Save the death message, which will be thrown.
c.dieState.message = msg
- // Reload all registers to have an accurate stack trace when we return
- // to host mode. This means that the stack should be unwound correctly.
- if errno := c.getUserRegisters(&c.dieState.guestRegs); errno != 0 {
- throw(msg)
- }
-
// Setup the trampoline.
dieArchSetup(c, context, &c.dieState.guestRegs)
}
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
index a63a6a071..99cac665d 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
@@ -31,6 +31,12 @@ import (
//
//go:nosplit
func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) {
+ // Reload all registers to have an accurate stack trace when we return
+ // to host mode. This means that the stack should be unwound correctly.
+ if errno := c.getUserRegisters(&c.dieState.guestRegs); errno != 0 {
+ throw(c.dieState.message)
+ }
+
// If the vCPU is in user mode, we set the stack to the stored stack
// value in the vCPU itself. We don't want to unwind the user stack.
if guestRegs.RFLAGS&ring0.UserFlagsSet == ring0.UserFlagsSet {
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index 8076c7529..f1afc74dc 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -329,10 +329,12 @@ func (m *machine) Destroy() {
}
// Get gets an available vCPU.
+//
+// This will return with the OS thread locked.
func (m *machine) Get() *vCPU {
+ m.mu.RLock()
runtime.LockOSThread()
tid := procid.Current()
- m.mu.RLock()
// Check for an exact match.
if c := m.vCPUs[tid]; c != nil {
@@ -343,8 +345,22 @@ func (m *machine) Get() *vCPU {
// The happy path failed. We now proceed to acquire an exclusive lock
// (because the vCPU map may change), and scan all available vCPUs.
+ // In this case, we first unlock the OS thread. Otherwise, if mu is
+ // not available, the current system thread will be parked and a new
+ // system thread spawned. We avoid this situation by simply refreshing
+ // tid after relocking the system thread.
m.mu.RUnlock()
+ runtime.UnlockOSThread()
m.mu.Lock()
+ runtime.LockOSThread()
+ tid = procid.Current()
+
+ // Recheck for an exact match.
+ if c := m.vCPUs[tid]; c != nil {
+ c.lock()
+ m.mu.Unlock()
+ return c
+ }
for {
// Scan for an available vCPU.
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 1c8384e6b..b531f2f85 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -29,30 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-// setMemoryRegion initializes a region.
-//
-// This may be called from bluepillHandler, and therefore returns an errno
-// directly (instead of wrapping in an error) to avoid allocations.
-//
-//go:nosplit
-func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr) syscall.Errno {
- userRegion := userMemoryRegion{
- slot: uint32(slot),
- flags: 0,
- guestPhysAddr: uint64(physical),
- memorySize: uint64(length),
- userspaceAddr: uint64(virtual),
- }
-
- // Set the region.
- _, _, errno := syscall.RawSyscall(
- syscall.SYS_IOCTL,
- uintptr(m.fd),
- _KVM_SET_USER_MEMORY_REGION,
- uintptr(unsafe.Pointer(&userRegion)))
- return errno
-}
-
type kvmVcpuInit struct {
target uint32
features [7]uint32
diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go
index 6b078cd1e..f6da41c27 100644
--- a/pkg/sentry/platform/ring0/aarch64.go
+++ b/pkg/sentry/platform/ring0/aarch64.go
@@ -88,14 +88,14 @@ const (
El0Sync_undef
El0Sync_dbg
El0Sync_inv
- VirtualizationException
_NR_INTERRUPTS
)
// System call vectors.
const (
- Syscall Vector = El0Sync_svc
- PageFault Vector = El0Sync_da
+ Syscall Vector = El0Sync_svc
+ PageFault Vector = El0Sync_da
+ VirtualizationException Vector = El0Error
)
// VirtualAddressBits returns the number bits available for virtual addresses.
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index 679842288..d42eda37b 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -25,10 +25,14 @@
// not available for calls.
//
+// ERET returns using the ELR and SPSR for the current exception level.
#define ERET() \
WORD $0xd69f03e0
+// RSV_REG is a register that holds el1 information temporarily.
#define RSV_REG R18_PLATFORM
+
+// RSV_REG_APP is a register that holds el0 information temporarily.
#define RSV_REG_APP R9
#define FPEN_NOTRAP 0x3
@@ -36,6 +40,12 @@
#define FPEN_ENABLE (FPEN_NOTRAP << FPEN_SHIFT)
+// Saves a register set.
+//
+// This is a macro because it may need to executed in contents where a stack is
+// not available for calls.
+//
+// The following registers are not saved: R9, R18.
#define REGISTERS_SAVE(reg, offset) \
MOVD R0, offset+PTRACE_R0(reg); \
MOVD R1, offset+PTRACE_R1(reg); \
@@ -67,6 +77,12 @@
MOVD R29, offset+PTRACE_R29(reg); \
MOVD R30, offset+PTRACE_R30(reg);
+// Loads a register set.
+//
+// This is a macro because it may need to executed in contents where a stack is
+// not available for calls.
+//
+// The following registers are not loaded: R9, R18.
#define REGISTERS_LOAD(reg, offset) \
MOVD offset+PTRACE_R0(reg), R0; \
MOVD offset+PTRACE_R1(reg), R1; \
@@ -98,7 +114,7 @@
MOVD offset+PTRACE_R29(reg), R29; \
MOVD offset+PTRACE_R30(reg), R30;
-//NOP
+// NOP-s
#define nop31Instructions() \
WORD $0xd503201f; \
WORD $0xd503201f; \
@@ -254,6 +270,7 @@
#define ESR_ELx_WFx_ISS_WFE (UL(1) << 0)
#define ESR_ELx_xVC_IMM_MASK ((1UL << 16) - 1)
+// LOAD_KERNEL_ADDRESS loads a kernel address.
#define LOAD_KERNEL_ADDRESS(from, to) \
MOVD from, to; \
ORR $0xffff000000000000, to, to;
@@ -263,15 +280,18 @@
LOAD_KERNEL_ADDRESS(CPU_SELF(from), RSV_REG); \
MOVD $CPU_STACK_TOP(RSV_REG), RSV_REG; \
MOVD RSV_REG, RSP; \
+ WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
ISB $15; \
DSB $15;
+// SWITCH_TO_APP_PAGETABLE sets a new pagetable for a container application.
#define SWITCH_TO_APP_PAGETABLE(from) \
MOVD CPU_TTBR0_APP(from), RSV_REG; \
WORD $0xd5182012; \ // MSR R18, TTBR0_EL1
ISB $15; \
DSB $15;
+// SWITCH_TO_KVM_PAGETABLE sets the kvm pagetable.
#define SWITCH_TO_KVM_PAGETABLE(from) \
MOVD CPU_TTBR0_KVM(from), RSV_REG; \
WORD $0xd5182012; \ // MSR R18, TTBR0_EL1
@@ -294,6 +314,7 @@
WORD $0xd5181040; \ //MSR R0, CPACR_EL1
ISB $15;
+// KERNEL_ENTRY_FROM_EL0 is the entry code of the vcpu from el0 to el1.
#define KERNEL_ENTRY_FROM_EL0 \
SUB $16, RSP, RSP; \ // step1, save r18, r9 into kernel temporary stack.
STP (RSV_REG, RSV_REG_APP), 16*0(RSP); \
@@ -315,19 +336,22 @@
WORD $0xd5384103; \ // MRS SP_EL0, R3
MOVD R3, PTRACE_SP(RSV_REG_APP);
+// KERNEL_ENTRY_FROM_EL1 is the entry code of the vcpu from el1 to el1.
#define KERNEL_ENTRY_FROM_EL1 \
WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
- REGISTERS_SAVE(RSV_REG, CPU_REGISTERS); \ // save sentry context
+ REGISTERS_SAVE(RSV_REG, CPU_REGISTERS); \ // Save sentry context.
MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG); \
WORD $0xd5384004; \ // MRS SPSR_EL1, R4
MOVD R4, CPU_REGISTERS+PTRACE_PSTATE(RSV_REG); \
MRS ELR_EL1, R4; \
MOVD R4, CPU_REGISTERS+PTRACE_PC(RSV_REG); \
MOVD RSP, R4; \
- MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG);
+ MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \
+ LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack.
+// Halt halts execution.
TEXT ·Halt(SB),NOSPLIT,$0
- // clear bluepill.
+ // Clear bluepill.
WORD $0xd538d092 //MRS TPIDR_EL1, R18
CMP RSV_REG, R9
BNE mmio_exit
@@ -341,8 +365,22 @@ mmio_exit:
// MMIO_EXIT.
MOVD $0, R9
MOVD R0, 0xffff000000001000(R9)
- B ·kernelExitToEl1(SB)
+ RET
+
+// HaltAndResume halts execution and point the pointer to the resume function.
+TEXT ·HaltAndResume(SB),NOSPLIT,$0
+ BL ·Halt(SB)
+ B ·kernelExitToEl1(SB) // Resume.
+// HaltEl1SvcAndResume calls Hooks.KernelSyscall and resume.
+TEXT ·HaltEl1SvcAndResume(SB),NOSPLIT,$0
+ WORD $0xd538d092 // MRS TPIDR_EL1, R18
+ MOVD CPU_SELF(RSV_REG), R3 // Load vCPU.
+ MOVD R3, 8(RSP) // First argument (vCPU).
+ CALL ·kernelSyscall(SB) // Call the trampoline.
+ B ·kernelExitToEl1(SB) // Resume.
+
+// Shutdown stops the guest.
TEXT ·Shutdown(SB),NOSPLIT,$0
// PSCI EVENT.
MOVD $0x84000009, R0
@@ -429,6 +467,7 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0
TEXT ·kernelExitToEl1(SB),NOSPLIT,$0
ERET()
+// Start is the CPU entrypoint.
TEXT ·Start(SB),NOSPLIT,$0
IRQ_DISABLE
MOVD R8, RSV_REG
@@ -437,18 +476,23 @@ TEXT ·Start(SB),NOSPLIT,$0
B ·kernelExitToEl1(SB)
+// El1_sync_invalid is the handler for an invalid EL1_sync.
TEXT ·El1_sync_invalid(SB),NOSPLIT,$0
B ·Shutdown(SB)
+// El1_irq_invalid is the handler for an invalid El1_irq.
TEXT ·El1_irq_invalid(SB),NOSPLIT,$0
B ·Shutdown(SB)
+// El1_fiq_invalid is the handler for an invalid El1_fiq.
TEXT ·El1_fiq_invalid(SB),NOSPLIT,$0
B ·Shutdown(SB)
+// El1_error_invalid is the handler for an invalid El1_error.
TEXT ·El1_error_invalid(SB),NOSPLIT,$0
B ·Shutdown(SB)
+// El1_sync is the handler for El1_sync.
TEXT ·El1_sync(SB),NOSPLIT,$0
KERNEL_ENTRY_FROM_EL1
WORD $0xd5385219 // MRS ESR_EL1, R25
@@ -484,10 +528,10 @@ el1_da:
MOVD $PageFault, R3
MOVD R3, CPU_VECTOR_CODE(RSV_REG)
- B ·Halt(SB)
+ B ·HaltAndResume(SB)
el1_ia:
- B ·Halt(SB)
+ B ·HaltAndResume(SB)
el1_sp_pc:
B ·Shutdown(SB)
@@ -496,7 +540,9 @@ el1_undef:
B ·Shutdown(SB)
el1_svc:
- B ·Halt(SB)
+ MOVD $0, CPU_ERROR_CODE(RSV_REG)
+ MOVD $0, CPU_ERROR_TYPE(RSV_REG)
+ B ·HaltEl1SvcAndResume(SB)
el1_dbg:
B ·Shutdown(SB)
@@ -508,15 +554,19 @@ el1_fpsimd_acc:
el1_invalid:
B ·Shutdown(SB)
+// El1_irq is the handler for El1_irq.
TEXT ·El1_irq(SB),NOSPLIT,$0
B ·Shutdown(SB)
+// El1_fiq is the handler for El1_fiq.
TEXT ·El1_fiq(SB),NOSPLIT,$0
B ·Shutdown(SB)
+// El1_error is the handler for El1_error.
TEXT ·El1_error(SB),NOSPLIT,$0
B ·Shutdown(SB)
+// El0_sync is the handler for El0_sync.
TEXT ·El0_sync(SB),NOSPLIT,$0
KERNEL_ENTRY_FROM_EL0
WORD $0xd5385219 // MRS ESR_EL1, R25
@@ -554,7 +604,7 @@ el0_svc:
MOVD $Syscall, R3
MOVD R3, CPU_VECTOR_CODE(RSV_REG)
- B ·Halt(SB)
+ B ·HaltAndResume(SB)
el0_da:
WORD $0xd538d092 //MRS TPIDR_EL1, R18
@@ -568,7 +618,7 @@ el0_da:
MOVD $PageFault, R3
MOVD R3, CPU_VECTOR_CODE(RSV_REG)
- B ·Halt(SB)
+ B ·HaltAndResume(SB)
el0_ia:
B ·Shutdown(SB)
@@ -601,7 +651,19 @@ TEXT ·El0_fiq(SB),NOSPLIT,$0
B ·Shutdown(SB)
TEXT ·El0_error(SB),NOSPLIT,$0
- B ·Shutdown(SB)
+ KERNEL_ENTRY_FROM_EL0
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ WORD $0xd538601a //MRS FAR_EL1, R26
+
+ MOVD R26, CPU_FAULT_ADDR(RSV_REG)
+
+ MOVD $1, R3
+ MOVD R3, CPU_ERROR_TYPE(RSV_REG) // Set error type to user.
+
+ MOVD $VirtualizationException, R3
+ MOVD R3, CPU_VECTOR_CODE(RSV_REG)
+
+ B ·HaltAndResume(SB)
TEXT ·El0_sync_invalid(SB),NOSPLIT,$0
B ·Shutdown(SB)
@@ -615,6 +677,7 @@ TEXT ·El0_fiq_invalid(SB),NOSPLIT,$0
TEXT ·El0_error_invalid(SB),NOSPLIT,$0
B ·Shutdown(SB)
+// Vectors implements exception vector table.
TEXT ·Vectors(SB),NOSPLIT,$0
B ·El1_sync_invalid(SB)
nop31Instructions()
diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go
index c3d341998..ccacaea6b 100644
--- a/pkg/sentry/platform/ring0/kernel_arm64.go
+++ b/pkg/sentry/platform/ring0/kernel_arm64.go
@@ -16,6 +16,14 @@
package ring0
+// HaltAndResume halts execution and point the pointer to the resume function.
+//go:nosplit
+func HaltAndResume()
+
+// HaltEl1SvcAndResume calls Hooks.KernelSyscall and resume.
+//go:nosplit
+func HaltEl1SvcAndResume()
+
// init initializes architecture-specific state.
func (k *Kernel) init(opts KernelOpts) {
// Save the root page tables.
diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go
index 8c960c749..057fb5c69 100644
--- a/pkg/sentry/platform/ring0/offsets_arm64.go
+++ b/pkg/sentry/platform/ring0/offsets_arm64.go
@@ -85,6 +85,7 @@ func Emit(w io.Writer) {
fmt.Fprintf(w, "#define PageFault 0x%02x\n", PageFault)
fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall)
+ fmt.Fprintf(w, "#define VirtualizationException 0x%02x\n", VirtualizationException)
p := &syscall.PtraceRegs{}
fmt.Fprintf(w, "\n// Ptrace registers.\n")
diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD
index 971eed7fa..4f2406ce3 100644
--- a/pkg/sentry/platform/ring0/pagetables/BUILD
+++ b/pkg/sentry/platform/ring0/pagetables/BUILD
@@ -7,7 +7,7 @@ go_template(
name = "generic_walker",
srcs = select_arch(
amd64 = ["walker_amd64.go"],
- arm64 = ["walker_amd64.go"],
+ arm64 = ["walker_arm64.go"],
),
opt_types = [
"Visitor",
diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD
index 79e16d6e8..4d42d29cb 100644
--- a/pkg/sentry/socket/control/BUILD
+++ b/pkg/sentry/socket/control/BUILD
@@ -19,6 +19,7 @@ go_library(
"//pkg/sentry/socket",
"//pkg/sentry/socket/unix/transport",
"//pkg/syserror",
+ "//pkg/tcpip",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 00265f15b..8834a1e1a 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -189,7 +190,7 @@ func putUint32(buf []byte, n uint32) []byte {
// putCmsg writes a control message header and as much data as will fit into
// the unused capacity of a buffer.
func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([]byte, int) {
- space := AlignDown(cap(buf)-len(buf), 4)
+ space := binary.AlignDown(cap(buf)-len(buf), 4)
// We can't write to space that doesn't exist, so if we are going to align
// the available space, we must align down.
@@ -282,19 +283,9 @@ func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte, flags int
return putCmsg(buf, flags, linux.SCM_CREDENTIALS, align, c)
}
-// AlignUp rounds a length up to an alignment. align must be a power of 2.
-func AlignUp(length int, align uint) int {
- return (length + int(align) - 1) & ^(int(align) - 1)
-}
-
-// AlignDown rounds a down to an alignment. align must be a power of 2.
-func AlignDown(length int, align uint) int {
- return length & ^(int(align) - 1)
-}
-
// alignSlice extends a slice's length (up to the capacity) to align it.
func alignSlice(buf []byte, align uint) []byte {
- aligned := AlignUp(len(buf), align)
+ aligned := binary.AlignUp(len(buf), align)
if aligned > cap(buf) {
// Linux allows unaligned data if there isn't room for alignment.
// Since there isn't room for alignment, there isn't room for any
@@ -338,7 +329,7 @@ func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte {
}
// PackTClass packs an IPV6_TCLASS socket control message.
-func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte {
+func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte {
return putCmsgStruct(
buf,
linux.SOL_IPV6,
@@ -348,6 +339,22 @@ func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte {
)
}
+// PackIPPacketInfo packs an IP_PKTINFO socket control message.
+func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte {
+ var p linux.ControlMessageIPPacketInfo
+ p.NIC = int32(packetInfo.NIC)
+ copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
+ copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
+
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IP,
+ linux.IP_PKTINFO,
+ t.Arch().Width(),
+ p,
+ )
+}
+
// PackControlMessages packs control messages into the given buffer.
//
// We skip control messages specific to Unix domain sockets.
@@ -372,12 +379,16 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt
buf = PackTClass(t, cmsgs.IP.TClass, buf)
}
+ if cmsgs.IP.HasIPPacketInfo {
+ buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf)
+ }
+
return buf
}
// cmsgSpace is equivalent to CMSG_SPACE in Linux.
func cmsgSpace(t *kernel.Task, dataLen int) int {
- return linux.SizeOfControlMessageHeader + AlignUp(dataLen, t.Arch().Width())
+ return linux.SizeOfControlMessageHeader + binary.AlignUp(dataLen, t.Arch().Width())
}
// CmsgsSpace returns the number of bytes needed to fit the control messages
@@ -404,6 +415,16 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
return space
}
+// NewIPPacketInfo returns the IPPacketInfo struct.
+func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo {
+ var p tcpip.IPPacketInfo
+ p.NIC = tcpip.NICID(packetInfo.NIC)
+ copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:])
+ copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:])
+
+ return p
+}
+
// Parse parses a raw socket control message into portable objects.
func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) {
var (
@@ -437,7 +458,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
case linux.SOL_SOCKET:
switch h.Type {
case linux.SCM_RIGHTS:
- rightsSize := AlignDown(length, linux.SizeOfControlMessageRight)
+ rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight)
numRights := rightsSize / linux.SizeOfControlMessageRight
if len(fds)+numRights > linux.SCM_MAX_FD {
@@ -448,7 +469,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight])))
}
- i += AlignUp(length, width)
+ i += binary.AlignUp(length, width)
case linux.SCM_CREDENTIALS:
if length < linux.SizeOfControlMessageCredentials {
@@ -462,7 +483,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
return socket.ControlMessages{}, err
}
cmsgs.Unix.Credentials = scmCreds
- i += AlignUp(length, width)
+ i += binary.AlignUp(length, width)
default:
// Unknown message type.
@@ -476,7 +497,19 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
}
cmsgs.IP.HasTOS = true
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS)
- i += AlignUp(length, width)
+ i += binary.AlignUp(length, width)
+
+ case linux.IP_PKTINFO:
+ if length < linux.SizeOfControlMessageIPPacketInfo {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ cmsgs.IP.HasIPPacketInfo = true
+ var packetInfo linux.ControlMessageIPPacketInfo
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
+
+ cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo)
+ i += binary.AlignUp(length, width)
default:
return socket.ControlMessages{}, syserror.EINVAL
@@ -489,7 +522,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
}
cmsgs.IP.HasTClass = true
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass)
- i += AlignUp(length, width)
+ i += binary.AlignUp(length, width)
default:
return socket.ControlMessages{}, syserror.EINVAL
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index 5a07d5d0e..023bad156 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -10,6 +10,7 @@ go_library(
"save_restore.go",
"socket.go",
"socket_unsafe.go",
+ "sockopt_impl.go",
"stack.go",
],
visibility = ["//pkg/sentry:internal"],
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index bde4c7a1e..22f78d2e2 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -126,7 +126,7 @@ func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
}
return uint64(n), nil
}
- return readv(s.fd, iovecsFromBlockSeq(dsts))
+ return readv(s.fd, safemem.IovecsFromBlockSeq(dsts))
}))
return int64(n), err
}
@@ -149,7 +149,7 @@ func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
}
return uint64(n), nil
}
- return writev(s.fd, iovecsFromBlockSeq(srcs))
+ return writev(s.fd, safemem.IovecsFromBlockSeq(srcs))
}))
return int64(n), err
}
@@ -285,11 +285,11 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt
}
// Whitelist options and constrain option length.
- var optlen int
+ optlen := getSockOptLen(t, level, name)
switch level {
case linux.SOL_IP:
switch name {
- case linux.IP_TOS, linux.IP_RECVTOS:
+ case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO:
optlen = sizeofInt32
}
case linux.SOL_IPV6:
@@ -330,12 +330,14 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt
// SetSockOpt implements socket.Socket.SetSockOpt.
func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
// Whitelist options and constrain option length.
- var optlen int
+ optlen := setSockOptLen(t, level, name)
switch level {
case linux.SOL_IP:
switch name {
case linux.IP_TOS, linux.IP_RECVTOS:
optlen = sizeofInt32
+ case linux.IP_PKTINFO:
+ optlen = linux.SizeOfControlMessageIPPacketInfo
}
case linux.SOL_IPV6:
switch name {
@@ -353,6 +355,7 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
optlen = sizeofInt32
}
}
+
if optlen == 0 {
// Pretend to accept socket options we don't understand. This seems
// dangerous, but it's what netstack does...
@@ -402,7 +405,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
// We always do a non-blocking recv*().
sysflags := flags | syscall.MSG_DONTWAIT
- iovs := iovecsFromBlockSeq(dsts)
+ iovs := safemem.IovecsFromBlockSeq(dsts)
msg := syscall.Msghdr{
Iov: &iovs[0],
Iovlen: uint64(len(iovs)),
@@ -472,7 +475,14 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
case syscall.IP_TOS:
controlMessages.IP.HasTOS = true
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS)
+
+ case syscall.IP_PKTINFO:
+ controlMessages.IP.HasIPPacketInfo = true
+ var packetInfo linux.ControlMessageIPPacketInfo
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
+ controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo)
}
+
case syscall.SOL_IPV6:
switch unixCmsg.Header.Type {
case syscall.IPV6_TCLASS:
@@ -522,7 +532,7 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
return uint64(n), nil
}
- iovs := iovecsFromBlockSeq(srcs)
+ iovs := safemem.IovecsFromBlockSeq(srcs)
msg := syscall.Msghdr{
Iov: &iovs[0],
Iovlen: uint64(len(iovs)),
@@ -567,21 +577,6 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
return int(n), syserr.FromError(err)
}
-func iovecsFromBlockSeq(bs safemem.BlockSeq) []syscall.Iovec {
- iovs := make([]syscall.Iovec, 0, bs.NumBlocks())
- for ; !bs.IsEmpty(); bs = bs.Tail() {
- b := bs.Head()
- iovs = append(iovs, syscall.Iovec{
- Base: &b.ToSlice()[0],
- Len: uint64(b.Len()),
- })
- // We don't need to care about b.NeedSafecopy(), because the host
- // kernel will handle such address ranges just fine (by returning
- // EFAULT).
- }
- return iovs
-}
-
func translateIOSyscallError(err error) error {
if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK {
return syserror.ErrWouldBlock
diff --git a/pkg/sentry/socket/hostinet/sockopt_impl.go b/pkg/sentry/socket/hostinet/sockopt_impl.go
new file mode 100644
index 000000000..8a783712e
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/sockopt_impl.go
@@ -0,0 +1,27 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+func getSockOptLen(t *kernel.Task, level, name int) int {
+ return 0 // No custom options.
+}
+
+func setSockOptLen(t *kernel.Task, level, name int) int {
+ return 0 // No custom options.
+}
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index 034eca676..a48082631 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -310,6 +310,11 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
return addrs
}
+// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr.
+func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+ return syserror.EACCES
+}
+
// SupportsIPv6 implements inet.Stack.SupportsIPv6.
func (s *Stack) SupportsIPv6() bool {
return s.supportsIPv6
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
index fa2a2cb66..7cd2ce55b 100644
--- a/pkg/sentry/socket/netfilter/BUILD
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -5,7 +5,11 @@ package(licenses = ["notice"])
go_library(
name = "netfilter",
srcs = [
+ "extensions.go",
"netfilter.go",
+ "targets.go",
+ "tcp_matcher.go",
+ "udp_matcher.go",
],
# This target depends on netstack and should only be used by epsocket,
# which is allowed to depend on netstack.
@@ -17,6 +21,7 @@ go_library(
"//pkg/sentry/kernel",
"//pkg/syserr",
"//pkg/tcpip",
+ "//pkg/tcpip/header",
"//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
"//pkg/usermem",
diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go
new file mode 100644
index 000000000..b4b244abf
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/extensions.go
@@ -0,0 +1,95 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// TODO(gvisor.dev/issue/170): The following per-matcher params should be
+// supported:
+// - Table name
+// - Match size
+// - User size
+// - Hooks
+// - Proto
+// - Family
+
+// matchMaker knows how to (un)marshal the matcher named name().
+type matchMaker interface {
+ // name is the matcher name as stored in the xt_entry_match struct.
+ name() string
+
+ // marshal converts from an iptables.Matcher to an ABI struct.
+ marshal(matcher iptables.Matcher) []byte
+
+ // unmarshal converts from the ABI matcher struct to an
+ // iptables.Matcher.
+ unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error)
+}
+
+// matchMakers maps the name of supported matchers to the matchMaker that
+// marshals and unmarshals it. It is immutable after package initialization.
+var matchMakers = map[string]matchMaker{}
+
+// registermatchMaker should be called by match extensions to register them
+// with the netfilter package.
+func registerMatchMaker(mm matchMaker) {
+ if _, ok := matchMakers[mm.name()]; ok {
+ panic(fmt.Sprintf("Multiple matches registered with name %q.", mm.name()))
+ }
+ matchMakers[mm.name()] = mm
+}
+
+func marshalMatcher(matcher iptables.Matcher) []byte {
+ matchMaker, ok := matchMakers[matcher.Name()]
+ if !ok {
+ panic(fmt.Sprintf("Unknown matcher of type %T.", matcher))
+ }
+ return matchMaker.marshal(matcher)
+}
+
+// marshalEntryMatch creates a marshalled XTEntryMatch with the given name and
+// data appended at the end.
+func marshalEntryMatch(name string, data []byte) []byte {
+ nflog("marshaling matcher %q", name)
+
+ // We have to pad this struct size to a multiple of 8 bytes.
+ size := binary.AlignUp(linux.SizeOfXTEntryMatch+len(data), 8)
+ matcher := linux.KernelXTEntryMatch{
+ XTEntryMatch: linux.XTEntryMatch{
+ MatchSize: uint16(size),
+ },
+ Data: data,
+ }
+ copy(matcher.Name[:], name)
+
+ buf := make([]byte, 0, size)
+ buf = binary.Marshal(buf, usermem.ByteOrder, matcher)
+ return append(buf, make([]byte, size-len(buf))...)
+}
+
+func unmarshalMatcher(match linux.XTEntryMatch, filter iptables.IPHeaderFilter, buf []byte) (iptables.Matcher, error) {
+ matchMaker, ok := matchMakers[match.Name.String()]
+ if !ok {
+ return nil, fmt.Errorf("unsupported matcher with name %q", match.Name.String())
+ }
+ return matchMaker.unmarshal(buf, filter)
+}
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 6ef740463..2ec11f6ac 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -17,6 +17,7 @@
package netfilter
import (
+ "errors"
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -34,9 +35,12 @@ import (
// shouldn't be reached - an error has occurred if we fall through to one.
const errorTargetName = "ERROR"
-// metadata is opaque to netstack. It holds data that we need to translate
-// between Linux's and netstack's iptables representations.
-// TODO(gvisor.dev/issue/170): This might be removable.
+// Metadata is used to verify that we are correctly serializing and
+// deserializing iptables into structs consumable by the iptables tool. We save
+// a metadata struct when the tables are written, and when they are read out we
+// verify that certain fields are the same.
+//
+// metadata is used by this serialization/deserializing code, not netstack.
type metadata struct {
HookEntry [linux.NF_INET_NUMHOOKS]uint32
Underflow [linux.NF_INET_NUMHOOKS]uint32
@@ -44,6 +48,13 @@ type metadata struct {
Size uint32
}
+// nflog logs messages related to the writing and reading of iptables.
+func nflog(format string, args ...interface{}) {
+ if log.IsLogging(log.Debug) {
+ log.Debugf("netfilter: "+format, args...)
+ }
+}
+
// GetInfo returns information about iptables.
func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) {
// Read in the struct and table name.
@@ -55,7 +66,8 @@ func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPT
// Find the appropriate table.
table, err := findTable(stack, info.Name)
if err != nil {
- return linux.IPTGetinfo{}, err
+ nflog("%v", err)
+ return linux.IPTGetinfo{}, syserr.ErrInvalidArgument
}
// Get the hooks that apply to this table.
@@ -72,6 +84,8 @@ func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPT
info.NumEntries = metadata.NumEntries
info.Size = metadata.Size
+ nflog("returning info: %+v", info)
+
return info, nil
}
@@ -80,34 +94,40 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen
// Read in the struct and table name.
var userEntries linux.IPTGetEntries
if _, err := t.CopyIn(outPtr, &userEntries); err != nil {
+ nflog("couldn't copy in entries %q", userEntries.Name)
return linux.KernelIPTGetEntries{}, syserr.FromError(err)
}
// Find the appropriate table.
table, err := findTable(stack, userEntries.Name)
if err != nil {
- return linux.KernelIPTGetEntries{}, err
+ nflog("%v", err)
+ return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
}
// Convert netstack's iptables rules to something that the iptables
// tool can understand.
- entries, _, err := convertNetstackToBinary(userEntries.Name.String(), table)
+ entries, meta, err := convertNetstackToBinary(userEntries.Name.String(), table)
if err != nil {
- return linux.KernelIPTGetEntries{}, err
+ nflog("couldn't read entries: %v", err)
+ return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
+ }
+ if meta != table.Metadata().(metadata) {
+ panic(fmt.Sprintf("Table %q metadata changed between writing and reading. Was saved as %+v, but is now %+v", userEntries.Name.String(), table.Metadata().(metadata), meta))
}
if binary.Size(entries) > uintptr(outLen) {
- log.Warningf("Insufficient GetEntries output size: %d", uintptr(outLen))
+ nflog("insufficient GetEntries output size: %d", uintptr(outLen))
return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
}
return entries, nil
}
-func findTable(stack *stack.Stack, tablename linux.TableName) (iptables.Table, *syserr.Error) {
+func findTable(stack *stack.Stack, tablename linux.TableName) (iptables.Table, error) {
ipt := stack.IPTables()
table, ok := ipt.Tables[tablename.String()]
if !ok {
- return iptables.Table{}, syserr.ErrInvalidArgument
+ return iptables.Table{}, fmt.Errorf("couldn't find table %q", tablename)
}
return table, nil
}
@@ -135,28 +155,31 @@ func FillDefaultIPTables(stack *stack.Stack) {
// format expected by the iptables tool. Linux stores each table as a binary
// blob that can only be traversed by parsing a bit, reading some offsets,
// jumping to those offsets, parsing again, etc.
-func convertNetstackToBinary(tablename string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, *syserr.Error) {
+func convertNetstackToBinary(tablename string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, error) {
// Return values.
var entries linux.KernelIPTGetEntries
var meta metadata
// The table name has to fit in the struct.
if linux.XT_TABLE_MAXNAMELEN < len(tablename) {
- log.Warningf("Table name %q too long.", tablename)
- return linux.KernelIPTGetEntries{}, metadata{}, syserr.ErrInvalidArgument
+ return linux.KernelIPTGetEntries{}, metadata{}, fmt.Errorf("table name %q too long.", tablename)
}
copy(entries.Name[:], tablename)
for ruleIdx, rule := range table.Rules {
+ nflog("convert to binary: current offset: %d", entries.Size)
+
// Is this a chain entry point?
for hook, hookRuleIdx := range table.BuiltinChains {
if hookRuleIdx == ruleIdx {
+ nflog("convert to binary: found hook %d at offset %d", hook, entries.Size)
meta.HookEntry[hook] = entries.Size
}
}
// Is this a chain underflow point?
for underflow, underflowRuleIdx := range table.Underflows {
if underflowRuleIdx == ruleIdx {
+ nflog("convert to binary: found underflow %d at offset %d", underflow, entries.Size)
meta.Underflow[underflow] = entries.Size
}
}
@@ -176,6 +199,10 @@ func convertNetstackToBinary(tablename string, table iptables.Table) (linux.Kern
// Serialize the matcher and add it to the
// entry.
serialized := marshalMatcher(matcher)
+ nflog("convert to binary: matcher serialized as: %v", serialized)
+ if len(serialized)%8 != 0 {
+ panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher))
+ }
entry.Elems = append(entry.Elems, serialized...)
entry.NextOffset += uint16(len(serialized))
entry.TargetOffset += uint16(len(serialized))
@@ -183,41 +210,46 @@ func convertNetstackToBinary(tablename string, table iptables.Table) (linux.Kern
// Serialize and append the target.
serialized := marshalTarget(rule.Target)
+ if len(serialized)%8 != 0 {
+ panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target))
+ }
entry.Elems = append(entry.Elems, serialized...)
entry.NextOffset += uint16(len(serialized))
+ nflog("convert to binary: adding entry: %+v", entry)
+
entries.Size += uint32(entry.NextOffset)
entries.Entrytable = append(entries.Entrytable, entry)
meta.NumEntries++
}
+ nflog("convert to binary: finished with an marshalled size of %d", meta.Size)
meta.Size = entries.Size
return entries, meta, nil
}
-func marshalMatcher(matcher iptables.Matcher) []byte {
- switch matcher.(type) {
- default:
- // TODO(gvisor.dev/issue/170): We don't support any matchers
- // yet, so any call to marshalMatcher will panic.
- panic(fmt.Errorf("unknown matcher of type %T", matcher))
- }
-}
-
func marshalTarget(target iptables.Target) []byte {
- switch target.(type) {
- case iptables.UnconditionalAcceptTarget:
- return marshalStandardTarget(iptables.Accept)
- case iptables.UnconditionalDropTarget:
- return marshalStandardTarget(iptables.Drop)
+ switch tg := target.(type) {
+ case iptables.AcceptTarget:
+ return marshalStandardTarget(iptables.RuleAccept)
+ case iptables.DropTarget:
+ return marshalStandardTarget(iptables.RuleDrop)
case iptables.ErrorTarget:
- return marshalErrorTarget()
+ return marshalErrorTarget(errorTargetName)
+ case iptables.UserChainTarget:
+ return marshalErrorTarget(tg.Name)
+ case iptables.ReturnTarget:
+ return marshalStandardTarget(iptables.RuleReturn)
+ case JumpTarget:
+ return marshalJumpTarget(tg)
default:
panic(fmt.Errorf("unknown target of type %T", target))
}
}
-func marshalStandardTarget(verdict iptables.Verdict) []byte {
+func marshalStandardTarget(verdict iptables.RuleVerdict) []byte {
+ nflog("convert to binary: marshalling standard target")
+
// The target's name will be the empty string.
target := linux.XTStandardTarget{
Target: linux.XTEntryTarget{
@@ -230,66 +262,77 @@ func marshalStandardTarget(verdict iptables.Verdict) []byte {
return binary.Marshal(ret, usermem.ByteOrder, target)
}
-func marshalErrorTarget() []byte {
+func marshalErrorTarget(errorName string) []byte {
// This is an error target named error
target := linux.XTErrorTarget{
Target: linux.XTEntryTarget{
TargetSize: linux.SizeOfXTErrorTarget,
},
}
- copy(target.Name[:], errorTargetName)
+ copy(target.Name[:], errorName)
copy(target.Target.Name[:], errorTargetName)
ret := make([]byte, 0, linux.SizeOfXTErrorTarget)
return binary.Marshal(ret, usermem.ByteOrder, target)
}
+func marshalJumpTarget(jt JumpTarget) []byte {
+ nflog("convert to binary: marshalling jump target")
+
+ // The target's name will be the empty string.
+ target := linux.XTStandardTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTStandardTarget,
+ },
+ // Verdict is overloaded by the ABI. When positive, it holds
+ // the jump offset from the start of the table.
+ Verdict: int32(jt.Offset),
+ }
+
+ ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
// translateFromStandardVerdict translates verdicts the same way as the iptables
// tool.
-func translateFromStandardVerdict(verdict iptables.Verdict) int32 {
+func translateFromStandardVerdict(verdict iptables.RuleVerdict) int32 {
switch verdict {
- case iptables.Accept:
+ case iptables.RuleAccept:
return -linux.NF_ACCEPT - 1
- case iptables.Drop:
+ case iptables.RuleDrop:
return -linux.NF_DROP - 1
- case iptables.Queue:
- return -linux.NF_QUEUE - 1
- case iptables.Return:
+ case iptables.RuleReturn:
return linux.NF_RETURN
- case iptables.Jump:
+ default:
// TODO(gvisor.dev/issue/170): Support Jump.
- panic("Jump isn't supported yet")
+ panic(fmt.Sprintf("unknown standard verdict: %d", verdict))
}
- panic(fmt.Sprintf("unknown standard verdict: %d", verdict))
}
-// translateToStandardVerdict translates from the value in a
+// translateToStandardTarget translates from the value in a
// linux.XTStandardTarget to an iptables.Verdict.
-func translateToStandardVerdict(val int32) (iptables.Verdict, *syserr.Error) {
+func translateToStandardTarget(val int32) (iptables.Target, error) {
// TODO(gvisor.dev/issue/170): Support other verdicts.
switch val {
case -linux.NF_ACCEPT - 1:
- return iptables.Accept, nil
+ return iptables.AcceptTarget{}, nil
case -linux.NF_DROP - 1:
- return iptables.Drop, nil
+ return iptables.DropTarget{}, nil
case -linux.NF_QUEUE - 1:
- log.Warningf("Unsupported iptables verdict QUEUE.")
+ return nil, errors.New("unsupported iptables verdict QUEUE")
case linux.NF_RETURN:
- log.Warningf("Unsupported iptables verdict RETURN.")
+ return iptables.ReturnTarget{}, nil
default:
- log.Warningf("Unknown iptables verdict %d.", val)
+ return nil, fmt.Errorf("unknown iptables verdict %d", val)
}
- return iptables.Invalid, syserr.ErrInvalidArgument
}
// SetEntries sets iptables rules for a single table. See
// net/ipv4/netfilter/ip_tables.c:translate_table for reference.
func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
- printReplace(optVal)
-
// Get the basic rules data (struct ipt_replace).
if len(optVal) < linux.SizeOfIPTReplace {
- log.Warningf("netfilter.SetEntries: optVal has insufficient size for replace %d", len(optVal))
+ nflog("optVal has insufficient size for replace %d", len(optVal))
return syserr.ErrInvalidArgument
}
var replace linux.IPTReplace
@@ -303,25 +346,32 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
case iptables.TablenameFilter:
table = iptables.EmptyFilterTable()
default:
- log.Warningf("We don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
+ nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
return syserr.ErrInvalidArgument
}
+ nflog("set entries: setting entries in table %q", replace.Name.String())
+
// Convert input into a list of rules and their offsets.
var offset uint32
- var offsets []uint32
+ // offsets maps rule byte offsets to their position in table.Rules.
+ offsets := map[uint32]int{}
for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ {
+ nflog("set entries: processing entry at offset %d", offset)
+
// Get the struct ipt_entry.
if len(optVal) < linux.SizeOfIPTEntry {
- log.Warningf("netfilter: optVal has insufficient size for entry %d", len(optVal))
+ nflog("optVal has insufficient size for entry %d", len(optVal))
return syserr.ErrInvalidArgument
}
var entry linux.IPTEntry
buf := optVal[:linux.SizeOfIPTEntry]
- optVal = optVal[linux.SizeOfIPTEntry:]
binary.Unmarshal(buf, usermem.ByteOrder, &entry)
- if entry.TargetOffset != linux.SizeOfIPTEntry {
- // TODO(gvisor.dev/issue/170): Support matchers.
+ initialOptValLen := len(optVal)
+ optVal = optVal[linux.SizeOfIPTEntry:]
+
+ if entry.TargetOffset < linux.SizeOfIPTEntry {
+ nflog("entry has too-small target offset %d", entry.TargetOffset)
return syserr.ErrInvalidArgument
}
@@ -329,22 +379,50 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
// filtering fields.
filter, err := filterFromIPTIP(entry.IP)
if err != nil {
- return err
+ nflog("bad iptip: %v", err)
+ return syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/170): Matchers and targets can specify
+ // that they only work for certain protocols, hooks, tables.
+ // Get matchers.
+ matchersSize := entry.TargetOffset - linux.SizeOfIPTEntry
+ if len(optVal) < int(matchersSize) {
+ nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal))
+ return syserr.ErrInvalidArgument
+ }
+ matchers, err := parseMatchers(filter, optVal[:matchersSize])
+ if err != nil {
+ nflog("failed to parse matchers: %v", err)
+ return syserr.ErrInvalidArgument
}
+ optVal = optVal[matchersSize:]
// Get the target of the rule.
- target, consumed, err := parseTarget(optVal)
+ targetSize := entry.NextOffset - entry.TargetOffset
+ if len(optVal) < int(targetSize) {
+ nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal))
+ return syserr.ErrInvalidArgument
+ }
+ target, err := parseTarget(optVal[:targetSize])
if err != nil {
- return err
+ nflog("failed to parse target: %v", err)
+ return syserr.ErrInvalidArgument
}
- optVal = optVal[consumed:]
+ optVal = optVal[targetSize:]
table.Rules = append(table.Rules, iptables.Rule{
- Filter: filter,
- Target: target,
+ Filter: filter,
+ Target: target,
+ Matchers: matchers,
})
- offsets = append(offsets, offset)
- offset += linux.SizeOfIPTEntry + consumed
+ offsets[offset] = int(entryIdx)
+ offset += uint32(entry.NextOffset)
+
+ if initialOptValLen-len(optVal) != int(entry.NextOffset) {
+ nflog("entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal))
+ return syserr.ErrInvalidArgument
+ }
}
// Go through the list of supported hooks for this table and, for each
@@ -352,32 +430,77 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
for hook, _ := range replace.HookEntry {
if table.ValidHooks()&(1<<hook) != 0 {
hk := hookFromLinux(hook)
- for ruleIdx, offset := range offsets {
+ for offset, ruleIdx := range offsets {
if offset == replace.HookEntry[hook] {
table.BuiltinChains[hk] = ruleIdx
}
if offset == replace.Underflow[hook] {
+ if !validUnderflow(table.Rules[ruleIdx]) {
+ nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP")
+ return syserr.ErrInvalidArgument
+ }
table.Underflows[hk] = ruleIdx
}
}
if ruleIdx := table.BuiltinChains[hk]; ruleIdx == iptables.HookUnset {
- log.Warningf("Hook %v is unset.", hk)
+ nflog("hook %v is unset.", hk)
return syserr.ErrInvalidArgument
}
if ruleIdx := table.Underflows[hk]; ruleIdx == iptables.HookUnset {
- log.Warningf("Underflow %v is unset.", hk)
+ nflog("underflow %v is unset.", hk)
return syserr.ErrInvalidArgument
}
}
}
+ // Add the user chains.
+ for ruleIdx, rule := range table.Rules {
+ target, ok := rule.Target.(iptables.UserChainTarget)
+ if !ok {
+ continue
+ }
+
+ // We found a user chain. Before inserting it into the table,
+ // check that:
+ // - There's some other rule after it.
+ // - There are no matchers.
+ if ruleIdx == len(table.Rules)-1 {
+ nflog("user chain must have a rule or default policy")
+ return syserr.ErrInvalidArgument
+ }
+ if len(table.Rules[ruleIdx].Matchers) != 0 {
+ nflog("user chain's first node must have no matchers")
+ return syserr.ErrInvalidArgument
+ }
+ table.UserChains[target.Name] = ruleIdx + 1
+ }
+
+ // Set each jump to point to the appropriate rule. Right now they hold byte
+ // offsets.
+ for ruleIdx, rule := range table.Rules {
+ jump, ok := rule.Target.(JumpTarget)
+ if !ok {
+ continue
+ }
+
+ // Find the rule corresponding to the jump rule offset.
+ jumpTo, ok := offsets[jump.Offset]
+ if !ok {
+ nflog("failed to find a rule to jump to")
+ return syserr.ErrInvalidArgument
+ }
+ jump.RuleNum = jumpTo
+ rule.Target = jump
+ table.Rules[ruleIdx] = rule
+ }
+
// TODO(gvisor.dev/issue/170): Support other chains.
// Since we only support modifying the INPUT chain right now, make sure
// all other chains point to ACCEPT rules.
for hook, ruleIdx := range table.BuiltinChains {
if hook != iptables.Input {
- if _, ok := table.Rules[ruleIdx].Target.(iptables.UnconditionalAcceptTarget); !ok {
- log.Warningf("Hook %d is unsupported.", hook)
+ if _, ok := table.Rules[ruleIdx].Target.(iptables.AcceptTarget); !ok {
+ nflog("hook %d is unsupported.", hook)
return syserr.ErrInvalidArgument
}
}
@@ -401,12 +524,56 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
return nil
}
-// parseTarget parses a target from the start of optVal and returns the target
-// along with the number of bytes it occupies in optVal.
-func parseTarget(optVal []byte) (iptables.Target, uint32, *syserr.Error) {
+// parseMatchers parses 0 or more matchers from optVal. optVal should contain
+// only the matchers.
+func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Matcher, error) {
+ nflog("set entries: parsing matchers of size %d", len(optVal))
+ var matchers []iptables.Matcher
+ for len(optVal) > 0 {
+ nflog("set entries: optVal has len %d", len(optVal))
+
+ // Get the XTEntryMatch.
+ if len(optVal) < linux.SizeOfXTEntryMatch {
+ return nil, fmt.Errorf("optVal has insufficient size for entry match: %d", len(optVal))
+ }
+ var match linux.XTEntryMatch
+ buf := optVal[:linux.SizeOfXTEntryMatch]
+ binary.Unmarshal(buf, usermem.ByteOrder, &match)
+ nflog("set entries: parsed entry match %q: %+v", match.Name.String(), match)
+
+ // Check some invariants.
+ if match.MatchSize < linux.SizeOfXTEntryMatch {
+
+ return nil, fmt.Errorf("match size is too small, must be at least %d", linux.SizeOfXTEntryMatch)
+ }
+ if len(optVal) < int(match.MatchSize) {
+ return nil, fmt.Errorf("optVal has insufficient size for match: %d", len(optVal))
+ }
+
+ // Parse the specific matcher.
+ matcher, err := unmarshalMatcher(match, filter, optVal[linux.SizeOfXTEntryMatch:match.MatchSize])
+ if err != nil {
+ return nil, fmt.Errorf("failed to create matcher: %v", err)
+ }
+ matchers = append(matchers, matcher)
+
+ // TODO(gvisor.dev/issue/170): Check the revision field.
+ optVal = optVal[match.MatchSize:]
+ }
+
+ if len(optVal) != 0 {
+ return nil, errors.New("optVal should be exhausted after parsing matchers")
+ }
+
+ return matchers, nil
+}
+
+// parseTarget parses a target from optVal. optVal should contain only the
+// target.
+func parseTarget(optVal []byte) (iptables.Target, error) {
+ nflog("set entries: parsing target of size %d", len(optVal))
if len(optVal) < linux.SizeOfXTEntryTarget {
- log.Warningf("netfilter: optVal has insufficient size for entry target %d", len(optVal))
- return nil, 0, syserr.ErrInvalidArgument
+ return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal))
}
var target linux.XTEntryTarget
buf := optVal[:linux.SizeOfXTEntryTarget]
@@ -414,32 +581,24 @@ func parseTarget(optVal []byte) (iptables.Target, uint32, *syserr.Error) {
switch target.Name.String() {
case "":
// Standard target.
- if len(optVal) < linux.SizeOfXTStandardTarget {
- log.Warningf("netfilter.SetEntries: optVal has insufficient size for standard target %d", len(optVal))
- return nil, 0, syserr.ErrInvalidArgument
+ if len(optVal) != linux.SizeOfXTStandardTarget {
+ return nil, fmt.Errorf("optVal has wrong size for standard target %d", len(optVal))
}
var standardTarget linux.XTStandardTarget
buf = optVal[:linux.SizeOfXTStandardTarget]
binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget)
- verdict, err := translateToStandardVerdict(standardTarget.Verdict)
- if err != nil {
- return nil, 0, err
- }
- switch verdict {
- case iptables.Accept:
- return iptables.UnconditionalAcceptTarget{}, linux.SizeOfXTStandardTarget, nil
- case iptables.Drop:
- return iptables.UnconditionalDropTarget{}, linux.SizeOfXTStandardTarget, nil
- default:
- panic(fmt.Sprintf("Unknown verdict: %v", verdict))
+ if standardTarget.Verdict < 0 {
+ // A Verdict < 0 indicates a non-jump verdict.
+ return translateToStandardTarget(standardTarget.Verdict)
}
+ // A verdict >= 0 indicates a jump.
+ return JumpTarget{Offset: uint32(standardTarget.Verdict)}, nil
case errorTargetName:
// Error target.
- if len(optVal) < linux.SizeOfXTErrorTarget {
- log.Infof("netfilter.SetEntries: optVal has insufficient size for error target %d", len(optVal))
- return nil, 0, syserr.ErrInvalidArgument
+ if len(optVal) != linux.SizeOfXTErrorTarget {
+ return nil, fmt.Errorf("optVal has insufficient size for error target %d", len(optVal))
}
var errorTarget linux.XTErrorTarget
buf = optVal[:linux.SizeOfXTErrorTarget]
@@ -452,24 +611,24 @@ func parseTarget(optVal []byte) (iptables.Target, uint32, *syserr.Error) {
// somehow fall through every rule.
// * To mark the start of a user defined chain. These
// rules have an error with the name of the chain.
- switch errorTarget.Name.String() {
+ switch name := errorTarget.Name.String(); name {
case errorTargetName:
- return iptables.ErrorTarget{}, linux.SizeOfXTErrorTarget, nil
+ nflog("set entries: error target")
+ return iptables.ErrorTarget{}, nil
default:
- log.Infof("Unknown error target %q doesn't exist or isn't supported yet.", errorTarget.Name.String())
- return nil, 0, syserr.ErrInvalidArgument
+ // User defined chain.
+ nflog("set entries: user-defined target %q", name)
+ return iptables.UserChainTarget{Name: name}, nil
}
}
// Unknown target.
- log.Infof("Unknown target %q doesn't exist or isn't supported yet.", target.Name.String())
- return nil, 0, syserr.ErrInvalidArgument
+ return nil, fmt.Errorf("unknown target %q doesn't exist or isn't supported yet.", target.Name.String())
}
-func filterFromIPTIP(iptip linux.IPTIP) (iptables.IPHeaderFilter, *syserr.Error) {
+func filterFromIPTIP(iptip linux.IPTIP) (iptables.IPHeaderFilter, error) {
if containsUnsupportedFields(iptip) {
- log.Warningf("netfilter: unsupported fields in struct iptip: %+v", iptip)
- return iptables.IPHeaderFilter{}, syserr.ErrInvalidArgument
+ return iptables.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip)
}
return iptables.IPHeaderFilter{
Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
@@ -492,6 +651,18 @@ func containsUnsupportedFields(iptip linux.IPTIP) bool {
iptip.InverseFlags != 0
}
+func validUnderflow(rule iptables.Rule) bool {
+ if len(rule.Matchers) != 0 {
+ return false
+ }
+ switch rule.Target.(type) {
+ case iptables.AcceptTarget, iptables.DropTarget:
+ return true
+ default:
+ return false
+ }
+}
+
func hookFromLinux(hook int) iptables.Hook {
switch hook {
case linux.NF_INET_PRE_ROUTING:
@@ -507,52 +678,3 @@ func hookFromLinux(hook int) iptables.Hook {
}
panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain", hook))
}
-
-// printReplace prints information about the struct ipt_replace in optVal. It
-// is only for debugging.
-func printReplace(optVal []byte) {
- // Basic replace info.
- var replace linux.IPTReplace
- replaceBuf := optVal[:linux.SizeOfIPTReplace]
- optVal = optVal[linux.SizeOfIPTReplace:]
- binary.Unmarshal(replaceBuf, usermem.ByteOrder, &replace)
- log.Infof("Replacing table %q: %+v", replace.Name.String(), replace)
-
- // Read in the list of entries at the end of replace.
- var totalOffset uint16
- for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ {
- var entry linux.IPTEntry
- entryBuf := optVal[:linux.SizeOfIPTEntry]
- binary.Unmarshal(entryBuf, usermem.ByteOrder, &entry)
- log.Infof("Entry %d (total offset %d): %+v", entryIdx, totalOffset, entry)
-
- totalOffset += entry.NextOffset
- if entry.TargetOffset == linux.SizeOfIPTEntry {
- log.Infof("Entry has no matches.")
- } else {
- log.Infof("Entry has matches.")
- }
-
- var target linux.XTEntryTarget
- targetBuf := optVal[entry.TargetOffset : entry.TargetOffset+linux.SizeOfXTEntryTarget]
- binary.Unmarshal(targetBuf, usermem.ByteOrder, &target)
- log.Infof("Target named %q: %+v", target.Name.String(), target)
-
- switch target.Name.String() {
- case "":
- var standardTarget linux.XTStandardTarget
- stBuf := optVal[entry.TargetOffset : entry.TargetOffset+linux.SizeOfXTStandardTarget]
- binary.Unmarshal(stBuf, usermem.ByteOrder, &standardTarget)
- log.Infof("Standard target with verdict %q (%d).", linux.VerdictStrings[standardTarget.Verdict], standardTarget.Verdict)
- case errorTargetName:
- var errorTarget linux.XTErrorTarget
- etBuf := optVal[entry.TargetOffset : entry.TargetOffset+linux.SizeOfXTErrorTarget]
- binary.Unmarshal(etBuf, usermem.ByteOrder, &errorTarget)
- log.Infof("Error target with name %q.", errorTarget.Name.String())
- default:
- log.Infof("Unknown target type.")
- }
-
- optVal = optVal[entry.NextOffset:]
- }
-}
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
new file mode 100644
index 000000000..c421b87cf
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -0,0 +1,35 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
+)
+
+// JumpTarget implements iptables.Target.
+type JumpTarget struct {
+ // Offset is the byte offset of the rule to jump to. It is used for
+ // marshaling and unmarshaling.
+ Offset uint32
+
+ // RuleNum is the rule to jump to.
+ RuleNum int
+}
+
+// Action implements iptables.Target.Action.
+func (jt JumpTarget) Action(tcpip.PacketBuffer) (iptables.RuleVerdict, int) {
+ return iptables.RuleJump, jt.RuleNum
+}
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
new file mode 100644
index 000000000..f9945e214
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -0,0 +1,143 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const matcherNameTCP = "tcp"
+
+func init() {
+ registerMatchMaker(tcpMarshaler{})
+}
+
+// tcpMarshaler implements matchMaker for TCP matching.
+type tcpMarshaler struct{}
+
+// name implements matchMaker.name.
+func (tcpMarshaler) name() string {
+ return matcherNameTCP
+}
+
+// marshal implements matchMaker.marshal.
+func (tcpMarshaler) marshal(mr iptables.Matcher) []byte {
+ matcher := mr.(*TCPMatcher)
+ xttcp := linux.XTTCP{
+ SourcePortStart: matcher.sourcePortStart,
+ SourcePortEnd: matcher.sourcePortEnd,
+ DestinationPortStart: matcher.destinationPortStart,
+ DestinationPortEnd: matcher.destinationPortEnd,
+ }
+ buf := make([]byte, 0, linux.SizeOfXTTCP)
+ return marshalEntryMatch(matcherNameTCP, binary.Marshal(buf, usermem.ByteOrder, xttcp))
+}
+
+// unmarshal implements matchMaker.unmarshal.
+func (tcpMarshaler) unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error) {
+ if len(buf) < linux.SizeOfXTTCP {
+ return nil, fmt.Errorf("buf has insufficient size for TCP match: %d", len(buf))
+ }
+
+ // For alignment reasons, the match's total size may
+ // exceed what's strictly necessary to hold matchData.
+ var matchData linux.XTTCP
+ binary.Unmarshal(buf[:linux.SizeOfXTTCP], usermem.ByteOrder, &matchData)
+ nflog("parseMatchers: parsed XTTCP: %+v", matchData)
+
+ if matchData.Option != 0 ||
+ matchData.FlagMask != 0 ||
+ matchData.FlagCompare != 0 ||
+ matchData.InverseFlags != 0 {
+ return nil, fmt.Errorf("unsupported TCP matcher flags set")
+ }
+
+ if filter.Protocol != header.TCPProtocolNumber {
+ return nil, fmt.Errorf("TCP matching is only valid for protocol %d.", header.TCPProtocolNumber)
+ }
+
+ return &TCPMatcher{
+ sourcePortStart: matchData.SourcePortStart,
+ sourcePortEnd: matchData.SourcePortEnd,
+ destinationPortStart: matchData.DestinationPortStart,
+ destinationPortEnd: matchData.DestinationPortEnd,
+ }, nil
+}
+
+// TCPMatcher matches TCP packets and their headers. It implements Matcher.
+type TCPMatcher struct {
+ sourcePortStart uint16
+ sourcePortEnd uint16
+ destinationPortStart uint16
+ destinationPortEnd uint16
+}
+
+// Name implements Matcher.Name.
+func (*TCPMatcher) Name() string {
+ return matcherNameTCP
+}
+
+// Match implements Matcher.Match.
+func (tm *TCPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) {
+ netHeader := header.IPv4(pkt.NetworkHeader)
+
+ if netHeader.TransportProtocol() != header.TCPProtocolNumber {
+ return false, false
+ }
+
+ // We dont't match fragments.
+ if frag := netHeader.FragmentOffset(); frag != 0 {
+ if frag == 1 {
+ return false, true
+ }
+ return false, false
+ }
+
+ // Now we need the transport header. However, this may not have been set
+ // yet.
+ // TODO(gvisor.dev/issue/170): Parsing the transport header should
+ // ultimately be moved into the iptables.Check codepath as matchers are
+ // added.
+ var tcpHeader header.TCP
+ if pkt.TransportHeader != nil {
+ tcpHeader = header.TCP(pkt.TransportHeader)
+ } else {
+ // The TCP header hasn't been parsed yet. We have to do it here.
+ if len(pkt.Data.First()) < header.TCPMinimumSize {
+ // There's no valid TCP header here, so we hotdrop the
+ // packet.
+ return false, true
+ }
+ tcpHeader = header.TCP(pkt.Data.First())
+ }
+
+ // Check whether the source and destination ports are within the
+ // matching range.
+ if sourcePort := tcpHeader.SourcePort(); sourcePort < tm.sourcePortStart || tm.sourcePortEnd < sourcePort {
+ return false, false
+ }
+ if destinationPort := tcpHeader.DestinationPort(); destinationPort < tm.destinationPortStart || tm.destinationPortEnd < destinationPort {
+ return false, false
+ }
+
+ return true, false
+}
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
new file mode 100644
index 000000000..86aa11696
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -0,0 +1,142 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const matcherNameUDP = "udp"
+
+func init() {
+ registerMatchMaker(udpMarshaler{})
+}
+
+// udpMarshaler implements matchMaker for UDP matching.
+type udpMarshaler struct{}
+
+// name implements matchMaker.name.
+func (udpMarshaler) name() string {
+ return matcherNameUDP
+}
+
+// marshal implements matchMaker.marshal.
+func (udpMarshaler) marshal(mr iptables.Matcher) []byte {
+ matcher := mr.(*UDPMatcher)
+ xtudp := linux.XTUDP{
+ SourcePortStart: matcher.sourcePortStart,
+ SourcePortEnd: matcher.sourcePortEnd,
+ DestinationPortStart: matcher.destinationPortStart,
+ DestinationPortEnd: matcher.destinationPortEnd,
+ }
+ buf := make([]byte, 0, linux.SizeOfXTUDP)
+ return marshalEntryMatch(matcherNameUDP, binary.Marshal(buf, usermem.ByteOrder, xtudp))
+}
+
+// unmarshal implements matchMaker.unmarshal.
+func (udpMarshaler) unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error) {
+ if len(buf) < linux.SizeOfXTUDP {
+ return nil, fmt.Errorf("buf has insufficient size for UDP match: %d", len(buf))
+ }
+
+ // For alignment reasons, the match's total size may exceed what's
+ // strictly necessary to hold matchData.
+ var matchData linux.XTUDP
+ binary.Unmarshal(buf[:linux.SizeOfXTUDP], usermem.ByteOrder, &matchData)
+ nflog("parseMatchers: parsed XTUDP: %+v", matchData)
+
+ if matchData.InverseFlags != 0 {
+ return nil, fmt.Errorf("unsupported UDP matcher inverse flags set")
+ }
+
+ if filter.Protocol != header.UDPProtocolNumber {
+ return nil, fmt.Errorf("UDP matching is only valid for protocol %d.", header.UDPProtocolNumber)
+ }
+
+ return &UDPMatcher{
+ sourcePortStart: matchData.SourcePortStart,
+ sourcePortEnd: matchData.SourcePortEnd,
+ destinationPortStart: matchData.DestinationPortStart,
+ destinationPortEnd: matchData.DestinationPortEnd,
+ }, nil
+}
+
+// UDPMatcher matches UDP packets and their headers. It implements Matcher.
+type UDPMatcher struct {
+ sourcePortStart uint16
+ sourcePortEnd uint16
+ destinationPortStart uint16
+ destinationPortEnd uint16
+}
+
+// Name implements Matcher.Name.
+func (*UDPMatcher) Name() string {
+ return matcherNameUDP
+}
+
+// Match implements Matcher.Match.
+func (um *UDPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) {
+ netHeader := header.IPv4(pkt.NetworkHeader)
+
+ // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved
+ // into the iptables.Check codepath as matchers are added.
+ if netHeader.TransportProtocol() != header.UDPProtocolNumber {
+ return false, false
+ }
+
+ // We dont't match fragments.
+ if frag := netHeader.FragmentOffset(); frag != 0 {
+ if frag == 1 {
+ return false, true
+ }
+ return false, false
+ }
+
+ // Now we need the transport header. However, this may not have been set
+ // yet.
+ // TODO(gvisor.dev/issue/170): Parsing the transport header should
+ // ultimately be moved into the iptables.Check codepath as matchers are
+ // added.
+ var udpHeader header.UDP
+ if pkt.TransportHeader != nil {
+ udpHeader = header.UDP(pkt.TransportHeader)
+ } else {
+ // The UDP header hasn't been parsed yet. We have to do it here.
+ if len(pkt.Data.First()) < header.UDPMinimumSize {
+ // There's no valid UDP header here, so we hotdrop the
+ // packet.
+ return false, true
+ }
+ udpHeader = header.UDP(pkt.Data.First())
+ }
+
+ // Check whether the source and destination ports are within the
+ // matching range.
+ if sourcePort := udpHeader.SourcePort(); sourcePort < um.sourcePortStart || um.sourcePortEnd < sourcePort {
+ return false, false
+ }
+ if destinationPort := udpHeader.DestinationPort(); destinationPort < um.destinationPortStart || um.destinationPortEnd < destinationPort {
+ return false, false
+ }
+
+ return true, false
+}
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index f8b8e467d..1911cd9b8 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -33,3 +33,15 @@ go_library(
"//pkg/waiter",
],
)
+
+go_test(
+ name = "netlink_test",
+ size = "small",
+ srcs = [
+ "message_test.go",
+ ],
+ deps = [
+ ":netlink",
+ "//pkg/abi/linux",
+ ],
+)
diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go
index b21e0ca4b..0899c61d1 100644
--- a/pkg/sentry/socket/netlink/message.go
+++ b/pkg/sentry/socket/netlink/message.go
@@ -23,15 +23,16 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-// alignUp rounds a length up to an alignment.
+// alignPad returns the length of padding required for alignment.
//
// Preconditions: align is a power of two.
-func alignUp(length int, align uint) int {
- return (length + int(align) - 1) &^ (int(align) - 1)
+func alignPad(length int, align uint) int {
+ return binary.AlignUp(length, align) - length
}
// Message contains a complete serialized netlink message.
type Message struct {
+ hdr linux.NetlinkMessageHeader
buf []byte
}
@@ -40,10 +41,86 @@ type Message struct {
// The header length will be updated by Finalize.
func NewMessage(hdr linux.NetlinkMessageHeader) *Message {
return &Message{
+ hdr: hdr,
buf: binary.Marshal(nil, usermem.ByteOrder, hdr),
}
}
+// ParseMessage parses the first message seen at buf, returning the rest of the
+// buffer. If message is malformed, ok of false is returned. For last message,
+// padding check is loose, if there isn't enought padding, whole buf is consumed
+// and ok is set to true.
+func ParseMessage(buf []byte) (msg *Message, rest []byte, ok bool) {
+ b := BytesView(buf)
+
+ hdrBytes, ok := b.Extract(linux.NetlinkMessageHeaderSize)
+ if !ok {
+ return
+ }
+ var hdr linux.NetlinkMessageHeader
+ binary.Unmarshal(hdrBytes, usermem.ByteOrder, &hdr)
+
+ // Msg portion.
+ totalMsgLen := int(hdr.Length)
+ _, ok = b.Extract(totalMsgLen - linux.NetlinkMessageHeaderSize)
+ if !ok {
+ return
+ }
+
+ // Padding.
+ numPad := alignPad(totalMsgLen, linux.NLMSG_ALIGNTO)
+ // Linux permits the last message not being aligned, just consume all of it.
+ // Ref: net/netlink/af_netlink.c:netlink_rcv_skb
+ if numPad > len(b) {
+ numPad = len(b)
+ }
+ _, ok = b.Extract(numPad)
+ if !ok {
+ return
+ }
+
+ return &Message{
+ hdr: hdr,
+ buf: buf[:totalMsgLen],
+ }, []byte(b), true
+}
+
+// Header returns the header of this message.
+func (m *Message) Header() linux.NetlinkMessageHeader {
+ return m.hdr
+}
+
+// GetData unmarshals the payload message header from this netlink message, and
+// returns the attributes portion.
+func (m *Message) GetData(msg interface{}) (AttrsView, bool) {
+ b := BytesView(m.buf)
+
+ _, ok := b.Extract(linux.NetlinkMessageHeaderSize)
+ if !ok {
+ return nil, false
+ }
+
+ size := int(binary.Size(msg))
+ msgBytes, ok := b.Extract(size)
+ if !ok {
+ return nil, false
+ }
+ binary.Unmarshal(msgBytes, usermem.ByteOrder, msg)
+
+ numPad := alignPad(linux.NetlinkMessageHeaderSize+size, linux.NLMSG_ALIGNTO)
+ // Linux permits the last message not being aligned, just consume all of it.
+ // Ref: net/netlink/af_netlink.c:netlink_rcv_skb
+ if numPad > len(b) {
+ numPad = len(b)
+ }
+ _, ok = b.Extract(numPad)
+ if !ok {
+ return nil, false
+ }
+
+ return AttrsView(b), true
+}
+
// Finalize returns the []byte containing the entire message, with the total
// length set in the message header. The Message must not be modified after
// calling Finalize.
@@ -54,7 +131,7 @@ func (m *Message) Finalize() []byte {
// Align the message. Note that the message length in the header (set
// above) is the useful length of the message, not the total aligned
// length. See net/netlink/af_netlink.c:__nlmsg_put.
- aligned := alignUp(len(m.buf), linux.NLMSG_ALIGNTO)
+ aligned := binary.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO)
m.putZeros(aligned - len(m.buf))
return m.buf
}
@@ -89,7 +166,7 @@ func (m *Message) PutAttr(atype uint16, v interface{}) {
m.Put(v)
// Align the attribute.
- aligned := alignUp(l, linux.NLA_ALIGNTO)
+ aligned := binary.AlignUp(l, linux.NLA_ALIGNTO)
m.putZeros(aligned - l)
}
@@ -106,7 +183,7 @@ func (m *Message) PutAttrString(atype uint16, s string) {
m.putZeros(1)
// Align the attribute.
- aligned := alignUp(l, linux.NLA_ALIGNTO)
+ aligned := binary.AlignUp(l, linux.NLA_ALIGNTO)
m.putZeros(aligned - l)
}
@@ -157,3 +234,48 @@ func (ms *MessageSet) AddMessage(hdr linux.NetlinkMessageHeader) *Message {
ms.Messages = append(ms.Messages, m)
return m
}
+
+// AttrsView is a view into the attributes portion of a netlink message.
+type AttrsView []byte
+
+// Empty returns whether there is no attribute left in v.
+func (v AttrsView) Empty() bool {
+ return len(v) == 0
+}
+
+// ParseFirst parses first netlink attribute at the beginning of v.
+func (v AttrsView) ParseFirst() (hdr linux.NetlinkAttrHeader, value []byte, rest AttrsView, ok bool) {
+ b := BytesView(v)
+
+ hdrBytes, ok := b.Extract(linux.NetlinkAttrHeaderSize)
+ if !ok {
+ return
+ }
+ binary.Unmarshal(hdrBytes, usermem.ByteOrder, &hdr)
+
+ value, ok = b.Extract(int(hdr.Length) - linux.NetlinkAttrHeaderSize)
+ if !ok {
+ return
+ }
+
+ _, ok = b.Extract(alignPad(int(hdr.Length), linux.NLA_ALIGNTO))
+ if !ok {
+ return
+ }
+
+ return hdr, value, AttrsView(b), ok
+}
+
+// BytesView supports extracting data from a byte slice with bounds checking.
+type BytesView []byte
+
+// Extract removes the first n bytes from v and returns it. If n is out of
+// bounds, it returns false.
+func (v *BytesView) Extract(n int) ([]byte, bool) {
+ if n < 0 || n > len(*v) {
+ return nil, false
+ }
+ extracted := (*v)[:n]
+ *v = (*v)[n:]
+ return extracted, true
+}
diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go
new file mode 100644
index 000000000..ef13d9386
--- /dev/null
+++ b/pkg/sentry/socket/netlink/message_test.go
@@ -0,0 +1,312 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package message_test
+
+import (
+ "bytes"
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netlink"
+)
+
+type dummyNetlinkMsg struct {
+ Foo uint16
+}
+
+func TestParseMessage(t *testing.T) {
+ tests := []struct {
+ desc string
+ input []byte
+
+ header linux.NetlinkMessageHeader
+ dataMsg *dummyNetlinkMsg
+ restLen int
+ ok bool
+ }{
+ {
+ desc: "valid",
+ input: []byte{
+ 0x14, 0x00, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding
+ },
+ header: linux.NetlinkMessageHeader{
+ Length: 20,
+ Type: 1,
+ Flags: 2,
+ Seq: 3,
+ PortID: 4,
+ },
+ dataMsg: &dummyNetlinkMsg{
+ Foo: 0x3130,
+ },
+ restLen: 0,
+ ok: true,
+ },
+ {
+ desc: "valid with next message",
+ input: []byte{
+ 0x14, 0x00, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding
+ 0xFF, // Next message (rest)
+ },
+ header: linux.NetlinkMessageHeader{
+ Length: 20,
+ Type: 1,
+ Flags: 2,
+ Seq: 3,
+ PortID: 4,
+ },
+ dataMsg: &dummyNetlinkMsg{
+ Foo: 0x3130,
+ },
+ restLen: 1,
+ ok: true,
+ },
+ {
+ desc: "valid for last message without padding",
+ input: []byte{
+ 0x12, 0x00, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, // Data message
+ },
+ header: linux.NetlinkMessageHeader{
+ Length: 18,
+ Type: 1,
+ Flags: 2,
+ Seq: 3,
+ PortID: 4,
+ },
+ dataMsg: &dummyNetlinkMsg{
+ Foo: 0x3130,
+ },
+ restLen: 0,
+ ok: true,
+ },
+ {
+ desc: "valid for last message not to be aligned",
+ input: []byte{
+ 0x13, 0x00, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, // Data message
+ 0x00, // Excessive 1 byte permitted at end
+ },
+ header: linux.NetlinkMessageHeader{
+ Length: 19,
+ Type: 1,
+ Flags: 2,
+ Seq: 3,
+ PortID: 4,
+ },
+ dataMsg: &dummyNetlinkMsg{
+ Foo: 0x3130,
+ },
+ restLen: 0,
+ ok: true,
+ },
+ {
+ desc: "header.Length too short",
+ input: []byte{
+ 0x04, 0x00, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding
+ },
+ ok: false,
+ },
+ {
+ desc: "header.Length too long",
+ input: []byte{
+ 0xFF, 0xFF, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding
+ },
+ ok: false,
+ },
+ {
+ desc: "header incomplete",
+ input: []byte{
+ 0x04, 0x00, 0x00, 0x00, // Length
+ },
+ ok: false,
+ },
+ {
+ desc: "empty message",
+ input: []byte{},
+ ok: false,
+ },
+ }
+ for _, test := range tests {
+ msg, rest, ok := netlink.ParseMessage(test.input)
+ if ok != test.ok {
+ t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok)
+ continue
+ }
+ if !test.ok {
+ continue
+ }
+ if !reflect.DeepEqual(msg.Header(), test.header) {
+ t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, msg.Header(), test.header)
+ }
+
+ dataMsg := &dummyNetlinkMsg{}
+ _, dataOk := msg.GetData(dataMsg)
+ if !dataOk {
+ t.Errorf("%v: GetData.ok = %v, want = true", test.desc, dataOk)
+ } else if !reflect.DeepEqual(dataMsg, test.dataMsg) {
+ t.Errorf("%v: GetData.msg = %+v, want = %+v", test.desc, dataMsg, test.dataMsg)
+ }
+
+ if got, want := rest, test.input[len(test.input)-test.restLen:]; !bytes.Equal(got, want) {
+ t.Errorf("%v: got rest = %v, want = %v", test.desc, got, want)
+ }
+ }
+}
+
+func TestAttrView(t *testing.T) {
+ tests := []struct {
+ desc string
+ input []byte
+
+ // Outputs for ParseFirst.
+ hdr linux.NetlinkAttrHeader
+ value []byte
+ restLen int
+ ok bool
+
+ // Outputs for Empty.
+ isEmpty bool
+ }{
+ {
+ desc: "valid",
+ input: []byte{
+ 0x06, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x30, 0x31, 0x00, 0x00, // Data with 2 bytes padding
+ },
+ hdr: linux.NetlinkAttrHeader{
+ Length: 6,
+ Type: 1,
+ },
+ value: []byte{0x30, 0x31},
+ restLen: 0,
+ ok: true,
+ isEmpty: false,
+ },
+ {
+ desc: "at alignment",
+ input: []byte{
+ 0x08, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x30, 0x31, 0x32, 0x33, // Data
+ },
+ hdr: linux.NetlinkAttrHeader{
+ Length: 8,
+ Type: 1,
+ },
+ value: []byte{0x30, 0x31, 0x32, 0x33},
+ restLen: 0,
+ ok: true,
+ isEmpty: false,
+ },
+ {
+ desc: "at alignment with rest data",
+ input: []byte{
+ 0x08, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x30, 0x31, 0x32, 0x33, // Data
+ 0xFF, 0xFE, // Rest data
+ },
+ hdr: linux.NetlinkAttrHeader{
+ Length: 8,
+ Type: 1,
+ },
+ value: []byte{0x30, 0x31, 0x32, 0x33},
+ restLen: 2,
+ ok: true,
+ isEmpty: false,
+ },
+ {
+ desc: "hdr.Length too long",
+ input: []byte{
+ 0xFF, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x30, 0x31, 0x32, 0x33, // Data
+ },
+ ok: false,
+ isEmpty: false,
+ },
+ {
+ desc: "hdr.Length too short",
+ input: []byte{
+ 0x01, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x30, 0x31, 0x32, 0x33, // Data
+ },
+ ok: false,
+ isEmpty: false,
+ },
+ {
+ desc: "empty",
+ input: []byte{},
+ ok: false,
+ isEmpty: true,
+ },
+ }
+ for _, test := range tests {
+ attrs := netlink.AttrsView(test.input)
+
+ // Test ParseFirst().
+ hdr, value, rest, ok := attrs.ParseFirst()
+ if ok != test.ok {
+ t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok)
+ } else if test.ok {
+ if !reflect.DeepEqual(hdr, test.hdr) {
+ t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, hdr, test.hdr)
+ }
+ if !bytes.Equal(value, test.value) {
+ t.Errorf("%v: got value = %v, want = %v", test.desc, value, test.value)
+ }
+ if wantRest := test.input[len(test.input)-test.restLen:]; !bytes.Equal(rest, wantRest) {
+ t.Errorf("%v: got rest = %v, want = %v", test.desc, rest, wantRest)
+ }
+ }
+
+ // Test Empty().
+ if got, want := attrs.Empty(), test.isEmpty; got != want {
+ t.Errorf("%v: got empty = %v, want = %v", test.desc, got, want)
+ }
+ }
+}
diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go
index 07f860a49..b0dc70e5c 100644
--- a/pkg/sentry/socket/netlink/provider.go
+++ b/pkg/sentry/socket/netlink/provider.go
@@ -42,7 +42,7 @@ type Protocol interface {
// If err == nil, any messages added to ms will be sent back to the
// other end of the socket. Setting ms.Multi will cause an NLMSG_DONE
// message to be sent even if ms contains no messages.
- ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *MessageSet) *syserr.Error
+ ProcessMessage(ctx context.Context, msg *Message, ms *MessageSet) *syserr.Error
}
// Provider is a function that creates a new Protocol for a specific netlink
diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD
index 622a1eafc..93127398d 100644
--- a/pkg/sentry/socket/netlink/route/BUILD
+++ b/pkg/sentry/socket/netlink/route/BUILD
@@ -10,13 +10,11 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/context",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/socket/netlink",
"//pkg/syserr",
- "//pkg/usermem",
],
)
diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go
index 2b3c7f5b3..c84d8bd7c 100644
--- a/pkg/sentry/socket/netlink/route/protocol.go
+++ b/pkg/sentry/socket/netlink/route/protocol.go
@@ -17,16 +17,15 @@ package route
import (
"bytes"
+ "syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink"
"gvisor.dev/gvisor/pkg/syserr"
- "gvisor.dev/gvisor/pkg/usermem"
)
// commandKind describes the operational class of a message type.
@@ -69,13 +68,7 @@ func (p *Protocol) CanSend() bool {
}
// dumpLinks handles RTM_GETLINK dump requests.
-func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error {
- // TODO(b/68878065): Only the dump variant of the types below are
- // supported.
- if hdr.Flags&linux.NLM_F_DUMP != linux.NLM_F_DUMP {
- return syserr.ErrNotSupported
- }
-
+func (p *Protocol) dumpLinks(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
// NLM_F_DUMP + RTM_GETLINK messages are supposed to include an
// ifinfomsg. However, Linux <3.9 only checked for rtgenmsg, and some
// userspace applications (including glibc) still include rtgenmsg.
@@ -99,44 +92,105 @@ func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader
return nil
}
- for id, i := range stack.Interfaces() {
- m := ms.AddMessage(linux.NetlinkMessageHeader{
- Type: linux.RTM_NEWLINK,
- })
+ for idx, i := range stack.Interfaces() {
+ addNewLinkMessage(ms, idx, i)
+ }
- m.Put(linux.InterfaceInfoMessage{
- Family: linux.AF_UNSPEC,
- Type: i.DeviceType,
- Index: id,
- Flags: i.Flags,
- })
+ return nil
+}
- m.PutAttrString(linux.IFLA_IFNAME, i.Name)
- m.PutAttr(linux.IFLA_MTU, i.MTU)
+// getLinks handles RTM_GETLINK requests.
+func (p *Protocol) getLink(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network devices.
+ return nil
+ }
- mac := make([]byte, 6)
- brd := mac
- if len(i.Addr) > 0 {
- mac = i.Addr
- brd = bytes.Repeat([]byte{0xff}, len(i.Addr))
+ // Parse message.
+ var ifi linux.InterfaceInfoMessage
+ attrs, ok := msg.GetData(&ifi)
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+
+ // Parse attributes.
+ var byName []byte
+ for !attrs.Empty() {
+ ahdr, value, rest, ok := attrs.ParseFirst()
+ if !ok {
+ return syserr.ErrInvalidArgument
}
- m.PutAttr(linux.IFLA_ADDRESS, mac)
- m.PutAttr(linux.IFLA_BROADCAST, brd)
+ attrs = rest
- // TODO(gvisor.dev/issue/578): There are many more attributes.
+ switch ahdr.Type {
+ case linux.IFLA_IFNAME:
+ if len(value) < 1 {
+ return syserr.ErrInvalidArgument
+ }
+ byName = value[:len(value)-1]
+
+ // TODO(gvisor.dev/issue/578): Support IFLA_EXT_MASK.
+ }
}
+ found := false
+ for idx, i := range stack.Interfaces() {
+ switch {
+ case ifi.Index > 0:
+ if idx != ifi.Index {
+ continue
+ }
+ case byName != nil:
+ if string(byName) != i.Name {
+ continue
+ }
+ default:
+ // Criteria not specified.
+ return syserr.ErrInvalidArgument
+ }
+
+ addNewLinkMessage(ms, idx, i)
+ found = true
+ break
+ }
+ if !found {
+ return syserr.ErrNoDevice
+ }
return nil
}
-// dumpAddrs handles RTM_GETADDR dump requests.
-func (p *Protocol) dumpAddrs(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error {
- // TODO(b/68878065): Only the dump variant of the types below are
- // supported.
- if hdr.Flags&linux.NLM_F_DUMP != linux.NLM_F_DUMP {
- return syserr.ErrNotSupported
+// addNewLinkMessage appends RTM_NEWLINK message for the given interface into
+// the message set.
+func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) {
+ m := ms.AddMessage(linux.NetlinkMessageHeader{
+ Type: linux.RTM_NEWLINK,
+ })
+
+ m.Put(linux.InterfaceInfoMessage{
+ Family: linux.AF_UNSPEC,
+ Type: i.DeviceType,
+ Index: idx,
+ Flags: i.Flags,
+ })
+
+ m.PutAttrString(linux.IFLA_IFNAME, i.Name)
+ m.PutAttr(linux.IFLA_MTU, i.MTU)
+
+ mac := make([]byte, 6)
+ brd := mac
+ if len(i.Addr) > 0 {
+ mac = i.Addr
+ brd = bytes.Repeat([]byte{0xff}, len(i.Addr))
}
+ m.PutAttr(linux.IFLA_ADDRESS, mac)
+ m.PutAttr(linux.IFLA_BROADCAST, brd)
+
+ // TODO(gvisor.dev/issue/578): There are many more attributes.
+}
+// dumpAddrs handles RTM_GETADDR dump requests.
+func (p *Protocol) dumpAddrs(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
// RTM_GETADDR dump requests need not contain anything more than the
// netlink header and 1 byte protocol family common to all
// NETLINK_ROUTE requests.
@@ -168,6 +222,7 @@ func (p *Protocol) dumpAddrs(ctx context.Context, hdr linux.NetlinkMessageHeader
Index: uint32(id),
})
+ m.PutAttr(linux.IFA_LOCAL, []byte(a.Addr))
m.PutAttr(linux.IFA_ADDRESS, []byte(a.Addr))
// TODO(gvisor.dev/issue/578): There are many more attributes.
@@ -252,12 +307,12 @@ func fillRoute(routes []inet.Route, addr []byte) (inet.Route, *syserr.Error) {
}
// parseForDestination parses a message as format of RouteMessage-RtAttr-dst.
-func parseForDestination(data []byte) ([]byte, *syserr.Error) {
+func parseForDestination(msg *netlink.Message) ([]byte, *syserr.Error) {
var rtMsg linux.RouteMessage
- if len(data) < linux.SizeOfRouteMessage {
+ attrs, ok := msg.GetData(&rtMsg)
+ if !ok {
return nil, syserr.ErrInvalidArgument
}
- binary.Unmarshal(data[:linux.SizeOfRouteMessage], usermem.ByteOrder, &rtMsg)
// iproute2 added the RTM_F_LOOKUP_TABLE flag in version v4.4.0. See
// commit bc234301af12. Note we don't check this flag for backward
// compatibility.
@@ -265,26 +320,15 @@ func parseForDestination(data []byte) ([]byte, *syserr.Error) {
return nil, syserr.ErrNotSupported
}
- data = data[linux.SizeOfRouteMessage:]
-
- // TODO(gvisor.dev/issue/1611): Add generic attribute parsing.
- var rtAttr linux.RtAttr
- if len(data) < linux.SizeOfRtAttr {
- return nil, syserr.ErrInvalidArgument
+ // Expect first attribute is RTA_DST.
+ if hdr, value, _, ok := attrs.ParseFirst(); ok && hdr.Type == linux.RTA_DST {
+ return value, nil
}
- binary.Unmarshal(data[:linux.SizeOfRtAttr], usermem.ByteOrder, &rtAttr)
- if rtAttr.Type != linux.RTA_DST {
- return nil, syserr.ErrInvalidArgument
- }
-
- if len(data) < int(rtAttr.Len) {
- return nil, syserr.ErrInvalidArgument
- }
- return data[linux.SizeOfRtAttr:rtAttr.Len], nil
+ return nil, syserr.ErrInvalidArgument
}
// dumpRoutes handles RTM_GETROUTE requests.
-func (p *Protocol) dumpRoutes(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error {
+func (p *Protocol) dumpRoutes(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
// RTM_GETROUTE dump requests need not contain anything more than the
// netlink header and 1 byte protocol family common to all
// NETLINK_ROUTE requests.
@@ -295,10 +339,11 @@ func (p *Protocol) dumpRoutes(ctx context.Context, hdr linux.NetlinkMessageHeade
return nil
}
+ hdr := msg.Header()
routeTables := stack.RouteTable()
if hdr.Flags == linux.NLM_F_REQUEST {
- dst, err := parseForDestination(data)
+ dst, err := parseForDestination(msg)
if err != nil {
return err
}
@@ -357,10 +402,55 @@ func (p *Protocol) dumpRoutes(ctx context.Context, hdr linux.NetlinkMessageHeade
return nil
}
+// newAddr handles RTM_NEWADDR requests.
+func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network stack.
+ return syserr.ErrProtocolNotSupported
+ }
+
+ var ifa linux.InterfaceAddrMessage
+ attrs, ok := msg.GetData(&ifa)
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+
+ for !attrs.Empty() {
+ ahdr, value, rest, ok := attrs.ParseFirst()
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+ attrs = rest
+
+ switch ahdr.Type {
+ case linux.IFA_LOCAL:
+ err := stack.AddInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{
+ Family: ifa.Family,
+ PrefixLen: ifa.PrefixLen,
+ Flags: ifa.Flags,
+ Addr: value,
+ })
+ if err == syscall.EEXIST {
+ flags := msg.Header().Flags
+ if flags&linux.NLM_F_EXCL != 0 {
+ return syserr.ErrExists
+ }
+ } else if err != nil {
+ return syserr.ErrInvalidArgument
+ }
+ }
+ }
+ return nil
+}
+
// ProcessMessage implements netlink.Protocol.ProcessMessage.
-func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error {
+func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ hdr := msg.Header()
+
// All messages start with a 1 byte protocol family.
- if len(data) < 1 {
+ var family uint8
+ if _, ok := msg.GetData(&family); !ok {
// Linux ignores messages missing the protocol family. See
// net/core/rtnetlink.c:rtnetlink_rcv_msg.
return nil
@@ -374,16 +464,32 @@ func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageH
}
}
- switch hdr.Type {
- case linux.RTM_GETLINK:
- return p.dumpLinks(ctx, hdr, data, ms)
- case linux.RTM_GETADDR:
- return p.dumpAddrs(ctx, hdr, data, ms)
- case linux.RTM_GETROUTE:
- return p.dumpRoutes(ctx, hdr, data, ms)
- default:
- return syserr.ErrNotSupported
+ if hdr.Flags&linux.NLM_F_DUMP == linux.NLM_F_DUMP {
+ // TODO(b/68878065): Only the dump variant of the types below are
+ // supported.
+ switch hdr.Type {
+ case linux.RTM_GETLINK:
+ return p.dumpLinks(ctx, msg, ms)
+ case linux.RTM_GETADDR:
+ return p.dumpAddrs(ctx, msg, ms)
+ case linux.RTM_GETROUTE:
+ return p.dumpRoutes(ctx, msg, ms)
+ default:
+ return syserr.ErrNotSupported
+ }
+ } else if hdr.Flags&linux.NLM_F_REQUEST == linux.NLM_F_REQUEST {
+ switch hdr.Type {
+ case linux.RTM_GETLINK:
+ return p.getLink(ctx, msg, ms)
+ case linux.RTM_GETROUTE:
+ return p.dumpRoutes(ctx, msg, ms)
+ case linux.RTM_NEWADDR:
+ return p.newAddr(ctx, msg, ms)
+ default:
+ return syserr.ErrNotSupported
+ }
}
+ return syserr.ErrNotSupported
}
// init registers the NETLINK_ROUTE provider.
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index c4b95debb..2ca02567d 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -644,47 +644,38 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error
return nil
}
-func (s *Socket) dumpErrorMesage(ctx context.Context, hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr.Error) *syserr.Error {
+func dumpErrorMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr.Error) {
m := ms.AddMessage(linux.NetlinkMessageHeader{
Type: linux.NLMSG_ERROR,
})
-
m.Put(linux.NetlinkErrorMessage{
Error: int32(-err.ToLinux().Number()),
Header: hdr,
})
- return nil
+}
+func dumpAckMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet) {
+ m := ms.AddMessage(linux.NetlinkMessageHeader{
+ Type: linux.NLMSG_ERROR,
+ })
+ m.Put(linux.NetlinkErrorMessage{
+ Error: 0,
+ Header: hdr,
+ })
}
// processMessages handles each message in buf, passing it to the protocol
// handler for final handling.
func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error {
for len(buf) > 0 {
- if len(buf) < linux.NetlinkMessageHeaderSize {
+ msg, rest, ok := ParseMessage(buf)
+ if !ok {
// Linux ignores messages that are too short. See
// net/netlink/af_netlink.c:netlink_rcv_skb.
break
}
-
- var hdr linux.NetlinkMessageHeader
- binary.Unmarshal(buf[:linux.NetlinkMessageHeaderSize], usermem.ByteOrder, &hdr)
-
- if hdr.Length < linux.NetlinkMessageHeaderSize || uint64(hdr.Length) > uint64(len(buf)) {
- // Linux ignores malformed messages. See
- // net/netlink/af_netlink.c:netlink_rcv_skb.
- break
- }
-
- // Data from this message.
- data := buf[linux.NetlinkMessageHeaderSize:hdr.Length]
-
- // Advance to the next message.
- next := alignUp(int(hdr.Length), linux.NLMSG_ALIGNTO)
- if next >= len(buf)-1 {
- next = len(buf) - 1
- }
- buf = buf[next:]
+ buf = rest
+ hdr := msg.Header()
// Ignore control messages.
if hdr.Type < linux.NLMSG_MIN_TYPE {
@@ -692,19 +683,10 @@ func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error
}
ms := NewMessageSet(s.portID, hdr.Seq)
- var err *syserr.Error
- // TODO(b/68877377): ACKs not supported yet.
- if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK {
- err = syserr.ErrNotSupported
- } else {
-
- err = s.protocol.ProcessMessage(ctx, hdr, data, ms)
- }
- if err != nil {
- ms = NewMessageSet(s.portID, hdr.Seq)
- if err := s.dumpErrorMesage(ctx, hdr, ms, err); err != nil {
- return err
- }
+ if err := s.protocol.ProcessMessage(ctx, msg, ms); err != nil {
+ dumpErrorMesage(hdr, ms, err)
+ } else if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK {
+ dumpAckMesage(hdr, ms)
}
if err := s.sendResponse(ctx, ms); err != nil {
diff --git a/pkg/sentry/socket/netlink/uevent/protocol.go b/pkg/sentry/socket/netlink/uevent/protocol.go
index 1ee4296bc..029ba21b5 100644
--- a/pkg/sentry/socket/netlink/uevent/protocol.go
+++ b/pkg/sentry/socket/netlink/uevent/protocol.go
@@ -49,7 +49,7 @@ func (p *Protocol) CanSend() bool {
}
// ProcessMessage implements netlink.Protocol.ProcessMessage.
-func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error {
+func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
// Silently ignore all messages.
return nil
}
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 8619cc506..48c268bfa 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -712,14 +712,40 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
// Bind implements the linux syscall bind(2) for sockets backed by
// tcpip.Endpoint.
func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
- addr, family, err := AddressAndFamily(sockaddr)
- if err != nil {
- return err
- }
- if err := s.checkFamily(family, true /* exact */); err != nil {
- return err
+ family := usermem.ByteOrder.Uint16(sockaddr)
+ var addr tcpip.FullAddress
+
+ // Bind for AF_PACKET requires only family, protocol and ifindex.
+ // In function AddressAndFamily, we check the address length which is
+ // not needed for AF_PACKET bind.
+ if family == linux.AF_PACKET {
+ var a linux.SockAddrLink
+ if len(sockaddr) < sockAddrLinkSize {
+ return syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(sockaddr[:sockAddrLinkSize], usermem.ByteOrder, &a)
+
+ if a.Protocol != uint16(s.protocol) {
+ return syserr.ErrInvalidArgument
+ }
+
+ addr = tcpip.FullAddress{
+ NIC: tcpip.NICID(a.InterfaceIndex),
+ Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ }
+ } else {
+ var err *syserr.Error
+ addr, family, err = AddressAndFamily(sockaddr)
+ if err != nil {
+ return err
+ }
+
+ if err = s.checkFamily(family, true /* exact */); err != nil {
+ return err
+ }
+
+ addr = s.mapFamily(addr, family)
}
- addr = s.mapFamily(addr, family)
// Issue the bind request to the endpoint.
return syserr.TranslateNetstackError(s.Endpoint.Bind(addr))
@@ -1260,6 +1286,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return int32(time.Duration(v) / time.Second), nil
+ case linux.TCP_DEFER_ACCEPT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.TCPDeferAcceptOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(time.Duration(v) / time.Second), nil
+
default:
emitUnimplementedEventTCP(t, name)
}
@@ -1306,6 +1344,22 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
}
return ib, nil
+ case linux.IPV6_RECVTCLASS:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ var o int32
+ if v {
+ o = 1
+ }
+ return o, nil
+
default:
emitUnimplementedEventIPv6(t, name)
}
@@ -1402,6 +1456,21 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
}
return o, nil
+ case linux.IP_PKTINFO:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ var o int32
+ if v {
+ o = 1
+ }
+ return o, nil
+
default:
emitUnimplementedEventIP(t, name)
}
@@ -1713,6 +1782,16 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
v := usermem.ByteOrder.Uint32(optVal)
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v))))
+ case linux.TCP_DEFER_ACCEPT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+ if v < 0 {
+ v = 0
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPDeferAcceptOption(time.Second * time.Duration(v))))
+
case linux.TCP_REPAIR_OPTIONS:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1740,6 +1819,7 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte)
linux.IPV6_IPSEC_POLICY,
linux.IPV6_JOIN_ANYCAST,
linux.IPV6_LEAVE_ANYCAST,
+ // TODO(b/148887420): Add support for IPV6_PKTINFO.
linux.IPV6_PKTINFO,
linux.IPV6_ROUTER_ALERT,
linux.IPV6_XFRM_POLICY,
@@ -1765,6 +1845,14 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte)
}
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v)))
+ case linux.IPV6_RECVTCLASS:
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0))
+
default:
emitUnimplementedEventIPv6(t, name)
}
@@ -1927,6 +2015,16 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
}
return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0))
+ case linux.IP_PKTINFO:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0))
+
case linux.IP_ADD_SOURCE_MEMBERSHIP,
linux.IP_BIND_ADDRESS_NO_PORT,
linux.IP_BLOCK_SOURCE,
@@ -1942,7 +2040,6 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
linux.IP_NODEFRAG,
linux.IP_OPTIONS,
linux.IP_PASSSEC,
- linux.IP_PKTINFO,
linux.IP_RECVERR,
linux.IP_RECVFRAGSIZE,
linux.IP_RECVOPTS,
@@ -2039,7 +2136,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) {
linux.IPV6_RECVPATHMTU,
linux.IPV6_RECVPKTINFO,
linux.IPV6_RECVRTHDR,
- linux.IPV6_RECVTCLASS,
linux.IPV6_RTHDR,
linux.IPV6_RTHDRDSTOPTS,
linux.IPV6_TCLASS,
@@ -2207,11 +2303,16 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq
var copied int
// Copy as many views as possible into the user-provided buffer.
- for dst.NumBytes() != 0 {
+ for {
+ // Always do at least one fetchReadView, even if the number of bytes to
+ // read is 0.
err = s.fetchReadView()
if err != nil {
break
}
+ if dst.NumBytes() == 0 {
+ break
+ }
var n int
var e error
@@ -2368,10 +2469,14 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
func (s *SocketOperations) controlMessages() socket.ControlMessages {
return socket.ControlMessages{
IP: tcpip.ControlMessages{
- HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
- Timestamp: s.readCM.Timestamp,
- HasTOS: s.readCM.HasTOS,
- TOS: s.readCM.TOS,
+ HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
+ Timestamp: s.readCM.Timestamp,
+ HasTOS: s.readCM.HasTOS,
+ TOS: s.readCM.TOS,
+ HasTClass: s.readCM.HasTClass,
+ TClass: s.readCM.TClass,
+ HasIPPacketInfo: s.readCM.HasIPPacketInfo,
+ PacketInfo: s.readCM.PacketInfo,
},
}
}
diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go
index 5afff2564..5f181f017 100644
--- a/pkg/sentry/socket/netstack/provider.go
+++ b/pkg/sentry/socket/netstack/provider.go
@@ -75,6 +75,8 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
switch protocol {
case syscall.IPPROTO_ICMP:
return header.ICMPv4ProtocolNumber, true, nil
+ case syscall.IPPROTO_ICMPV6:
+ return header.ICMPv6ProtocolNumber, true, nil
case syscall.IPPROTO_UDP:
return header.UDPProtocolNumber, true, nil
case syscall.IPPROTO_TCP:
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index 31ea66eca..0692482e9 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -20,6 +20,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/socket/netfilter"
"gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -88,6 +90,59 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
return nicAddrs
}
+// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr.
+func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+ var (
+ protocol tcpip.NetworkProtocolNumber
+ address tcpip.Address
+ )
+ switch addr.Family {
+ case linux.AF_INET:
+ if len(addr.Addr) < header.IPv4AddressSize {
+ return syserror.EINVAL
+ }
+ if addr.PrefixLen > header.IPv4AddressSize*8 {
+ return syserror.EINVAL
+ }
+ protocol = ipv4.ProtocolNumber
+ address = tcpip.Address(addr.Addr[:header.IPv4AddressSize])
+
+ case linux.AF_INET6:
+ if len(addr.Addr) < header.IPv6AddressSize {
+ return syserror.EINVAL
+ }
+ if addr.PrefixLen > header.IPv6AddressSize*8 {
+ return syserror.EINVAL
+ }
+ protocol = ipv6.ProtocolNumber
+ address = tcpip.Address(addr.Addr[:header.IPv6AddressSize])
+
+ default:
+ return syserror.ENOTSUP
+ }
+
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: protocol,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: int(addr.PrefixLen),
+ },
+ }
+
+ // Attach address to interface.
+ if err := s.Stack.AddProtocolAddressWithOptions(tcpip.NICID(idx), protocolAddress, stack.CanBePrimaryEndpoint); err != nil {
+ return syserr.TranslateNetstackError(err).ToError()
+ }
+
+ // Add route for local network.
+ s.Stack.AddRoute(tcpip.Route{
+ Destination: protocolAddress.AddressWithPrefix.Subnet(),
+ Gateway: "", // No gateway for local network.
+ NIC: tcpip.NICID(idx),
+ })
+ return nil
+}
+
// TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize.
func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) {
var rs tcp.ReceiveBufferSizeOption
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index 762a946fe..88d5db9fc 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"capability.go",
"clone.go",
+ "epoll.go",
"futex.go",
"linux64_amd64.go",
"linux64_arm64.go",
@@ -30,7 +31,6 @@ go_library(
"//pkg/seccomp",
"//pkg/sentry/arch",
"//pkg/sentry/kernel",
- "//pkg/sentry/socket/control",
"//pkg/sentry/socket/netlink",
"//pkg/sentry/socket/netstack",
"//pkg/sentry/syscalls/linux",
diff --git a/pkg/sentry/strace/epoll.go b/pkg/sentry/strace/epoll.go
new file mode 100644
index 000000000..a6e48b836
--- /dev/null
+++ b/pkg/sentry/strace/epoll.go
@@ -0,0 +1,89 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "fmt"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func epollEvent(t *kernel.Task, eventAddr usermem.Addr) string {
+ var e linux.EpollEvent
+ if _, err := t.CopyIn(eventAddr, &e); err != nil {
+ return fmt.Sprintf("%#x {error reading event: %v}", eventAddr, err)
+ }
+ var sb strings.Builder
+ fmt.Fprintf(&sb, "%#x ", eventAddr)
+ writeEpollEvent(&sb, e)
+ return sb.String()
+}
+
+func epollEvents(t *kernel.Task, eventsAddr usermem.Addr, numEvents, maxBytes uint64) string {
+ var sb strings.Builder
+ fmt.Fprintf(&sb, "%#x {", eventsAddr)
+ addr := eventsAddr
+ for i := uint64(0); i < numEvents; i++ {
+ var e linux.EpollEvent
+ if _, err := t.CopyIn(addr, &e); err != nil {
+ fmt.Fprintf(&sb, "{error reading event at %#x: %v}", addr, err)
+ continue
+ }
+ writeEpollEvent(&sb, e)
+ if uint64(sb.Len()) >= maxBytes {
+ sb.WriteString("...")
+ break
+ }
+ if _, ok := addr.AddLength(uint64(linux.SizeOfEpollEvent)); !ok {
+ fmt.Fprintf(&sb, "{error reading event at %#x: EFAULT}", addr)
+ continue
+ }
+ }
+ sb.WriteString("}")
+ return sb.String()
+}
+
+func writeEpollEvent(sb *strings.Builder, e linux.EpollEvent) {
+ events := epollEventEvents.Parse(uint64(e.Events))
+ fmt.Fprintf(sb, "{events=%s data=[%#x, %#x]}", events, e.Data[0], e.Data[1])
+}
+
+var epollCtlOps = abi.ValueSet{
+ linux.EPOLL_CTL_ADD: "EPOLL_CTL_ADD",
+ linux.EPOLL_CTL_DEL: "EPOLL_CTL_DEL",
+ linux.EPOLL_CTL_MOD: "EPOLL_CTL_MOD",
+}
+
+var epollEventEvents = abi.FlagSet{
+ {Flag: linux.EPOLLIN, Name: "EPOLLIN"},
+ {Flag: linux.EPOLLPRI, Name: "EPOLLPRI"},
+ {Flag: linux.EPOLLOUT, Name: "EPOLLOUT"},
+ {Flag: linux.EPOLLERR, Name: "EPOLLERR"},
+ {Flag: linux.EPOLLHUP, Name: "EPULLHUP"},
+ {Flag: linux.EPOLLRDNORM, Name: "EPOLLRDNORM"},
+ {Flag: linux.EPOLLRDBAND, Name: "EPOLLRDBAND"},
+ {Flag: linux.EPOLLWRNORM, Name: "EPOLLWRNORM"},
+ {Flag: linux.EPOLLWRBAND, Name: "EPOLLWRBAND"},
+ {Flag: linux.EPOLLMSG, Name: "EPOLLMSG"},
+ {Flag: linux.EPOLLRDHUP, Name: "EPOLLRDHUP"},
+ {Flag: linux.EPOLLEXCLUSIVE, Name: "EPOLLEXCLUSIVE"},
+ {Flag: linux.EPOLLWAKEUP, Name: "EPOLLWAKEUP"},
+ {Flag: linux.EPOLLONESHOT, Name: "EPOLLONESHOT"},
+ {Flag: linux.EPOLLET, Name: "EPOLLET"},
+}
diff --git a/pkg/sentry/strace/linux64_amd64.go b/pkg/sentry/strace/linux64_amd64.go
index 85ec66fd3..71b92eaee 100644
--- a/pkg/sentry/strace/linux64_amd64.go
+++ b/pkg/sentry/strace/linux64_amd64.go
@@ -78,8 +78,8 @@ var linuxAMD64 = SyscallMap{
51: makeSyscallInfo("getsockname", FD, PostSockAddr, SockLen),
52: makeSyscallInfo("getpeername", FD, PostSockAddr, SockLen),
53: makeSyscallInfo("socketpair", SockFamily, SockType, SockProtocol, Hex),
- 54: makeSyscallInfo("setsockopt", FD, Hex, Hex, Hex, Hex),
- 55: makeSyscallInfo("getsockopt", FD, Hex, Hex, Hex, Hex),
+ 54: makeSyscallInfo("setsockopt", FD, SockOptLevel, SockOptName, SetSockOptVal, Hex /* length by value, not a pointer */),
+ 55: makeSyscallInfo("getsockopt", FD, SockOptLevel, SockOptName, GetSockOptVal, SockLen),
56: makeSyscallInfo("clone", CloneFlags, Hex, Hex, Hex, Hex),
57: makeSyscallInfo("fork"),
58: makeSyscallInfo("vfork"),
@@ -256,8 +256,8 @@ var linuxAMD64 = SyscallMap{
229: makeSyscallInfo("clock_getres", Hex, PostTimespec),
230: makeSyscallInfo("clock_nanosleep", Hex, Hex, Timespec, PostTimespec),
231: makeSyscallInfo("exit_group", Hex),
- 232: makeSyscallInfo("epoll_wait", Hex, Hex, Hex, Hex),
- 233: makeSyscallInfo("epoll_ctl", Hex, Hex, FD, Hex),
+ 232: makeSyscallInfo("epoll_wait", FD, EpollEvents, Hex, Hex),
+ 233: makeSyscallInfo("epoll_ctl", FD, EpollCtlOp, FD, EpollEvent),
234: makeSyscallInfo("tgkill", Hex, Hex, Signal),
235: makeSyscallInfo("utimes", Path, Timeval),
// 236: vserver (not implemented in the Linux kernel)
@@ -305,7 +305,7 @@ var linuxAMD64 = SyscallMap{
278: makeSyscallInfo("vmsplice", FD, Hex, Hex, Hex),
279: makeSyscallInfo("move_pages", Hex, Hex, Hex, Hex, Hex, Hex),
280: makeSyscallInfo("utimensat", FD, Path, UTimeTimespec, Hex),
- 281: makeSyscallInfo("epoll_pwait", Hex, Hex, Hex, Hex, SigSet, Hex),
+ 281: makeSyscallInfo("epoll_pwait", FD, EpollEvents, Hex, Hex, SigSet, Hex),
282: makeSyscallInfo("signalfd", Hex, Hex, Hex),
283: makeSyscallInfo("timerfd_create", Hex, Hex),
284: makeSyscallInfo("eventfd", Hex),
diff --git a/pkg/sentry/strace/linux64_arm64.go b/pkg/sentry/strace/linux64_arm64.go
index 8bc38545f..bd7361a52 100644
--- a/pkg/sentry/strace/linux64_arm64.go
+++ b/pkg/sentry/strace/linux64_arm64.go
@@ -45,8 +45,8 @@ var linuxARM64 = SyscallMap{
18: makeSyscallInfo("lookup_dcookie", Hex, Hex, Hex),
19: makeSyscallInfo("eventfd2", Hex, Hex),
20: makeSyscallInfo("epoll_create1", Hex),
- 21: makeSyscallInfo("epoll_ctl", Hex, Hex, FD, Hex),
- 22: makeSyscallInfo("epoll_pwait", Hex, Hex, Hex, Hex, SigSet, Hex),
+ 21: makeSyscallInfo("epoll_ctl", FD, EpollCtlOp, FD, EpollEvent),
+ 22: makeSyscallInfo("epoll_pwait", FD, EpollEvents, Hex, Hex, SigSet, Hex),
23: makeSyscallInfo("dup", FD),
24: makeSyscallInfo("dup3", FD, FD, Hex),
25: makeSyscallInfo("fcntl", FD, Hex, Hex),
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index d2079c85f..c0512de89 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/socket/control"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink"
"gvisor.dev/gvisor/pkg/sentry/socket/netstack"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
@@ -220,13 +219,13 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64)
if skipData {
strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length))
- i += control.AlignUp(length, width)
+ i += binary.AlignUp(length, width)
continue
}
switch h.Type {
case linux.SCM_RIGHTS:
- rightsSize := control.AlignDown(length, linux.SizeOfControlMessageRight)
+ rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight)
numRights := rightsSize / linux.SizeOfControlMessageRight
fds := make(linux.ControlMessageRights, numRights)
@@ -295,7 +294,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64)
default:
panic("unreachable")
}
- i += control.AlignUp(length, width)
+ i += binary.AlignUp(length, width)
}
return fmt.Sprintf("%#x %s", addr, strings.Join(strs, ", "))
@@ -419,3 +418,227 @@ func sockFlags(flags int32) string {
}
return SocketFlagSet.Parse(uint64(flags))
}
+
+func getSockOptVal(t *kernel.Task, level, optname uint64, optVal usermem.Addr, optLen usermem.Addr, maximumBlobSize uint, rval uintptr) string {
+ if int(rval) < 0 {
+ return hexNum(uint64(optVal))
+ }
+ if optVal == 0 {
+ return "null"
+ }
+ l, err := copySockLen(t, optLen)
+ if err != nil {
+ return fmt.Sprintf("%#x {error reading length: %v}", optLen, err)
+ }
+ return sockOptVal(t, level, optname, optVal, uint64(l), maximumBlobSize)
+}
+
+func sockOptVal(t *kernel.Task, level, optname uint64, optVal usermem.Addr, optLen uint64, maximumBlobSize uint) string {
+ switch optLen {
+ case 1:
+ var v uint8
+ _, err := t.CopyIn(optVal, &v)
+ if err != nil {
+ return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err)
+ }
+ return fmt.Sprintf("%#x {value=%v}", optVal, v)
+ case 2:
+ var v uint16
+ _, err := t.CopyIn(optVal, &v)
+ if err != nil {
+ return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err)
+ }
+ return fmt.Sprintf("%#x {value=%v}", optVal, v)
+ case 4:
+ var v uint32
+ _, err := t.CopyIn(optVal, &v)
+ if err != nil {
+ return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err)
+ }
+ return fmt.Sprintf("%#x {value=%v}", optVal, v)
+ default:
+ return dump(t, optVal, uint(optLen), maximumBlobSize)
+ }
+}
+
+var sockOptLevels = abi.ValueSet{
+ linux.SOL_IP: "SOL_IP",
+ linux.SOL_SOCKET: "SOL_SOCKET",
+ linux.SOL_TCP: "SOL_TCP",
+ linux.SOL_UDP: "SOL_UDP",
+ linux.SOL_IPV6: "SOL_IPV6",
+ linux.SOL_ICMPV6: "SOL_ICMPV6",
+ linux.SOL_RAW: "SOL_RAW",
+ linux.SOL_PACKET: "SOL_PACKET",
+ linux.SOL_NETLINK: "SOL_NETLINK",
+}
+
+var sockOptNames = map[uint64]abi.ValueSet{
+ linux.SOL_IP: {
+ linux.IP_TTL: "IP_TTL",
+ linux.IP_MULTICAST_TTL: "IP_MULTICAST_TTL",
+ linux.IP_MULTICAST_IF: "IP_MULTICAST_IF",
+ linux.IP_MULTICAST_LOOP: "IP_MULTICAST_LOOP",
+ linux.IP_TOS: "IP_TOS",
+ linux.IP_RECVTOS: "IP_RECVTOS",
+ linux.IPT_SO_GET_INFO: "IPT_SO_GET_INFO",
+ linux.IPT_SO_GET_ENTRIES: "IPT_SO_GET_ENTRIES",
+ linux.IP_ADD_MEMBERSHIP: "IP_ADD_MEMBERSHIP",
+ linux.IP_DROP_MEMBERSHIP: "IP_DROP_MEMBERSHIP",
+ linux.MCAST_JOIN_GROUP: "MCAST_JOIN_GROUP",
+ linux.IP_ADD_SOURCE_MEMBERSHIP: "IP_ADD_SOURCE_MEMBERSHIP",
+ linux.IP_BIND_ADDRESS_NO_PORT: "IP_BIND_ADDRESS_NO_PORT",
+ linux.IP_BLOCK_SOURCE: "IP_BLOCK_SOURCE",
+ linux.IP_CHECKSUM: "IP_CHECKSUM",
+ linux.IP_DROP_SOURCE_MEMBERSHIP: "IP_DROP_SOURCE_MEMBERSHIP",
+ linux.IP_FREEBIND: "IP_FREEBIND",
+ linux.IP_HDRINCL: "IP_HDRINCL",
+ linux.IP_IPSEC_POLICY: "IP_IPSEC_POLICY",
+ linux.IP_MINTTL: "IP_MINTTL",
+ linux.IP_MSFILTER: "IP_MSFILTER",
+ linux.IP_MTU_DISCOVER: "IP_MTU_DISCOVER",
+ linux.IP_MULTICAST_ALL: "IP_MULTICAST_ALL",
+ linux.IP_NODEFRAG: "IP_NODEFRAG",
+ linux.IP_OPTIONS: "IP_OPTIONS",
+ linux.IP_PASSSEC: "IP_PASSSEC",
+ linux.IP_PKTINFO: "IP_PKTINFO",
+ linux.IP_RECVERR: "IP_RECVERR",
+ linux.IP_RECVFRAGSIZE: "IP_RECVFRAGSIZE",
+ linux.IP_RECVOPTS: "IP_RECVOPTS",
+ linux.IP_RECVORIGDSTADDR: "IP_RECVORIGDSTADDR",
+ linux.IP_RECVTTL: "IP_RECVTTL",
+ linux.IP_RETOPTS: "IP_RETOPTS",
+ linux.IP_TRANSPARENT: "IP_TRANSPARENT",
+ linux.IP_UNBLOCK_SOURCE: "IP_UNBLOCK_SOURCE",
+ linux.IP_UNICAST_IF: "IP_UNICAST_IF",
+ linux.IP_XFRM_POLICY: "IP_XFRM_POLICY",
+ linux.MCAST_BLOCK_SOURCE: "MCAST_BLOCK_SOURCE",
+ linux.MCAST_JOIN_SOURCE_GROUP: "MCAST_JOIN_SOURCE_GROUP",
+ linux.MCAST_LEAVE_GROUP: "MCAST_LEAVE_GROUP",
+ linux.MCAST_LEAVE_SOURCE_GROUP: "MCAST_LEAVE_SOURCE_GROUP",
+ linux.MCAST_MSFILTER: "MCAST_MSFILTER",
+ linux.MCAST_UNBLOCK_SOURCE: "MCAST_UNBLOCK_SOURCE",
+ linux.IP_ROUTER_ALERT: "IP_ROUTER_ALERT",
+ linux.IP_PKTOPTIONS: "IP_PKTOPTIONS",
+ linux.IP_MTU: "IP_MTU",
+ },
+ linux.SOL_SOCKET: {
+ linux.SO_ERROR: "SO_ERROR",
+ linux.SO_PEERCRED: "SO_PEERCRED",
+ linux.SO_PASSCRED: "SO_PASSCRED",
+ linux.SO_SNDBUF: "SO_SNDBUF",
+ linux.SO_RCVBUF: "SO_RCVBUF",
+ linux.SO_REUSEADDR: "SO_REUSEADDR",
+ linux.SO_REUSEPORT: "SO_REUSEPORT",
+ linux.SO_BINDTODEVICE: "SO_BINDTODEVICE",
+ linux.SO_BROADCAST: "SO_BROADCAST",
+ linux.SO_KEEPALIVE: "SO_KEEPALIVE",
+ linux.SO_LINGER: "SO_LINGER",
+ linux.SO_SNDTIMEO: "SO_SNDTIMEO",
+ linux.SO_RCVTIMEO: "SO_RCVTIMEO",
+ linux.SO_OOBINLINE: "SO_OOBINLINE",
+ linux.SO_TIMESTAMP: "SO_TIMESTAMP",
+ },
+ linux.SOL_TCP: {
+ linux.TCP_NODELAY: "TCP_NODELAY",
+ linux.TCP_CORK: "TCP_CORK",
+ linux.TCP_QUICKACK: "TCP_QUICKACK",
+ linux.TCP_MAXSEG: "TCP_MAXSEG",
+ linux.TCP_KEEPIDLE: "TCP_KEEPIDLE",
+ linux.TCP_KEEPINTVL: "TCP_KEEPINTVL",
+ linux.TCP_USER_TIMEOUT: "TCP_USER_TIMEOUT",
+ linux.TCP_INFO: "TCP_INFO",
+ linux.TCP_CC_INFO: "TCP_CC_INFO",
+ linux.TCP_NOTSENT_LOWAT: "TCP_NOTSENT_LOWAT",
+ linux.TCP_ZEROCOPY_RECEIVE: "TCP_ZEROCOPY_RECEIVE",
+ linux.TCP_CONGESTION: "TCP_CONGESTION",
+ linux.TCP_LINGER2: "TCP_LINGER2",
+ linux.TCP_DEFER_ACCEPT: "TCP_DEFER_ACCEPT",
+ linux.TCP_REPAIR_OPTIONS: "TCP_REPAIR_OPTIONS",
+ linux.TCP_INQ: "TCP_INQ",
+ linux.TCP_FASTOPEN: "TCP_FASTOPEN",
+ linux.TCP_FASTOPEN_CONNECT: "TCP_FASTOPEN_CONNECT",
+ linux.TCP_FASTOPEN_KEY: "TCP_FASTOPEN_KEY",
+ linux.TCP_FASTOPEN_NO_COOKIE: "TCP_FASTOPEN_NO_COOKIE",
+ linux.TCP_KEEPCNT: "TCP_KEEPCNT",
+ linux.TCP_QUEUE_SEQ: "TCP_QUEUE_SEQ",
+ linux.TCP_REPAIR: "TCP_REPAIR",
+ linux.TCP_REPAIR_QUEUE: "TCP_REPAIR_QUEUE",
+ linux.TCP_REPAIR_WINDOW: "TCP_REPAIR_WINDOW",
+ linux.TCP_SAVED_SYN: "TCP_SAVED_SYN",
+ linux.TCP_SAVE_SYN: "TCP_SAVE_SYN",
+ linux.TCP_SYNCNT: "TCP_SYNCNT",
+ linux.TCP_THIN_DUPACK: "TCP_THIN_DUPACK",
+ linux.TCP_THIN_LINEAR_TIMEOUTS: "TCP_THIN_LINEAR_TIMEOUTS",
+ linux.TCP_TIMESTAMP: "TCP_TIMESTAMP",
+ linux.TCP_ULP: "TCP_ULP",
+ linux.TCP_WINDOW_CLAMP: "TCP_WINDOW_CLAMP",
+ },
+ linux.SOL_IPV6: {
+ linux.IPV6_V6ONLY: "IPV6_V6ONLY",
+ linux.IPV6_PATHMTU: "IPV6_PATHMTU",
+ linux.IPV6_TCLASS: "IPV6_TCLASS",
+ linux.IPV6_ADD_MEMBERSHIP: "IPV6_ADD_MEMBERSHIP",
+ linux.IPV6_DROP_MEMBERSHIP: "IPV6_DROP_MEMBERSHIP",
+ linux.IPV6_IPSEC_POLICY: "IPV6_IPSEC_POLICY",
+ linux.IPV6_JOIN_ANYCAST: "IPV6_JOIN_ANYCAST",
+ linux.IPV6_LEAVE_ANYCAST: "IPV6_LEAVE_ANYCAST",
+ linux.IPV6_PKTINFO: "IPV6_PKTINFO",
+ linux.IPV6_ROUTER_ALERT: "IPV6_ROUTER_ALERT",
+ linux.IPV6_XFRM_POLICY: "IPV6_XFRM_POLICY",
+ linux.MCAST_BLOCK_SOURCE: "MCAST_BLOCK_SOURCE",
+ linux.MCAST_JOIN_GROUP: "MCAST_JOIN_GROUP",
+ linux.MCAST_JOIN_SOURCE_GROUP: "MCAST_JOIN_SOURCE_GROUP",
+ linux.MCAST_LEAVE_GROUP: "MCAST_LEAVE_GROUP",
+ linux.MCAST_LEAVE_SOURCE_GROUP: "MCAST_LEAVE_SOURCE_GROUP",
+ linux.MCAST_UNBLOCK_SOURCE: "MCAST_UNBLOCK_SOURCE",
+ linux.IPV6_2292DSTOPTS: "IPV6_2292DSTOPTS",
+ linux.IPV6_2292HOPLIMIT: "IPV6_2292HOPLIMIT",
+ linux.IPV6_2292HOPOPTS: "IPV6_2292HOPOPTS",
+ linux.IPV6_2292PKTINFO: "IPV6_2292PKTINFO",
+ linux.IPV6_2292PKTOPTIONS: "IPV6_2292PKTOPTIONS",
+ linux.IPV6_2292RTHDR: "IPV6_2292RTHDR",
+ linux.IPV6_ADDR_PREFERENCES: "IPV6_ADDR_PREFERENCES",
+ linux.IPV6_AUTOFLOWLABEL: "IPV6_AUTOFLOWLABEL",
+ linux.IPV6_DONTFRAG: "IPV6_DONTFRAG",
+ linux.IPV6_DSTOPTS: "IPV6_DSTOPTS",
+ linux.IPV6_FLOWINFO: "IPV6_FLOWINFO",
+ linux.IPV6_FLOWINFO_SEND: "IPV6_FLOWINFO_SEND",
+ linux.IPV6_FLOWLABEL_MGR: "IPV6_FLOWLABEL_MGR",
+ linux.IPV6_FREEBIND: "IPV6_FREEBIND",
+ linux.IPV6_HOPOPTS: "IPV6_HOPOPTS",
+ linux.IPV6_MINHOPCOUNT: "IPV6_MINHOPCOUNT",
+ linux.IPV6_MTU: "IPV6_MTU",
+ linux.IPV6_MTU_DISCOVER: "IPV6_MTU_DISCOVER",
+ linux.IPV6_MULTICAST_ALL: "IPV6_MULTICAST_ALL",
+ linux.IPV6_MULTICAST_HOPS: "IPV6_MULTICAST_HOPS",
+ linux.IPV6_MULTICAST_IF: "IPV6_MULTICAST_IF",
+ linux.IPV6_MULTICAST_LOOP: "IPV6_MULTICAST_LOOP",
+ linux.IPV6_RECVDSTOPTS: "IPV6_RECVDSTOPTS",
+ linux.IPV6_RECVERR: "IPV6_RECVERR",
+ linux.IPV6_RECVFRAGSIZE: "IPV6_RECVFRAGSIZE",
+ linux.IPV6_RECVHOPLIMIT: "IPV6_RECVHOPLIMIT",
+ linux.IPV6_RECVHOPOPTS: "IPV6_RECVHOPOPTS",
+ linux.IPV6_RECVORIGDSTADDR: "IPV6_RECVORIGDSTADDR",
+ linux.IPV6_RECVPATHMTU: "IPV6_RECVPATHMTU",
+ linux.IPV6_RECVPKTINFO: "IPV6_RECVPKTINFO",
+ linux.IPV6_RECVRTHDR: "IPV6_RECVRTHDR",
+ linux.IPV6_RECVTCLASS: "IPV6_RECVTCLASS",
+ linux.IPV6_RTHDR: "IPV6_RTHDR",
+ linux.IPV6_RTHDRDSTOPTS: "IPV6_RTHDRDSTOPTS",
+ linux.IPV6_TRANSPARENT: "IPV6_TRANSPARENT",
+ linux.IPV6_UNICAST_HOPS: "IPV6_UNICAST_HOPS",
+ linux.IPV6_UNICAST_IF: "IPV6_UNICAST_IF",
+ linux.MCAST_MSFILTER: "MCAST_MSFILTER",
+ linux.IPV6_ADDRFORM: "IPV6_ADDRFORM",
+ },
+ linux.SOL_NETLINK: {
+ linux.NETLINK_BROADCAST_ERROR: "NETLINK_BROADCAST_ERROR",
+ linux.NETLINK_CAP_ACK: "NETLINK_CAP_ACK",
+ linux.NETLINK_DUMP_STRICT_CHK: "NETLINK_DUMP_STRICT_CHK",
+ linux.NETLINK_EXT_ACK: "NETLINK_EXT_ACK",
+ linux.NETLINK_LIST_MEMBERSHIPS: "NETLINK_LIST_MEMBERSHIPS",
+ linux.NETLINK_NO_ENOBUFS: "NETLINK_NO_ENOBUFS",
+ linux.NETLINK_PKTINFO: "NETLINK_PKTINFO",
+ },
+}
diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go
index 3fc4a47fc..77655558e 100644
--- a/pkg/sentry/strace/strace.go
+++ b/pkg/sentry/strace/strace.go
@@ -55,6 +55,14 @@ var ItimerTypes = abi.ValueSet{
linux.ITIMER_PROF: "ITIMER_PROF",
}
+func hexNum(num uint64) string {
+ return "0x" + strconv.FormatUint(num, 16)
+}
+
+func hexArg(arg arch.SyscallArgument) string {
+ return hexNum(arg.Uint64())
+}
+
func iovecs(t *kernel.Task, addr usermem.Addr, iovcnt int, printContent bool, maxBytes uint64) string {
if iovcnt < 0 || iovcnt > linux.UIO_MAXIOV {
return fmt.Sprintf("%#x (error decoding iovecs: invalid iovcnt)", addr)
@@ -133,6 +141,10 @@ func path(t *kernel.Task, addr usermem.Addr) string {
}
func fd(t *kernel.Task, fd int32) string {
+ if kernel.VFS2Enabled {
+ return fdVFS2(t, fd)
+ }
+
root := t.FSContext().RootDirectory()
if root != nil {
defer root.DecRef()
@@ -161,6 +173,30 @@ func fd(t *kernel.Task, fd int32) string {
return fmt.Sprintf("%#x %s", fd, name)
}
+func fdVFS2(t *kernel.Task, fd int32) string {
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+
+ vfsObj := root.Mount().Filesystem().VirtualFilesystem()
+ if fd == linux.AT_FDCWD {
+ wd := t.FSContext().WorkingDirectoryVFS2()
+ defer wd.DecRef()
+
+ name, _ := vfsObj.PathnameWithDeleted(t, root, wd)
+ return fmt.Sprintf("AT_FDCWD %s", name)
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ // Cast FD to uint64 to avoid printing negative hex.
+ return fmt.Sprintf("%#x (bad FD)", uint64(fd))
+ }
+ defer file.DecRef()
+
+ name, _ := vfsObj.PathnameWithDeleted(t, root, file.VirtualDentry())
+ return fmt.Sprintf("%#x %s", fd, name)
+}
+
func fdpair(t *kernel.Task, addr usermem.Addr) string {
var fds [2]int32
_, err := t.CopyIn(addr, &fds)
@@ -389,6 +425,12 @@ func (i *SyscallInfo) pre(t *kernel.Task, args arch.SyscallArguments, maximumBlo
output = append(output, path(t, args[arg].Pointer()))
case ExecveStringVector:
output = append(output, stringVector(t, args[arg].Pointer()))
+ case SetSockOptVal:
+ output = append(output, sockOptVal(t, args[arg-2].Uint64() /* level */, args[arg-1].Uint64() /* optName */, args[arg].Pointer() /* optVal */, args[arg+1].Uint64() /* optLen */, maximumBlobSize))
+ case SockOptLevel:
+ output = append(output, sockOptLevels.Parse(args[arg].Uint64()))
+ case SockOptName:
+ output = append(output, sockOptNames[args[arg-1].Uint64() /* level */].Parse(args[arg].Uint64()))
case SockAddr:
output = append(output, sockAddr(t, args[arg].Pointer(), uint32(args[arg+1].Uint64())))
case SockLen:
@@ -439,6 +481,12 @@ func (i *SyscallInfo) pre(t *kernel.Task, args arch.SyscallArguments, maximumBlo
output = append(output, capData(t, args[arg-1].Pointer(), args[arg].Pointer()))
case PollFDs:
output = append(output, pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), false))
+ case EpollCtlOp:
+ output = append(output, epollCtlOps.Parse(uint64(args[arg].Int())))
+ case EpollEvent:
+ output = append(output, epollEvent(t, args[arg].Pointer()))
+ case EpollEvents:
+ output = append(output, epollEvents(t, args[arg].Pointer(), 0 /* numEvents */, uint64(maximumBlobSize)))
case SelectFDSet:
output = append(output, fdSet(t, int(args[0].Int()), args[arg].Pointer()))
case Oct:
@@ -446,7 +494,7 @@ func (i *SyscallInfo) pre(t *kernel.Task, args arch.SyscallArguments, maximumBlo
case Hex:
fallthrough
default:
- output = append(output, "0x"+strconv.FormatUint(args[arg].Uint64(), 16))
+ output = append(output, hexArg(args[arg]))
}
}
@@ -507,6 +555,14 @@ func (i *SyscallInfo) post(t *kernel.Task, args arch.SyscallArguments, rval uint
output[arg] = capData(t, args[arg-1].Pointer(), args[arg].Pointer())
case PollFDs:
output[arg] = pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), true)
+ case EpollEvents:
+ output[arg] = epollEvents(t, args[arg].Pointer(), uint64(rval), uint64(maximumBlobSize))
+ case GetSockOptVal:
+ output[arg] = getSockOptVal(t, args[arg-2].Uint64() /* level */, args[arg-1].Uint64() /* optName */, args[arg].Pointer() /* optVal */, args[arg+1].Pointer() /* optLen */, maximumBlobSize, rval)
+ case SetSockOptVal:
+ // No need to print the value again. While it usually
+ // isn't, the string version of this arg can be long.
+ output[arg] = hexArg(args[arg])
}
}
}
diff --git a/pkg/sentry/strace/syscalls.go b/pkg/sentry/strace/syscalls.go
index 24e29a2ba..7e69b9279 100644
--- a/pkg/sentry/strace/syscalls.go
+++ b/pkg/sentry/strace/syscalls.go
@@ -207,9 +207,37 @@ const (
// array is in the next argument.
PollFDs
- // SelectFDSet is an fd_set argument in select(2)/pselect(2). The number of
- // fds represented must be the first argument.
+ // SelectFDSet is an fd_set argument in select(2)/pselect(2). The
+ // number of FDs represented must be the first argument.
SelectFDSet
+
+ // GetSockOptVal is the optval argument in getsockopt(2).
+ //
+ // Formatted after syscall execution.
+ GetSockOptVal
+
+ // SetSockOptVal is the optval argument in setsockopt(2).
+ //
+ // Contents omitted after syscall execution.
+ SetSockOptVal
+
+ // SockOptLevel is the level argument in getsockopt(2) and
+ // setsockopt(2).
+ SockOptLevel
+
+ // SockOptLevel is the optname argument in getsockopt(2) and
+ // setsockopt(2).
+ SockOptName
+
+ // EpollCtlOp is the op argument to epoll_ctl(2).
+ EpollCtlOp
+
+ // EpollEvent is the event argument in epoll_ctl(2).
+ EpollEvent
+
+ // EpollEvents is an array of struct epoll_event. It is the events
+ // argument in epoll_wait(2)/epoll_pwait(2).
+ EpollEvents
)
// defaultFormat is the syscall argument format to use if the actual format is
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index be16ee686..0d24fd3c4 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -74,6 +74,7 @@ go_library(
"//pkg/sentry/fs/lock",
"//pkg/sentry/fs/timerfd",
"//pkg/sentry/fs/tmpfs",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/epoll",
diff --git a/pkg/sentry/syscalls/linux/linux64_amd64.go b/pkg/sentry/syscalls/linux/linux64_amd64.go
index 7435b50bf..79066ad2a 100644
--- a/pkg/sentry/syscalls/linux/linux64_amd64.go
+++ b/pkg/sentry/syscalls/linux/linux64_amd64.go
@@ -234,12 +234,12 @@ var AMD64 = &kernel.SyscallTable{
191: syscalls.PartiallySupported("getxattr", GetXattr, "Only supported for tmpfs.", nil),
192: syscalls.PartiallySupported("lgetxattr", LGetXattr, "Only supported for tmpfs.", nil),
193: syscalls.PartiallySupported("fgetxattr", FGetXattr, "Only supported for tmpfs.", nil),
- 194: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 195: syscalls.ErrorWithEvent("llistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 196: syscalls.ErrorWithEvent("flistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 197: syscalls.ErrorWithEvent("removexattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 198: syscalls.ErrorWithEvent("lremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 199: syscalls.ErrorWithEvent("fremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 194: syscalls.PartiallySupported("listxattr", ListXattr, "Only supported for tmpfs", nil),
+ 195: syscalls.PartiallySupported("llistxattr", LListXattr, "Only supported for tmpfs", nil),
+ 196: syscalls.PartiallySupported("flistxattr", FListXattr, "Only supported for tmpfs", nil),
+ 197: syscalls.PartiallySupported("removexattr", RemoveXattr, "Only supported for tmpfs", nil),
+ 198: syscalls.PartiallySupported("lremovexattr", LRemoveXattr, "Only supported for tmpfs", nil),
+ 199: syscalls.PartiallySupported("fremovexattr", FRemoveXattr, "Only supported for tmpfs", nil),
200: syscalls.Supported("tkill", Tkill),
201: syscalls.Supported("time", Time),
202: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil),
diff --git a/pkg/sentry/syscalls/linux/linux64_arm64.go b/pkg/sentry/syscalls/linux/linux64_arm64.go
index 03a39fe65..7421619de 100644
--- a/pkg/sentry/syscalls/linux/linux64_arm64.go
+++ b/pkg/sentry/syscalls/linux/linux64_arm64.go
@@ -47,12 +47,12 @@ var ARM64 = &kernel.SyscallTable{
8: syscalls.PartiallySupported("getxattr", GetXattr, "Only supported for tmpfs.", nil),
9: syscalls.PartiallySupported("lgetxattr", LGetXattr, "Only supported for tmpfs.", nil),
10: syscalls.PartiallySupported("fgetxattr", FGetXattr, "Only supported for tmpfs.", nil),
- 11: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 12: syscalls.ErrorWithEvent("llistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 13: syscalls.ErrorWithEvent("flistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 14: syscalls.ErrorWithEvent("removexattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 15: syscalls.ErrorWithEvent("lremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 16: syscalls.ErrorWithEvent("fremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
+ 11: syscalls.PartiallySupported("listxattr", ListXattr, "Only supported for tmpfs", nil),
+ 12: syscalls.PartiallySupported("llistxattr", LListXattr, "Only supported for tmpfs", nil),
+ 13: syscalls.PartiallySupported("flistxattr", FListXattr, "Only supported for tmpfs", nil),
+ 14: syscalls.PartiallySupported("removexattr", RemoveXattr, "Only supported for tmpfs", nil),
+ 15: syscalls.PartiallySupported("lremovexattr", LRemoveXattr, "Only supported for tmpfs", nil),
+ 16: syscalls.PartiallySupported("fremovexattr", FRemoveXattr, "Only supported for tmpfs", nil),
17: syscalls.Supported("getcwd", Getcwd),
18: syscalls.CapError("lookup_dcookie", linux.CAP_SYS_ADMIN, "", nil),
19: syscalls.Supported("eventfd2", Eventfd2),
diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go
index 5f11b496c..3ab93fbde 100644
--- a/pkg/sentry/syscalls/linux/sys_epoll.go
+++ b/pkg/sentry/syscalls/linux/sys_epoll.go
@@ -25,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// EpollCreate1 implements the epoll_create1(2) linux syscall.
func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
flags := args[0].Int()
@@ -83,8 +85,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
mask = waiter.EventMaskFromLinux(e.Events)
- data[0] = e.Fd
- data[1] = e.Data
+ data = e.Data
}
// Perform the requested operations.
@@ -165,3 +166,5 @@ func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return EpollWait(t, args)
}
+
+// LINT.ThenChange(vfs2/epoll.go)
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 421845ebb..d10a9bed8 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -130,6 +130,8 @@ func copyInPath(t *kernel.Task, addr usermem.Addr, allowEmpty bool) (path string
return path, dirPath, nil
}
+// LINT.IfChange
+
func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uintptr, err error) {
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -575,6 +577,10 @@ func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, accessAt(t, dirFD, addr, flags&linux.AT_SYMLINK_NOFOLLOW == 0, mode)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
+// LINT.IfChange
+
// Ioctl implements linux syscall ioctl(2).
func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -650,6 +656,10 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
}
+// LINT.ThenChange(vfs2/ioctl.go)
+
+// LINT.IfChange
+
// Getcwd implements the linux syscall getcwd(2).
func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
addr := args[0].Pointer()
@@ -760,6 +770,10 @@ func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, nil
}
+// LINT.ThenChange(vfs2/fscontext.go)
+
+// LINT.IfChange
+
// Close implements linux syscall close(2).
func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -1094,6 +1108,8 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
}
+// LINT.ThenChange(vfs2/fd.go)
+
const (
_FADV_NORMAL = 0
_FADV_RANDOM = 1
@@ -1141,6 +1157,8 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, nil
}
+// LINT.IfChange
+
func mkdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode) error {
path, _, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -1218,7 +1236,7 @@ func rmdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
return syserror.ENOTEMPTY
}
- if err := fs.MayDelete(t, root, d, name); err != nil {
+ if err := d.MayDelete(t, root, name); err != nil {
return err
}
@@ -1421,6 +1439,10 @@ func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, linkAt(t, oldDirFD, oldAddr, newDirFD, newAddr, resolve, allowEmpty)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
+// LINT.IfChange
+
func readlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr, bufAddr usermem.Addr, size uint) (copied uintptr, err error) {
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -1480,6 +1502,10 @@ func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return n, nil, err
}
+// LINT.ThenChange(vfs2/stat.go)
+
+// LINT.IfChange
+
func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -1491,7 +1517,7 @@ func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
return syserror.ENOTDIR
}
- if err := fs.MayDelete(t, root, d, name); err != nil {
+ if err := d.MayDelete(t, root, name); err != nil {
return err
}
@@ -1516,6 +1542,10 @@ func Unlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, unlinkAt(t, dirFD, addr)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
+// LINT.IfChange
+
// Truncate implements linux syscall truncate(2).
func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
addr := args[0].Pointer()
@@ -1614,6 +1644,8 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, nil
}
+// LINT.ThenChange(vfs2/setstat.go)
+
// Umask implements linux syscall umask(2).
func Umask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
mask := args[0].ModeT()
@@ -1621,6 +1653,8 @@ func Umask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return uintptr(mask), nil, nil
}
+// LINT.IfChange
+
// Change ownership of a file.
//
// uid and gid may be -1, in which case they will not be changed.
@@ -1987,6 +2021,10 @@ func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, utimes(t, dirFD, pathnameAddr, ts, true)
}
+// LINT.ThenChange(vfs2/setstat.go)
+
+// LINT.IfChange
+
func renameAt(t *kernel.Task, oldDirFD int32, oldAddr usermem.Addr, newDirFD int32, newAddr usermem.Addr) error {
newPath, _, err := copyInPath(t, newAddr, false /* allowEmpty */)
if err != nil {
@@ -2042,6 +2080,8 @@ func Renameat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, renameAt(t, oldDirFD, oldPathAddr, newDirFD, newPathAddr)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
// Fallocate implements linux system call fallocate(2).
func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go
index f66f4ffde..b126fecc0 100644
--- a/pkg/sentry/syscalls/linux/sys_getdents.go
+++ b/pkg/sentry/syscalls/linux/sys_getdents.go
@@ -27,6 +27,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// Getdents implements linux syscall getdents(2) for 64bit systems.
func Getdents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -244,3 +246,5 @@ func (ds *direntSerializer) CopyOut(name string, attr fs.DentAttr) error {
func (ds *direntSerializer) Written() int {
return ds.written
}
+
+// LINT.ThenChange(vfs2/getdents.go)
diff --git a/pkg/sentry/syscalls/linux/sys_lseek.go b/pkg/sentry/syscalls/linux/sys_lseek.go
index 297e920c4..3f7691eae 100644
--- a/pkg/sentry/syscalls/linux/sys_lseek.go
+++ b/pkg/sentry/syscalls/linux/sys_lseek.go
@@ -21,6 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// LINT.IfChange
+
// Lseek implements linux syscall lseek(2).
func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -52,3 +54,5 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
return uintptr(offset), nil, err
}
+
+// LINT.ThenChange(vfs2/read_write.go)
diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go
index 9959f6e61..91694d374 100644
--- a/pkg/sentry/syscalls/linux/sys_mmap.go
+++ b/pkg/sentry/syscalls/linux/sys_mmap.go
@@ -35,6 +35,8 @@ func Brk(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
return uintptr(addr), nil, nil
}
+// LINT.IfChange
+
// Mmap implements linux syscall mmap(2).
func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
prot := args[2].Int()
@@ -104,6 +106,8 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
return uintptr(rv), nil, err
}
+// LINT.ThenChange(vfs2/mmap.go)
+
// Munmap implements linux syscall munmap(2).
func Munmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
return 0, nil, t.MemoryManager().MUnmap(t, args[0].Pointer(), args[1].Uint64())
diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go
index 98db32d77..9c6728530 100644
--- a/pkg/sentry/syscalls/linux/sys_prctl.go
+++ b/pkg/sentry/syscalls/linux/sys_prctl.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/mm"
@@ -135,7 +136,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
// Set the underlying executable.
- t.MemoryManager().SetExecutable(file.Dirent)
+ t.MemoryManager().SetExecutable(fsbridge.NewFSFile(file))
case linux.PR_SET_MM_AUXV,
linux.PR_SET_MM_START_CODE,
diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go
index 227692f06..78a2cb750 100644
--- a/pkg/sentry/syscalls/linux/sys_read.go
+++ b/pkg/sentry/syscalls/linux/sys_read.go
@@ -28,6 +28,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
const (
// EventMaskRead contains events that can be triggered on reads.
EventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr
@@ -388,3 +390,5 @@ func preadv(t *kernel.Task, f *fs.File, dst usermem.IOSequence, offset int64) (i
return total, err
}
+
+// LINT.ThenChange(vfs2/read_write.go)
diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go
index f43d6c155..fd642834b 100644
--- a/pkg/sentry/syscalls/linux/sys_splice.go
+++ b/pkg/sentry/syscalls/linux/sys_splice.go
@@ -25,6 +25,10 @@ import (
// doSplice implements a blocking splice operation.
func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonBlocking bool) (int64, error) {
+ if opts.Length < 0 || opts.SrcStart < 0 || opts.DstStart < 0 {
+ return 0, syserror.EINVAL
+ }
+
var (
total int64
n int64
@@ -82,11 +86,6 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
offsetAddr := args[2].Pointer()
count := int64(args[3].SizeT())
- // Don't send a negative number of bytes.
- if count < 0 {
- return 0, nil, syserror.EINVAL
- }
-
// Get files.
inFile := t.GetFile(inFD)
if inFile == nil {
@@ -136,11 +135,6 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, err
}
- // The offset must be valid.
- if offset < 0 {
- return 0, nil, syserror.EINVAL
- }
-
// Do the splice.
n, err = doSplice(t, outFile, inFile, fs.SpliceOpts{
Length: count,
@@ -227,6 +221,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if _, err := t.CopyIn(outOffset, &offset); err != nil {
return 0, nil, err
}
+
// Use the destination offset.
opts.DstOffset = true
opts.DstStart = offset
@@ -244,6 +239,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if _, err := t.CopyIn(inOffset, &offset); err != nil {
return 0, nil, err
}
+
// Use the source offset.
opts.SrcOffset = true
opts.SrcStart = offset
diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go
index c841abccb..9bd2df104 100644
--- a/pkg/sentry/syscalls/linux/sys_stat.go
+++ b/pkg/sentry/syscalls/linux/sys_stat.go
@@ -23,6 +23,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// Stat implements linux syscall stat(2).
func Stat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
addr := args[0].Pointer()
@@ -112,7 +114,8 @@ func stat(t *kernel.Task, d *fs.Dirent, dirPath bool, statAddr usermem.Addr) err
if err != nil {
return err
}
- return copyOutStat(t, statAddr, d.Inode.StableAttr, uattr)
+ s := statFromAttrs(t, d.Inode.StableAttr, uattr)
+ return s.CopyOut(t, statAddr)
}
// fstat implements fstat for the given *fs.File.
@@ -121,7 +124,8 @@ func fstat(t *kernel.Task, f *fs.File, statAddr usermem.Addr) error {
if err != nil {
return err
}
- return copyOutStat(t, statAddr, f.Dirent.Inode.StableAttr, uattr)
+ s := statFromAttrs(t, f.Dirent.Inode.StableAttr, uattr)
+ return s.CopyOut(t, statAddr)
}
// Statx implements linux syscall statx(2).
@@ -277,3 +281,5 @@ func statfsImpl(t *kernel.Task, d *fs.Dirent, addr usermem.Addr) error {
_, err = t.CopyOut(addr, &statfs)
return err
}
+
+// LINT.ThenChange(vfs2/stat.go)
diff --git a/pkg/sentry/syscalls/linux/sys_stat_amd64.go b/pkg/sentry/syscalls/linux/sys_stat_amd64.go
index 75a567bd4..0a04a6113 100644
--- a/pkg/sentry/syscalls/linux/sys_stat_amd64.go
+++ b/pkg/sentry/syscalls/linux/sys_stat_amd64.go
@@ -12,64 +12,34 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-//+build amd64
+// +build amd64
package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/usermem"
)
-// copyOutStat copies the attributes (sattr, uattr) to the struct stat at
-// address dst in t's address space. It encodes the stat struct to bytes
-// manually, as stat() is a very common syscall for many applications, and
-// t.CopyObjectOut has noticeable performance impact due to its many slice
-// allocations and use of reflection.
-func copyOutStat(t *kernel.Task, dst usermem.Addr, sattr fs.StableAttr, uattr fs.UnstableAttr) error {
- b := t.CopyScratchBuffer(int(linux.SizeOfStat))[:0]
-
- // Dev (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(sattr.DeviceID))
- // Ino (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(sattr.InodeID))
- // Nlink (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uattr.Links)
- // Mode (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, sattr.Type.LinuxType()|uint32(uattr.Perms.LinuxMode()))
- // UID (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()))
- // GID (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, uint32(uattr.Owner.GID.In(t.UserNamespace()).OrOverflow()))
- // Padding (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, 0)
- // Rdev (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(linux.MakeDeviceID(sattr.DeviceFileMajor, sattr.DeviceFileMinor)))
- // Size (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(uattr.Size))
- // Blksize (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(sattr.BlockSize))
- // Blocks (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(uattr.Usage/512))
-
- // ATime
- atime := uattr.AccessTime.Timespec()
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(atime.Sec))
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(atime.Nsec))
-
- // MTime
- mtime := uattr.ModificationTime.Timespec()
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(mtime.Sec))
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(mtime.Nsec))
-
- // CTime
- ctime := uattr.StatusChangeTime.Timespec()
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(ctime.Sec))
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(ctime.Nsec))
-
- _, err := t.CopyOutBytes(dst, b)
- return err
+// LINT.IfChange
+
+func statFromAttrs(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr) linux.Stat {
+ return linux.Stat{
+ Dev: sattr.DeviceID,
+ Ino: sattr.InodeID,
+ Nlink: uattr.Links,
+ Mode: sattr.Type.LinuxType() | uint32(uattr.Perms.LinuxMode()),
+ UID: uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()),
+ GID: uint32(uattr.Owner.GID.In(t.UserNamespace()).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(sattr.DeviceFileMajor, sattr.DeviceFileMinor)),
+ Size: uattr.Size,
+ Blksize: sattr.BlockSize,
+ Blocks: uattr.Usage / 512,
+ ATime: uattr.AccessTime.Timespec(),
+ MTime: uattr.ModificationTime.Timespec(),
+ CTime: uattr.StatusChangeTime.Timespec(),
+ }
}
+
+// LINT.ThenChange(vfs2/stat_amd64.go)
diff --git a/pkg/sentry/syscalls/linux/sys_stat_arm64.go b/pkg/sentry/syscalls/linux/sys_stat_arm64.go
index 80c98d05c..5a3b1bfad 100644
--- a/pkg/sentry/syscalls/linux/sys_stat_arm64.go
+++ b/pkg/sentry/syscalls/linux/sys_stat_arm64.go
@@ -12,66 +12,34 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-//+build arm64
+// +build arm64
package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/usermem"
)
-// copyOutStat copies the attributes (sattr, uattr) to the struct stat at
-// address dst in t's address space. It encodes the stat struct to bytes
-// manually, as stat() is a very common syscall for many applications, and
-// t.CopyObjectOut has noticeable performance impact due to its many slice
-// allocations and use of reflection.
-func copyOutStat(t *kernel.Task, dst usermem.Addr, sattr fs.StableAttr, uattr fs.UnstableAttr) error {
- b := t.CopyScratchBuffer(int(linux.SizeOfStat))[:0]
-
- // Dev (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(sattr.DeviceID))
- // Ino (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(sattr.InodeID))
- // Mode (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, sattr.Type.LinuxType()|uint32(uattr.Perms.LinuxMode()))
- // Nlink (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, uint32(uattr.Links))
- // UID (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()))
- // GID (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, uint32(uattr.Owner.GID.In(t.UserNamespace()).OrOverflow()))
- // Rdev (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(linux.MakeDeviceID(sattr.DeviceFileMajor, sattr.DeviceFileMinor)))
- // Padding (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, 0)
- // Size (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(uattr.Size))
- // Blksize (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, uint32(sattr.BlockSize))
- // Padding (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, 0)
- // Blocks (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(uattr.Usage/512))
-
- // ATime
- atime := uattr.AccessTime.Timespec()
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(atime.Sec))
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(atime.Nsec))
-
- // MTime
- mtime := uattr.ModificationTime.Timespec()
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(mtime.Sec))
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(mtime.Nsec))
-
- // CTime
- ctime := uattr.StatusChangeTime.Timespec()
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(ctime.Sec))
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(ctime.Nsec))
-
- _, err := t.CopyOutBytes(dst, b)
- return err
+// LINT.IfChange
+
+func statFromAttrs(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr) linux.Stat {
+ return linux.Stat{
+ Dev: sattr.DeviceID,
+ Ino: sattr.InodeID,
+ Nlink: uint32(uattr.Links),
+ Mode: sattr.Type.LinuxType() | uint32(uattr.Perms.LinuxMode()),
+ UID: uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()),
+ GID: uint32(uattr.Owner.GID.In(t.UserNamespace()).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(sattr.DeviceFileMajor, sattr.DeviceFileMinor)),
+ Size: uattr.Size,
+ Blksize: int32(sattr.BlockSize),
+ Blocks: uattr.Usage / 512,
+ ATime: uattr.AccessTime.Timespec(),
+ MTime: uattr.ModificationTime.Timespec(),
+ CTime: uattr.StatusChangeTime.Timespec(),
+ }
}
+
+// LINT.ThenChange(vfs2/stat_arm64.go)
diff --git a/pkg/sentry/syscalls/linux/sys_sync.go b/pkg/sentry/syscalls/linux/sys_sync.go
index 3e55235bd..5ad465ae3 100644
--- a/pkg/sentry/syscalls/linux/sys_sync.go
+++ b/pkg/sentry/syscalls/linux/sys_sync.go
@@ -22,6 +22,8 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// LINT.IfChange
+
// Sync implements linux system call sync(2).
func Sync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
t.MountNamespace().SyncAll(t)
@@ -135,3 +137,5 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
}
+
+// LINT.ThenChange(vfs2/sync.go)
diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go
index 0c9e2255d..00915fdde 100644
--- a/pkg/sentry/syscalls/linux/sys_thread.go
+++ b/pkg/sentry/syscalls/linux/sys_thread.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
"gvisor.dev/gvisor/pkg/sentry/loader"
@@ -119,7 +120,7 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr user
defer root.DecRef()
var wd *fs.Dirent
- var executable *fs.File
+ var executable fsbridge.File
var closeOnExec bool
if dirFD == linux.AT_FDCWD || path.IsAbs(pathname) {
// Even if the pathname is absolute, we may still need the wd
@@ -136,7 +137,15 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr user
closeOnExec = fdFlags.CloseOnExec
if atEmptyPath && len(pathname) == 0 {
- executable = f
+ // TODO(gvisor.dev/issue/160): Linux requires only execute permission,
+ // not read. However, our backing filesystems may prevent us from reading
+ // the file without read permission. Additionally, a task with a
+ // non-readable executable has additional constraints on access via
+ // ptrace and procfs.
+ if err := f.Dirent.Inode.CheckPermission(t, fs.PermMask{Read: true, Execute: true}); err != nil {
+ return 0, nil, err
+ }
+ executable = fsbridge.NewFSFile(f)
} else {
wd = f.Dirent
wd.IncRef()
@@ -152,9 +161,7 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr user
// Load the new TaskContext.
remainingTraversals := uint(linux.MaxSymlinkTraversals)
loadArgs := loader.LoadArgs{
- Mounts: t.MountNamespace(),
- Root: root,
- WorkingDirectory: wd,
+ Opener: fsbridge.NewFSLookup(t.MountNamespace(), root, wd),
RemainingTraversals: &remainingTraversals,
ResolveFinal: resolveFinal,
Filename: pathname,
diff --git a/pkg/sentry/syscalls/linux/sys_timer.go b/pkg/sentry/syscalls/linux/sys_timer.go
index 432351917..a4c400f87 100644
--- a/pkg/sentry/syscalls/linux/sys_timer.go
+++ b/pkg/sentry/syscalls/linux/sys_timer.go
@@ -146,7 +146,7 @@ func TimerCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
return 0, nil, err
}
- return uintptr(id), nil, nil
+ return 0, nil, nil
}
// TimerSettime implements linux syscall timer_settime(2).
diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go
index aba892939..506ee54ce 100644
--- a/pkg/sentry/syscalls/linux/sys_write.go
+++ b/pkg/sentry/syscalls/linux/sys_write.go
@@ -28,6 +28,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
const (
// EventMaskWrite contains events that can be triggered on writes.
//
@@ -358,3 +360,5 @@ func pwritev(t *kernel.Task, f *fs.File, src usermem.IOSequence, offset int64) (
return total, err
}
+
+// LINT.ThenChange(vfs2/read_write.go)
diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go
index efb95555c..2de5e3422 100644
--- a/pkg/sentry/syscalls/linux/sys_xattr.go
+++ b/pkg/sentry/syscalls/linux/sys_xattr.go
@@ -25,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// GetXattr implements linux syscall getxattr(2).
func GetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
return getXattrFromPath(t, args, true)
@@ -49,14 +51,11 @@ func FGetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
defer f.DecRef()
- n, value, err := getXattr(t, f.Dirent, nameAddr, size)
+ n, err := getXattr(t, f.Dirent, nameAddr, valueAddr, size)
if err != nil {
return 0, nil, err
}
- if _, err := t.CopyOutBytes(valueAddr, []byte(value)); err != nil {
- return 0, nil, err
- }
return uintptr(n), nil, nil
}
@@ -71,41 +70,36 @@ func getXattrFromPath(t *kernel.Task, args arch.SyscallArguments, resolveSymlink
return 0, nil, err
}
- valueLen := 0
- err = fileOpOn(t, linux.AT_FDCWD, path, resolveSymlink, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ n := 0
+ err = fileOpOn(t, linux.AT_FDCWD, path, resolveSymlink, func(_ *fs.Dirent, d *fs.Dirent, _ uint) error {
if dirPath && !fs.IsDir(d.Inode.StableAttr) {
return syserror.ENOTDIR
}
- n, value, err := getXattr(t, d, nameAddr, size)
- valueLen = n
- if err != nil {
- return err
- }
-
- _, err = t.CopyOutBytes(valueAddr, []byte(value))
+ n, err = getXattr(t, d, nameAddr, valueAddr, size)
return err
})
if err != nil {
return 0, nil, err
}
- return uintptr(valueLen), nil, nil
+
+ return uintptr(n), nil, nil
}
// getXattr implements getxattr(2) from the given *fs.Dirent.
-func getXattr(t *kernel.Task, d *fs.Dirent, nameAddr usermem.Addr, size uint64) (int, string, error) {
- if err := checkXattrPermissions(t, d.Inode, fs.PermMask{Read: true}); err != nil {
- return 0, "", err
- }
-
+func getXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr usermem.Addr, size uint64) (int, error) {
name, err := copyInXattrName(t, nameAddr)
if err != nil {
- return 0, "", err
+ return 0, err
+ }
+
+ if err := checkXattrPermissions(t, d.Inode, fs.PermMask{Read: true}); err != nil {
+ return 0, err
}
// TODO(b/148380782): Support xattrs in namespaces other than "user".
if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
- return 0, "", syserror.EOPNOTSUPP
+ return 0, syserror.EOPNOTSUPP
}
// If getxattr(2) is called with size 0, the size of the value will be
@@ -118,18 +112,22 @@ func getXattr(t *kernel.Task, d *fs.Dirent, nameAddr usermem.Addr, size uint64)
value, err := d.Inode.GetXattr(t, name, requestedSize)
if err != nil {
- return 0, "", err
+ return 0, err
}
n := len(value)
if uint64(n) > requestedSize {
- return 0, "", syserror.ERANGE
+ return 0, syserror.ERANGE
}
// Don't copy out the attribute value if size is 0.
if size == 0 {
- return n, "", nil
+ return n, nil
}
- return n, value, nil
+
+ if _, err = t.CopyOutBytes(valueAddr, []byte(value)); err != nil {
+ return 0, err
+ }
+ return n, nil
}
// SetXattr implements linux syscall setxattr(2).
@@ -172,7 +170,7 @@ func setXattrFromPath(t *kernel.Task, args arch.SyscallArguments, resolveSymlink
return 0, nil, err
}
- return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, resolveSymlink, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, resolveSymlink, func(_ *fs.Dirent, d *fs.Dirent, _ uint) error {
if dirPath && !fs.IsDir(d.Inode.StableAttr) {
return syserror.ENOTDIR
}
@@ -187,12 +185,12 @@ func setXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr usermem.Addr, si
return syserror.EINVAL
}
- if err := checkXattrPermissions(t, d.Inode, fs.PermMask{Write: true}); err != nil {
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
return err
}
- name, err := copyInXattrName(t, nameAddr)
- if err != nil {
+ if err := checkXattrPermissions(t, d.Inode, fs.PermMask{Write: true}); err != nil {
return err
}
@@ -226,12 +224,18 @@ func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) {
return name, nil
}
+// Restrict xattrs to regular files and directories.
+//
+// TODO(b/148380782): In Linux, this restriction technically only applies to
+// xattrs in the "user.*" namespace. Make file type checks specific to the
+// namespace once we allow other xattr prefixes.
+func xattrFileTypeOk(i *fs.Inode) bool {
+ return fs.IsRegular(i.StableAttr) || fs.IsDir(i.StableAttr)
+}
+
func checkXattrPermissions(t *kernel.Task, i *fs.Inode, perms fs.PermMask) error {
// Restrict xattrs to regular files and directories.
- //
- // In Linux, this restriction technically only applies to xattrs in the
- // "user.*" namespace, but we don't allow any other xattr prefixes anyway.
- if !fs.IsRegular(i.StableAttr) && !fs.IsDir(i.StableAttr) {
+ if !xattrFileTypeOk(i) {
if perms.Write {
return syserror.EPERM
}
@@ -240,3 +244,181 @@ func checkXattrPermissions(t *kernel.Task, i *fs.Inode, perms fs.PermMask) error
return i.CheckPermission(t, perms)
}
+
+// ListXattr implements linux syscall listxattr(2).
+func ListXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return listXattrFromPath(t, args, true)
+}
+
+// LListXattr implements linux syscall llistxattr(2).
+func LListXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return listXattrFromPath(t, args, false)
+}
+
+// FListXattr implements linux syscall flistxattr(2).
+func FListXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ listAddr := args[1].Pointer()
+ size := uint64(args[2].SizeT())
+
+ // TODO(b/113957122): Return EBADF if the fd was opened with O_PATH.
+ f := t.GetFile(fd)
+ if f == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer f.DecRef()
+
+ n, err := listXattr(t, f.Dirent, listAddr, size)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(n), nil, nil
+}
+
+func listXattrFromPath(t *kernel.Task, args arch.SyscallArguments, resolveSymlink bool) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ listAddr := args[1].Pointer()
+ size := uint64(args[2].SizeT())
+
+ path, dirPath, err := copyInPath(t, pathAddr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n := 0
+ err = fileOpOn(t, linux.AT_FDCWD, path, resolveSymlink, func(_ *fs.Dirent, d *fs.Dirent, _ uint) error {
+ if dirPath && !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ n, err = listXattr(t, d, listAddr, size)
+ return err
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(n), nil, nil
+}
+
+func listXattr(t *kernel.Task, d *fs.Dirent, addr usermem.Addr, size uint64) (int, error) {
+ if !xattrFileTypeOk(d.Inode) {
+ return 0, nil
+ }
+
+ // If listxattr(2) is called with size 0, the buffer size needed to contain
+ // the xattr list will be returned successfully even if it is nonzero. In
+ // that case, we need to retrieve the entire list so we can compute and
+ // return the correct size.
+ requestedSize := size
+ if size == 0 || size > linux.XATTR_SIZE_MAX {
+ requestedSize = linux.XATTR_SIZE_MAX
+ }
+ xattrs, err := d.Inode.ListXattr(t, requestedSize)
+ if err != nil {
+ return 0, err
+ }
+
+ // TODO(b/148380782): support namespaces other than "user".
+ for x := range xattrs {
+ if !strings.HasPrefix(x, linux.XATTR_USER_PREFIX) {
+ delete(xattrs, x)
+ }
+ }
+
+ listSize := xattrListSize(xattrs)
+ if listSize > linux.XATTR_SIZE_MAX {
+ return 0, syserror.E2BIG
+ }
+ if uint64(listSize) > requestedSize {
+ return 0, syserror.ERANGE
+ }
+
+ // Don't copy out the attributes if size is 0.
+ if size == 0 {
+ return listSize, nil
+ }
+
+ buf := make([]byte, 0, listSize)
+ for x := range xattrs {
+ buf = append(buf, []byte(x)...)
+ buf = append(buf, 0)
+ }
+ if _, err := t.CopyOutBytes(addr, buf); err != nil {
+ return 0, err
+ }
+
+ return len(buf), nil
+}
+
+func xattrListSize(xattrs map[string]struct{}) int {
+ size := 0
+ for x := range xattrs {
+ size += len(x) + 1
+ }
+ return size
+}
+
+// RemoveXattr implements linux syscall removexattr(2).
+func RemoveXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return removeXattrFromPath(t, args, true)
+}
+
+// LRemoveXattr implements linux syscall lremovexattr(2).
+func LRemoveXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return removeXattrFromPath(t, args, false)
+}
+
+// FRemoveXattr implements linux syscall fremovexattr(2).
+func FRemoveXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+
+ // TODO(b/113957122): Return EBADF if the fd was opened with O_PATH.
+ f := t.GetFile(fd)
+ if f == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer f.DecRef()
+
+ return 0, nil, removeXattr(t, f.Dirent, nameAddr)
+}
+
+func removeXattrFromPath(t *kernel.Task, args arch.SyscallArguments, resolveSymlink bool) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+
+ path, dirPath, err := copyInPath(t, pathAddr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, resolveSymlink, func(_ *fs.Dirent, d *fs.Dirent, _ uint) error {
+ if dirPath && !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ return removeXattr(t, d, nameAddr)
+ })
+}
+
+// removeXattr implements removexattr(2) from the given *fs.Dirent.
+func removeXattr(t *kernel.Task, d *fs.Dirent, nameAddr usermem.Addr) error {
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return err
+ }
+
+ if err := checkXattrPermissions(t, d.Inode, fs.PermMask{Write: true}); err != nil {
+ return err
+ }
+
+ if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
+ return syserror.EOPNOTSUPP
+ }
+
+ return d.Inode.RemoveXattr(t, d, name)
+}
+
+// LINT.ThenChange(vfs2/xattr.go)
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
index 6b8a00b6e..e7695e995 100644
--- a/pkg/sentry/syscalls/linux/vfs2/BUILD
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -5,18 +5,46 @@ package(licenses = ["notice"])
go_library(
name = "vfs2",
srcs = [
+ "epoll.go",
+ "epoll_unsafe.go",
+ "execve.go",
+ "fd.go",
+ "filesystem.go",
+ "fscontext.go",
+ "getdents.go",
+ "ioctl.go",
"linux64.go",
"linux64_override_amd64.go",
"linux64_override_arm64.go",
- "sys_read.go",
+ "mmap.go",
+ "path.go",
+ "poll.go",
+ "read_write.go",
+ "setstat.go",
+ "stat.go",
+ "stat_amd64.go",
+ "stat_arm64.go",
+ "sync.go",
+ "xattr.go",
],
+ marshal = True,
visibility = ["//:sandbox"],
deps = [
+ "//pkg/abi/linux",
+ "//pkg/fspath",
+ "//pkg/gohacks",
"//pkg/sentry/arch",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/loader",
+ "//pkg/sentry/memmap",
"//pkg/sentry/syscalls",
"//pkg/sentry/syscalls/linux",
"//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go
new file mode 100644
index 000000000..d6cb0e79a
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go
@@ -0,0 +1,225 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "math"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// EpollCreate1 implements Linux syscall epoll_create1(2).
+func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := args[0].Int()
+ if flags&^linux.EPOLL_CLOEXEC != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file, err := t.Kernel().VFS().NewEpollInstanceFD()
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.EPOLL_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+}
+
+// EpollCreate implements Linux syscall epoll_create(2).
+func EpollCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ size := args[0].Int()
+
+ // "Since Linux 2.6.8, the size argument is ignored, but must be greater
+ // than zero" - epoll_create(2)
+ if size <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file, err := t.Kernel().VFS().NewEpollInstanceFD()
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{})
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+}
+
+// EpollCtl implements Linux syscall epoll_ctl(2).
+func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ op := args[1].Int()
+ fd := args[2].Int()
+ eventAddr := args[3].Pointer()
+
+ epfile := t.GetFileVFS2(epfd)
+ if epfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer epfile.DecRef()
+ ep, ok := epfile.Impl().(*vfs.EpollInstance)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+ if epfile == file {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var event linux.EpollEvent
+ switch op {
+ case linux.EPOLL_CTL_ADD:
+ if err := event.CopyIn(t, eventAddr); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, ep.AddInterest(file, fd, event)
+ case linux.EPOLL_CTL_DEL:
+ return 0, nil, ep.DeleteInterest(file, fd)
+ case linux.EPOLL_CTL_MOD:
+ if err := event.CopyIn(t, eventAddr); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, ep.ModifyInterest(file, fd, event)
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+}
+
+// EpollWait implements Linux syscall epoll_wait(2).
+func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ eventsAddr := args[1].Pointer()
+ maxEvents := int(args[2].Int())
+ timeout := int(args[3].Int())
+
+ const _EP_MAX_EVENTS = math.MaxInt32 / sizeofEpollEvent // Linux: fs/eventpoll.c:EP_MAX_EVENTS
+ if maxEvents <= 0 || maxEvents > _EP_MAX_EVENTS {
+ return 0, nil, syserror.EINVAL
+ }
+
+ epfile := t.GetFileVFS2(epfd)
+ if epfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer epfile.DecRef()
+ ep, ok := epfile.Impl().(*vfs.EpollInstance)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Use a fixed-size buffer in a loop, instead of make([]linux.EpollEvent,
+ // maxEvents), so that the buffer can be allocated on the stack.
+ var (
+ events [16]linux.EpollEvent
+ total int
+ ch chan struct{}
+ haveDeadline bool
+ deadline ktime.Time
+ )
+ for {
+ batchEvents := len(events)
+ if batchEvents > maxEvents {
+ batchEvents = maxEvents
+ }
+ n := ep.ReadEvents(events[:batchEvents])
+ maxEvents -= n
+ if n != 0 {
+ // Copy what we read out.
+ copiedEvents, err := copyOutEvents(t, eventsAddr, events[:n])
+ eventsAddr += usermem.Addr(copiedEvents * sizeofEpollEvent)
+ total += copiedEvents
+ if err != nil {
+ if total != 0 {
+ return uintptr(total), nil, nil
+ }
+ return 0, nil, err
+ }
+ // If we've filled the application's event buffer, we're done.
+ if maxEvents == 0 {
+ return uintptr(total), nil, nil
+ }
+ // Loop if we read a full batch, under the expectation that there
+ // may be more events to read.
+ if n == batchEvents {
+ continue
+ }
+ }
+ // We get here if n != batchEvents. If we read any number of events
+ // (just now, or in a previous iteration of this loop), or if timeout
+ // is 0 (such that epoll_wait should be non-blocking), return the
+ // events we've read so far to the application.
+ if total != 0 || timeout == 0 {
+ return uintptr(total), nil, nil
+ }
+ // In the first iteration of this loop, register with the epoll
+ // instance for readability events, but then immediately continue the
+ // loop since we need to retry ReadEvents() before blocking. In all
+ // subsequent iterations, block until events are available, the timeout
+ // expires, or an interrupt arrives.
+ if ch == nil {
+ var w waiter.Entry
+ w, ch = waiter.NewChannelEntry(nil)
+ epfile.EventRegister(&w, waiter.EventIn)
+ defer epfile.EventUnregister(&w)
+ } else {
+ // Set up the timer if a timeout was specified.
+ if timeout > 0 && !haveDeadline {
+ timeoutDur := time.Duration(timeout) * time.Millisecond
+ deadline = t.Kernel().MonotonicClock().Now().Add(timeoutDur)
+ haveDeadline = true
+ }
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = nil
+ }
+ // total must be 0 since otherwise we would have returned
+ // above.
+ return 0, nil, err
+ }
+ }
+ }
+}
+
+// EpollPwait implements Linux syscall epoll_pwait(2).
+func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ maskAddr := args[4].Pointer()
+ maskSize := uint(args[5].Uint())
+
+ if err := setTempSignalSet(t, maskAddr, maskSize); err != nil {
+ return 0, nil, err
+ }
+
+ return EpollWait(t, args)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go b/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go
new file mode 100644
index 000000000..825f325bf
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go
@@ -0,0 +1,44 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "reflect"
+ "runtime"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const sizeofEpollEvent = int(unsafe.Sizeof(linux.EpollEvent{}))
+
+func copyOutEvents(t *kernel.Task, addr usermem.Addr, events []linux.EpollEvent) (int, error) {
+ if len(events) == 0 {
+ return 0, nil
+ }
+ // Cast events to a byte slice for copying.
+ var eventBytes []byte
+ eventBytesHdr := (*reflect.SliceHeader)(unsafe.Pointer(&eventBytes))
+ eventBytesHdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(&events[0])))
+ eventBytesHdr.Len = len(events) * sizeofEpollEvent
+ eventBytesHdr.Cap = len(events) * sizeofEpollEvent
+ copiedBytes, err := t.CopyOutBytes(addr, eventBytes)
+ runtime.KeepAlive(events)
+ copiedEvents := copiedBytes / sizeofEpollEvent // rounded down
+ return copiedEvents, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/execve.go b/pkg/sentry/syscalls/linux/vfs2/execve.go
new file mode 100644
index 000000000..aef0078a8
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/execve.go
@@ -0,0 +1,137 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/loader"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Execve implements linux syscall execve(2).
+func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathnameAddr := args[0].Pointer()
+ argvAddr := args[1].Pointer()
+ envvAddr := args[2].Pointer()
+ return execveat(t, linux.AT_FDCWD, pathnameAddr, argvAddr, envvAddr, 0 /* flags */)
+}
+
+// Execveat implements linux syscall execveat(2).
+func Execveat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathnameAddr := args[1].Pointer()
+ argvAddr := args[2].Pointer()
+ envvAddr := args[3].Pointer()
+ flags := args[4].Int()
+ return execveat(t, dirfd, pathnameAddr, argvAddr, envvAddr, flags)
+}
+
+func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr usermem.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ pathname, err := t.CopyInString(pathnameAddr, linux.PATH_MAX)
+ if err != nil {
+ return 0, nil, err
+ }
+ var argv, envv []string
+ if argvAddr != 0 {
+ var err error
+ argv, err = t.CopyInVector(argvAddr, slinux.ExecMaxElemSize, slinux.ExecMaxTotalSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+ if envvAddr != 0 {
+ var err error
+ envv, err = t.CopyInVector(envvAddr, slinux.ExecMaxElemSize, slinux.ExecMaxTotalSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ var executable fsbridge.File
+ closeOnExec := false
+ if path := fspath.Parse(pathname); dirfd != linux.AT_FDCWD && !path.Absolute {
+ // We must open the executable ourselves since dirfd is used as the
+ // starting point while resolving path, but the task working directory
+ // is used as the starting point while resolving interpreters (Linux:
+ // fs/binfmt_script.c:load_script() => fs/exec.c:open_exec() =>
+ // do_open_execat(fd=AT_FDCWD)), and the loader package is currently
+ // incapable of handling this correctly.
+ if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
+ return 0, nil, syserror.ENOENT
+ }
+ dirfile, dirfileFlags := t.FDTable().GetVFS2(dirfd)
+ if dirfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ start := dirfile.VirtualDentry()
+ start.IncRef()
+ dirfile.DecRef()
+ closeOnExec = dirfileFlags.CloseOnExec
+ file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ FileExec: true,
+ })
+ start.DecRef()
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+ executable = fsbridge.NewVFSFile(file)
+ }
+
+ // Load the new TaskContext.
+ mntns := t.MountNamespaceVFS2() // FIXME(jamieliu): useless refcount change
+ defer mntns.DecRef()
+ wd := t.FSContext().WorkingDirectoryVFS2()
+ defer wd.DecRef()
+ remainingTraversals := uint(linux.MaxSymlinkTraversals)
+ loadArgs := loader.LoadArgs{
+ Opener: fsbridge.NewVFSLookup(mntns, root, wd),
+ RemainingTraversals: &remainingTraversals,
+ ResolveFinal: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ Filename: pathname,
+ File: executable,
+ CloseOnExec: closeOnExec,
+ Argv: argv,
+ Envv: envv,
+ Features: t.Arch().FeatureSet(),
+ }
+
+ tc, se := t.Kernel().LoadTaskImage(t, loadArgs)
+ if se != nil {
+ return 0, nil, se.ToError()
+ }
+
+ ctrl, err := t.Execve(tc)
+ return 0, ctrl, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go
new file mode 100644
index 000000000..3afcea665
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/fd.go
@@ -0,0 +1,147 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Close implements Linux syscall close(2).
+func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ // Note that Remove provides a reference on the file that we may use to
+ // flush. It is still active until we drop the final reference below
+ // (and other reference-holding operations complete).
+ _, file := t.FDTable().Remove(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ err := file.OnClose(t)
+ return 0, nil, slinux.HandleIOErrorVFS2(t, false /* partial */, err, syserror.EINTR, "close", file)
+}
+
+// Dup implements Linux syscall dup(2).
+func Dup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ newFD, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{})
+ if err != nil {
+ return 0, nil, syserror.EMFILE
+ }
+ return uintptr(newFD), nil, nil
+}
+
+// Dup2 implements Linux syscall dup2(2).
+func Dup2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldfd := args[0].Int()
+ newfd := args[1].Int()
+
+ if oldfd == newfd {
+ // As long as oldfd is valid, dup2() does nothing and returns newfd.
+ file := t.GetFileVFS2(oldfd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ file.DecRef()
+ return uintptr(newfd), nil, nil
+ }
+
+ return dup3(t, oldfd, newfd, 0)
+}
+
+// Dup3 implements Linux syscall dup3(2).
+func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldfd := args[0].Int()
+ newfd := args[1].Int()
+ flags := args[2].Uint()
+
+ if oldfd == newfd {
+ return 0, nil, syserror.EINVAL
+ }
+
+ return dup3(t, oldfd, newfd, flags)
+}
+
+func dup3(t *kernel.Task, oldfd, newfd int32, flags uint32) (uintptr, *kernel.SyscallControl, error) {
+ if flags&^linux.O_CLOEXEC != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(oldfd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ err := t.NewFDAtVFS2(newfd, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.O_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(newfd), nil, nil
+}
+
+// Fcntl implements linux syscall fcntl(2).
+func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ cmd := args[1].Int()
+
+ file, flags := t.FDTable().GetVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ switch cmd {
+ case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC:
+ minfd := args[2].Int()
+ fd, err := t.NewFDFromVFS2(minfd, file, kernel.FDFlags{
+ CloseOnExec: cmd == linux.F_DUPFD_CLOEXEC,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+ case linux.F_GETFD:
+ return uintptr(flags.ToLinuxFDFlags()), nil, nil
+ case linux.F_SETFD:
+ flags := args[2].Uint()
+ t.FDTable().SetFlags(fd, kernel.FDFlags{
+ CloseOnExec: flags&linux.FD_CLOEXEC != 0,
+ })
+ return 0, nil, nil
+ case linux.F_GETFL:
+ return uintptr(file.StatusFlags()), nil, nil
+ case linux.F_SETFL:
+ return 0, nil, file.SetStatusFlags(t, t.Credentials(), args[2].Uint())
+ default:
+ // TODO(gvisor.dev/issue/1623): Everything else is not yet supported.
+ return 0, nil, syserror.EINVAL
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
new file mode 100644
index 000000000..fc5ceea4c
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
@@ -0,0 +1,326 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Link implements Linux syscall link(2).
+func Link(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldpathAddr := args[0].Pointer()
+ newpathAddr := args[1].Pointer()
+ return 0, nil, linkat(t, linux.AT_FDCWD, oldpathAddr, linux.AT_FDCWD, newpathAddr, 0 /* flags */)
+}
+
+// Linkat implements Linux syscall linkat(2).
+func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ olddirfd := args[0].Int()
+ oldpathAddr := args[1].Pointer()
+ newdirfd := args[2].Int()
+ newpathAddr := args[3].Pointer()
+ flags := args[4].Int()
+ return 0, nil, linkat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, flags)
+}
+
+func linkat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd int32, newpathAddr usermem.Addr, flags int32) error {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_FOLLOW) != 0 {
+ return syserror.EINVAL
+ }
+ if flags&linux.AT_EMPTY_PATH != 0 && !t.HasCapability(linux.CAP_DAC_READ_SEARCH) {
+ return syserror.ENOENT
+ }
+
+ oldpath, err := copyInPath(t, oldpathAddr)
+ if err != nil {
+ return err
+ }
+ oldtpop, err := getTaskPathOperation(t, olddirfd, oldpath, shouldAllowEmptyPath(flags&linux.AT_EMPTY_PATH != 0), shouldFollowFinalSymlink(flags&linux.AT_SYMLINK_FOLLOW != 0))
+ if err != nil {
+ return err
+ }
+ defer oldtpop.Release()
+
+ newpath, err := copyInPath(t, newpathAddr)
+ if err != nil {
+ return err
+ }
+ newtpop, err := getTaskPathOperation(t, newdirfd, newpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer newtpop.Release()
+
+ return t.Kernel().VFS().LinkAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop)
+}
+
+// Mkdir implements Linux syscall mkdir(2).
+func Mkdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+ return 0, nil, mkdirat(t, linux.AT_FDCWD, addr, mode)
+}
+
+// Mkdirat implements Linux syscall mkdirat(2).
+func Mkdirat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ mode := args[2].ModeT()
+ return 0, nil, mkdirat(t, dirfd, addr, mode)
+}
+
+func mkdirat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint) error {
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().MkdirAt(t, t.Credentials(), &tpop.pop, &vfs.MkdirOptions{
+ Mode: linux.FileMode(mode & (0777 | linux.S_ISVTX) &^ t.FSContext().Umask()),
+ })
+}
+
+// Mknod implements Linux syscall mknod(2).
+func Mknod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+ dev := args[2].Uint()
+ return 0, nil, mknodat(t, linux.AT_FDCWD, addr, mode, dev)
+}
+
+// Mknodat implements Linux syscall mknodat(2).
+func Mknodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ mode := args[2].ModeT()
+ dev := args[3].Uint()
+ return 0, nil, mknodat(t, dirfd, addr, mode, dev)
+}
+
+func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint, dev uint32) error {
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ major, minor := linux.DecodeDeviceID(dev)
+ return t.Kernel().VFS().MknodAt(t, t.Credentials(), &tpop.pop, &vfs.MknodOptions{
+ Mode: linux.FileMode(mode &^ t.FSContext().Umask()),
+ DevMajor: uint32(major),
+ DevMinor: minor,
+ })
+}
+
+// Open implements Linux syscall open(2).
+func Open(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := args[1].Uint()
+ mode := args[2].ModeT()
+ return openat(t, linux.AT_FDCWD, addr, flags, mode)
+}
+
+// Openat implements Linux syscall openat(2).
+func Openat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ flags := args[2].Uint()
+ mode := args[3].ModeT()
+ return openat(t, dirfd, addr, flags, mode)
+}
+
+// Creat implements Linux syscall creat(2).
+func Creat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+ return openat(t, linux.AT_FDCWD, addr, linux.O_WRONLY|linux.O_CREAT|linux.O_TRUNC, mode)
+}
+
+func openat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, flags uint32, mode uint) (uintptr, *kernel.SyscallControl, error) {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, shouldFollowFinalSymlink(flags&linux.O_NOFOLLOW == 0))
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &tpop.pop, &vfs.OpenOptions{
+ Flags: flags,
+ Mode: linux.FileMode(mode & (0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX) &^ t.FSContext().Umask()),
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.O_CLOEXEC != 0,
+ })
+ return uintptr(fd), nil, err
+}
+
+// Rename implements Linux syscall rename(2).
+func Rename(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldpathAddr := args[0].Pointer()
+ newpathAddr := args[1].Pointer()
+ return 0, nil, renameat(t, linux.AT_FDCWD, oldpathAddr, linux.AT_FDCWD, newpathAddr, 0 /* flags */)
+}
+
+// Renameat implements Linux syscall renameat(2).
+func Renameat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ olddirfd := args[0].Int()
+ oldpathAddr := args[1].Pointer()
+ newdirfd := args[2].Int()
+ newpathAddr := args[3].Pointer()
+ return 0, nil, renameat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, 0 /* flags */)
+}
+
+// Renameat2 implements Linux syscall renameat2(2).
+func Renameat2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ olddirfd := args[0].Int()
+ oldpathAddr := args[1].Pointer()
+ newdirfd := args[2].Int()
+ newpathAddr := args[3].Pointer()
+ flags := args[4].Uint()
+ return 0, nil, renameat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, flags)
+}
+
+func renameat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd int32, newpathAddr usermem.Addr, flags uint32) error {
+ oldpath, err := copyInPath(t, oldpathAddr)
+ if err != nil {
+ return err
+ }
+ // "If oldpath refers to a symbolic link, the link is renamed" - rename(2)
+ oldtpop, err := getTaskPathOperation(t, olddirfd, oldpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer oldtpop.Release()
+
+ newpath, err := copyInPath(t, newpathAddr)
+ if err != nil {
+ return err
+ }
+ newtpop, err := getTaskPathOperation(t, newdirfd, newpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer newtpop.Release()
+
+ return t.Kernel().VFS().RenameAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop, &vfs.RenameOptions{
+ Flags: flags,
+ })
+}
+
+// Rmdir implements Linux syscall rmdir(2).
+func Rmdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ return 0, nil, rmdirat(t, linux.AT_FDCWD, pathAddr)
+}
+
+func rmdirat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().RmdirAt(t, t.Credentials(), &tpop.pop)
+}
+
+// Unlink implements Linux syscall unlink(2).
+func Unlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ return 0, nil, unlinkat(t, linux.AT_FDCWD, pathAddr)
+}
+
+func unlinkat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().UnlinkAt(t, t.Credentials(), &tpop.pop)
+}
+
+// Unlinkat implements Linux syscall unlinkat(2).
+func Unlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ flags := args[2].Int()
+
+ if flags&^linux.AT_REMOVEDIR != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if flags&linux.AT_REMOVEDIR != 0 {
+ return 0, nil, rmdirat(t, dirfd, pathAddr)
+ }
+ return 0, nil, unlinkat(t, dirfd, pathAddr)
+}
+
+// Symlink implements Linux syscall symlink(2).
+func Symlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ targetAddr := args[0].Pointer()
+ linkpathAddr := args[1].Pointer()
+ return 0, nil, symlinkat(t, targetAddr, linux.AT_FDCWD, linkpathAddr)
+}
+
+// Symlinkat implements Linux syscall symlinkat(2).
+func Symlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ targetAddr := args[0].Pointer()
+ newdirfd := args[1].Int()
+ linkpathAddr := args[2].Pointer()
+ return 0, nil, symlinkat(t, targetAddr, newdirfd, linkpathAddr)
+}
+
+func symlinkat(t *kernel.Task, targetAddr usermem.Addr, newdirfd int32, linkpathAddr usermem.Addr) error {
+ target, err := t.CopyInString(targetAddr, linux.PATH_MAX)
+ if err != nil {
+ return err
+ }
+ linkpath, err := copyInPath(t, linkpathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, newdirfd, linkpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().SymlinkAt(t, t.Credentials(), &tpop.pop, target)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/fscontext.go b/pkg/sentry/syscalls/linux/vfs2/fscontext.go
new file mode 100644
index 000000000..317409a18
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/fscontext.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Getcwd implements Linux syscall getcwd(2).
+func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ size := args[1].SizeT()
+
+ root := t.FSContext().RootDirectoryVFS2()
+ wd := t.FSContext().WorkingDirectoryVFS2()
+ s, err := t.Kernel().VFS().PathnameForGetcwd(t, root, wd)
+ root.DecRef()
+ wd.DecRef()
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Note this is >= because we need a terminator.
+ if uint(len(s)) >= size {
+ return 0, nil, syserror.ERANGE
+ }
+
+ // Construct a byte slice containing a NUL terminator.
+ buf := t.CopyScratchBuffer(len(s) + 1)
+ copy(buf, s)
+ buf[len(buf)-1] = 0
+
+ // Write the pathname slice.
+ n, err := t.CopyOutBytes(addr, buf)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Chdir implements Linux syscall chdir(2).
+func Chdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ t.FSContext().SetWorkingDirectoryVFS2(vd)
+ vd.DecRef()
+ return 0, nil, nil
+}
+
+// Fchdir implements Linux syscall fchdir(2).
+func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ tpop, err := getTaskPathOperation(t, fd, fspath.Path{}, allowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ t.FSContext().SetWorkingDirectoryVFS2(vd)
+ vd.DecRef()
+ return 0, nil, nil
+}
+
+// Chroot implements Linux syscall chroot(2).
+func Chroot(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ if !t.HasCapability(linux.CAP_SYS_CHROOT) {
+ return 0, nil, syserror.EPERM
+ }
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ t.FSContext().SetRootDirectoryVFS2(vd)
+ vd.DecRef()
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/getdents.go b/pkg/sentry/syscalls/linux/vfs2/getdents.go
new file mode 100644
index 000000000..ddc140b65
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/getdents.go
@@ -0,0 +1,149 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Getdents implements Linux syscall getdents(2).
+func Getdents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getdents(t, args, false /* isGetdents64 */)
+}
+
+// Getdents64 implements Linux syscall getdents64(2).
+func Getdents64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getdents(t, args, true /* isGetdents64 */)
+}
+
+func getdents(t *kernel.Task, args arch.SyscallArguments, isGetdents64 bool) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := int(args[2].Uint())
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ cb := getGetdentsCallback(t, addr, size, isGetdents64)
+ err := file.IterDirents(t, cb)
+ n := size - cb.remaining
+ putGetdentsCallback(cb)
+ if n == 0 {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+type getdentsCallback struct {
+ t *kernel.Task
+ addr usermem.Addr
+ remaining int
+ isGetdents64 bool
+}
+
+var getdentsCallbackPool = sync.Pool{
+ New: func() interface{} {
+ return &getdentsCallback{}
+ },
+}
+
+func getGetdentsCallback(t *kernel.Task, addr usermem.Addr, size int, isGetdents64 bool) *getdentsCallback {
+ cb := getdentsCallbackPool.Get().(*getdentsCallback)
+ *cb = getdentsCallback{
+ t: t,
+ addr: addr,
+ remaining: size,
+ isGetdents64: isGetdents64,
+ }
+ return cb
+}
+
+func putGetdentsCallback(cb *getdentsCallback) {
+ cb.t = nil
+ getdentsCallbackPool.Put(cb)
+}
+
+// Handle implements vfs.IterDirentsCallback.Handle.
+func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
+ var buf []byte
+ if cb.isGetdents64 {
+ // struct linux_dirent64 {
+ // ino64_t d_ino; /* 64-bit inode number */
+ // off64_t d_off; /* 64-bit offset to next structure */
+ // unsigned short d_reclen; /* Size of this dirent */
+ // unsigned char d_type; /* File type */
+ // char d_name[]; /* Filename (null-terminated) */
+ // };
+ size := 8 + 8 + 2 + 1 + 1 + len(dirent.Name)
+ if size < cb.remaining {
+ return syserror.EINVAL
+ }
+ buf = cb.t.CopyScratchBuffer(size)
+ usermem.ByteOrder.PutUint64(buf[0:8], dirent.Ino)
+ usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff))
+ usermem.ByteOrder.PutUint16(buf[16:18], uint16(size))
+ buf[18] = dirent.Type
+ copy(buf[19:], dirent.Name)
+ buf[size-1] = 0 // NUL terminator
+ } else {
+ // struct linux_dirent {
+ // unsigned long d_ino; /* Inode number */
+ // unsigned long d_off; /* Offset to next linux_dirent */
+ // unsigned short d_reclen; /* Length of this linux_dirent */
+ // char d_name[]; /* Filename (null-terminated) */
+ // /* length is actually (d_reclen - 2 -
+ // offsetof(struct linux_dirent, d_name)) */
+ // /*
+ // char pad; // Zero padding byte
+ // char d_type; // File type (only since Linux
+ // // 2.6.4); offset is (d_reclen - 1)
+ // */
+ // };
+ if cb.t.Arch().Width() != 8 {
+ panic(fmt.Sprintf("unsupported sizeof(unsigned long): %d", cb.t.Arch().Width()))
+ }
+ size := 8 + 8 + 2 + 1 + 1 + 1 + len(dirent.Name)
+ if size < cb.remaining {
+ return syserror.EINVAL
+ }
+ buf = cb.t.CopyScratchBuffer(size)
+ usermem.ByteOrder.PutUint64(buf[0:8], dirent.Ino)
+ usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff))
+ usermem.ByteOrder.PutUint16(buf[16:18], uint16(size))
+ copy(buf[18:], dirent.Name)
+ buf[size-3] = 0 // NUL terminator
+ buf[size-2] = 0 // zero padding byte
+ buf[size-1] = dirent.Type
+ }
+ n, err := cb.t.CopyOutBytes(cb.addr, buf)
+ if err != nil {
+ // Don't report partially-written dirents by advancing cb.addr or
+ // cb.remaining.
+ return err
+ }
+ cb.addr += usermem.Addr(n)
+ cb.remaining -= n
+ return nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
new file mode 100644
index 000000000..5a2418da9
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
@@ -0,0 +1,35 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Ioctl implements Linux syscall ioctl(2).
+func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ ret, err := file.Ioctl(t, t.MemoryManager(), args)
+ return ret, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
index c134714ee..7d220bc20 100644
--- a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
+++ b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build amd64
+
package vfs2
import (
@@ -22,4 +24,142 @@ import (
// Override syscall table to add syscalls implementations from this package.
func Override(table map[uintptr]kernel.Syscall) {
table[0] = syscalls.Supported("read", Read)
+ table[1] = syscalls.Supported("write", Write)
+ table[2] = syscalls.Supported("open", Open)
+ table[3] = syscalls.Supported("close", Close)
+ table[4] = syscalls.Supported("stat", Stat)
+ table[5] = syscalls.Supported("fstat", Fstat)
+ table[6] = syscalls.Supported("lstat", Lstat)
+ table[7] = syscalls.Supported("poll", Poll)
+ table[8] = syscalls.Supported("lseek", Lseek)
+ table[9] = syscalls.Supported("mmap", Mmap)
+ table[16] = syscalls.Supported("ioctl", Ioctl)
+ table[17] = syscalls.Supported("pread64", Pread64)
+ table[18] = syscalls.Supported("pwrite64", Pwrite64)
+ table[19] = syscalls.Supported("readv", Readv)
+ table[20] = syscalls.Supported("writev", Writev)
+ table[21] = syscalls.Supported("access", Access)
+ delete(table, 22) // pipe
+ table[23] = syscalls.Supported("select", Select)
+ table[32] = syscalls.Supported("dup", Dup)
+ table[33] = syscalls.Supported("dup2", Dup2)
+ delete(table, 40) // sendfile
+ delete(table, 41) // socket
+ delete(table, 42) // connect
+ delete(table, 43) // accept
+ delete(table, 44) // sendto
+ delete(table, 45) // recvfrom
+ delete(table, 46) // sendmsg
+ delete(table, 47) // recvmsg
+ delete(table, 48) // shutdown
+ delete(table, 49) // bind
+ delete(table, 50) // listen
+ delete(table, 51) // getsockname
+ delete(table, 52) // getpeername
+ delete(table, 53) // socketpair
+ delete(table, 54) // setsockopt
+ delete(table, 55) // getsockopt
+ table[59] = syscalls.Supported("execve", Execve)
+ table[72] = syscalls.Supported("fcntl", Fcntl)
+ delete(table, 73) // flock
+ table[74] = syscalls.Supported("fsync", Fsync)
+ table[75] = syscalls.Supported("fdatasync", Fdatasync)
+ table[76] = syscalls.Supported("truncate", Truncate)
+ table[77] = syscalls.Supported("ftruncate", Ftruncate)
+ table[78] = syscalls.Supported("getdents", Getdents)
+ table[79] = syscalls.Supported("getcwd", Getcwd)
+ table[80] = syscalls.Supported("chdir", Chdir)
+ table[81] = syscalls.Supported("fchdir", Fchdir)
+ table[82] = syscalls.Supported("rename", Rename)
+ table[83] = syscalls.Supported("mkdir", Mkdir)
+ table[84] = syscalls.Supported("rmdir", Rmdir)
+ table[85] = syscalls.Supported("creat", Creat)
+ table[86] = syscalls.Supported("link", Link)
+ table[87] = syscalls.Supported("unlink", Unlink)
+ table[88] = syscalls.Supported("symlink", Symlink)
+ table[89] = syscalls.Supported("readlink", Readlink)
+ table[90] = syscalls.Supported("chmod", Chmod)
+ table[91] = syscalls.Supported("fchmod", Fchmod)
+ table[92] = syscalls.Supported("chown", Chown)
+ table[93] = syscalls.Supported("fchown", Fchown)
+ table[94] = syscalls.Supported("lchown", Lchown)
+ table[132] = syscalls.Supported("utime", Utime)
+ table[133] = syscalls.Supported("mknod", Mknod)
+ table[137] = syscalls.Supported("statfs", Statfs)
+ table[138] = syscalls.Supported("fstatfs", Fstatfs)
+ table[161] = syscalls.Supported("chroot", Chroot)
+ table[162] = syscalls.Supported("sync", Sync)
+ delete(table, 165) // mount
+ delete(table, 166) // umount2
+ delete(table, 187) // readahead
+ table[188] = syscalls.Supported("setxattr", Setxattr)
+ table[189] = syscalls.Supported("lsetxattr", Lsetxattr)
+ table[190] = syscalls.Supported("fsetxattr", Fsetxattr)
+ table[191] = syscalls.Supported("getxattr", Getxattr)
+ table[192] = syscalls.Supported("lgetxattr", Lgetxattr)
+ table[193] = syscalls.Supported("fgetxattr", Fgetxattr)
+ table[194] = syscalls.Supported("listxattr", Listxattr)
+ table[195] = syscalls.Supported("llistxattr", Llistxattr)
+ table[196] = syscalls.Supported("flistxattr", Flistxattr)
+ table[197] = syscalls.Supported("removexattr", Removexattr)
+ table[198] = syscalls.Supported("lremovexattr", Lremovexattr)
+ table[199] = syscalls.Supported("fremovexattr", Fremovexattr)
+ delete(table, 206) // io_setup
+ delete(table, 207) // io_destroy
+ delete(table, 208) // io_getevents
+ delete(table, 209) // io_submit
+ delete(table, 210) // io_cancel
+ table[213] = syscalls.Supported("epoll_create", EpollCreate)
+ table[217] = syscalls.Supported("getdents64", Getdents64)
+ delete(table, 221) // fdavise64
+ table[232] = syscalls.Supported("epoll_wait", EpollWait)
+ table[233] = syscalls.Supported("epoll_ctl", EpollCtl)
+ table[235] = syscalls.Supported("utimes", Utimes)
+ delete(table, 253) // inotify_init
+ delete(table, 254) // inotify_add_watch
+ delete(table, 255) // inotify_rm_watch
+ table[257] = syscalls.Supported("openat", Openat)
+ table[258] = syscalls.Supported("mkdirat", Mkdirat)
+ table[259] = syscalls.Supported("mknodat", Mknodat)
+ table[260] = syscalls.Supported("fchownat", Fchownat)
+ table[261] = syscalls.Supported("futimens", Futimens)
+ table[262] = syscalls.Supported("newfstatat", Newfstatat)
+ table[263] = syscalls.Supported("unlinkat", Unlinkat)
+ table[264] = syscalls.Supported("renameat", Renameat)
+ table[265] = syscalls.Supported("linkat", Linkat)
+ table[266] = syscalls.Supported("symlinkat", Symlinkat)
+ table[267] = syscalls.Supported("readlinkat", Readlinkat)
+ table[268] = syscalls.Supported("fchmodat", Fchmodat)
+ table[269] = syscalls.Supported("faccessat", Faccessat)
+ table[270] = syscalls.Supported("pselect", Pselect)
+ table[271] = syscalls.Supported("ppoll", Ppoll)
+ delete(table, 275) // splice
+ delete(table, 276) // tee
+ table[277] = syscalls.Supported("sync_file_range", SyncFileRange)
+ table[280] = syscalls.Supported("utimensat", Utimensat)
+ table[281] = syscalls.Supported("epoll_pwait", EpollPwait)
+ delete(table, 282) // signalfd
+ delete(table, 283) // timerfd_create
+ delete(table, 284) // eventfd
+ delete(table, 285) // fallocate
+ delete(table, 286) // timerfd_settime
+ delete(table, 287) // timerfd_gettime
+ delete(table, 288) // accept4
+ delete(table, 289) // signalfd4
+ delete(table, 290) // eventfd2
+ table[291] = syscalls.Supported("epoll_create1", EpollCreate1)
+ table[292] = syscalls.Supported("dup3", Dup3)
+ delete(table, 293) // pipe2
+ delete(table, 294) // inotify_init1
+ table[295] = syscalls.Supported("preadv", Preadv)
+ table[296] = syscalls.Supported("pwritev", Pwritev)
+ delete(table, 299) // recvmmsg
+ table[306] = syscalls.Supported("syncfs", Syncfs)
+ delete(table, 307) // sendmmsg
+ table[316] = syscalls.Supported("renameat2", Renameat2)
+ delete(table, 319) // memfd_create
+ table[322] = syscalls.Supported("execveat", Execveat)
+ table[327] = syscalls.Supported("preadv2", Preadv2)
+ table[328] = syscalls.Supported("pwritev2", Pwritev2)
+ table[332] = syscalls.Supported("statx", Statx)
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go b/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go
index 6af5c400f..a6b367468 100644
--- a/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go
+++ b/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build arm64
+
package vfs2
import (
diff --git a/pkg/sentry/syscalls/linux/vfs2/mmap.go b/pkg/sentry/syscalls/linux/vfs2/mmap.go
new file mode 100644
index 000000000..60a43f0a0
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/mmap.go
@@ -0,0 +1,92 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Mmap implements Linux syscall mmap(2).
+func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ prot := args[2].Int()
+ flags := args[3].Int()
+ fd := args[4].Int()
+ fixed := flags&linux.MAP_FIXED != 0
+ private := flags&linux.MAP_PRIVATE != 0
+ shared := flags&linux.MAP_SHARED != 0
+ anon := flags&linux.MAP_ANONYMOUS != 0
+ map32bit := flags&linux.MAP_32BIT != 0
+
+ // Require exactly one of MAP_PRIVATE and MAP_SHARED.
+ if private == shared {
+ return 0, nil, syserror.EINVAL
+ }
+
+ opts := memmap.MMapOpts{
+ Length: args[1].Uint64(),
+ Offset: args[5].Uint64(),
+ Addr: args[0].Pointer(),
+ Fixed: fixed,
+ Unmap: fixed,
+ Map32Bit: map32bit,
+ Private: private,
+ Perms: usermem.AccessType{
+ Read: linux.PROT_READ&prot != 0,
+ Write: linux.PROT_WRITE&prot != 0,
+ Execute: linux.PROT_EXEC&prot != 0,
+ },
+ MaxPerms: usermem.AnyAccess,
+ GrowsDown: linux.MAP_GROWSDOWN&flags != 0,
+ Precommit: linux.MAP_POPULATE&flags != 0,
+ }
+ if linux.MAP_LOCKED&flags != 0 {
+ opts.MLockMode = memmap.MLockEager
+ }
+ defer func() {
+ if opts.MappingIdentity != nil {
+ opts.MappingIdentity.DecRef()
+ }
+ }()
+
+ if !anon {
+ // Convert the passed FD to a file reference.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // mmap unconditionally requires that the FD is readable.
+ if !file.IsReadable() {
+ return 0, nil, syserror.EACCES
+ }
+ // MAP_SHARED requires that the FD be writable for PROT_WRITE.
+ if shared && !file.IsWritable() {
+ opts.MaxPerms.Write = false
+ }
+
+ if err := file.ConfigureMMap(t, &opts); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ rv, err := t.MemoryManager().MMap(t, opts)
+ return uintptr(rv), nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/path.go b/pkg/sentry/syscalls/linux/vfs2/path.go
new file mode 100644
index 000000000..97da6c647
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/path.go
@@ -0,0 +1,94 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func copyInPath(t *kernel.Task, addr usermem.Addr) (fspath.Path, error) {
+ pathname, err := t.CopyInString(addr, linux.PATH_MAX)
+ if err != nil {
+ return fspath.Path{}, err
+ }
+ return fspath.Parse(pathname), nil
+}
+
+type taskPathOperation struct {
+ pop vfs.PathOperation
+ haveStartRef bool
+}
+
+func getTaskPathOperation(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPath shouldAllowEmptyPath, shouldFollowFinalSymlink shouldFollowFinalSymlink) (taskPathOperation, error) {
+ root := t.FSContext().RootDirectoryVFS2()
+ start := root
+ haveStartRef := false
+ if !path.Absolute {
+ if !path.HasComponents() && !bool(shouldAllowEmptyPath) {
+ root.DecRef()
+ return taskPathOperation{}, syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ haveStartRef = true
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ root.DecRef()
+ return taskPathOperation{}, syserror.EBADF
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ haveStartRef = true
+ dirfile.DecRef()
+ }
+ }
+ return taskPathOperation{
+ pop: vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: bool(shouldFollowFinalSymlink),
+ },
+ haveStartRef: haveStartRef,
+ }, nil
+}
+
+func (tpop *taskPathOperation) Release() {
+ tpop.pop.Root.DecRef()
+ if tpop.haveStartRef {
+ tpop.pop.Start.DecRef()
+ tpop.haveStartRef = false
+ }
+}
+
+type shouldAllowEmptyPath bool
+
+const (
+ disallowEmptyPath shouldAllowEmptyPath = false
+ allowEmptyPath shouldAllowEmptyPath = true
+)
+
+type shouldFollowFinalSymlink bool
+
+const (
+ nofollowFinalSymlink shouldFollowFinalSymlink = false
+ followFinalSymlink shouldFollowFinalSymlink = true
+)
diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go
new file mode 100644
index 000000000..dbf4882da
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/poll.go
@@ -0,0 +1,584 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// fileCap is the maximum allowable files for poll & select. This has no
+// equivalent in Linux; it exists in gVisor since allocation failure in Go is
+// unrecoverable.
+const fileCap = 1024 * 1024
+
+// Masks for "readable", "writable", and "exceptional" events as defined by
+// select(2).
+const (
+ // selectReadEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLIN_SET.
+ selectReadEvents = linux.POLLIN | linux.POLLHUP | linux.POLLERR
+
+ // selectWriteEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLOUT_SET.
+ selectWriteEvents = linux.POLLOUT | linux.POLLERR
+
+ // selectExceptEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLEX_SET.
+ selectExceptEvents = linux.POLLPRI
+)
+
+// pollState tracks the associated file description and waiter of a PollFD.
+type pollState struct {
+ file *vfs.FileDescription
+ waiter waiter.Entry
+}
+
+// initReadiness gets the current ready mask for the file represented by the FD
+// stored in pfd.FD. If a channel is passed in, the waiter entry in "state" is
+// used to register with the file for event notifications, and a reference to
+// the file is stored in "state".
+func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan struct{}) {
+ if pfd.FD < 0 {
+ pfd.REvents = 0
+ return
+ }
+
+ file := t.GetFileVFS2(pfd.FD)
+ if file == nil {
+ pfd.REvents = linux.POLLNVAL
+ return
+ }
+
+ if ch == nil {
+ defer file.DecRef()
+ } else {
+ state.file = file
+ state.waiter, _ = waiter.NewChannelEntry(ch)
+ file.EventRegister(&state.waiter, waiter.EventMaskFromLinux(uint32(pfd.Events)))
+ }
+
+ r := file.Readiness(waiter.EventMaskFromLinux(uint32(pfd.Events)))
+ pfd.REvents = int16(r.ToLinux()) & pfd.Events
+}
+
+// releaseState releases all the pollState in "state".
+func releaseState(state []pollState) {
+ for i := range state {
+ if state[i].file != nil {
+ state[i].file.EventUnregister(&state[i].waiter)
+ state[i].file.DecRef()
+ }
+ }
+}
+
+// pollBlock polls the PollFDs in "pfd" with a bounded time specified in "timeout"
+// when "timeout" is greater than zero.
+//
+// pollBlock returns the remaining timeout, which is always 0 on a timeout; and 0 or
+// positive if interrupted by a signal.
+func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time.Duration, uintptr, error) {
+ var ch chan struct{}
+ if timeout != 0 {
+ ch = make(chan struct{}, 1)
+ }
+
+ // Register for event notification in the files involved if we may
+ // block (timeout not zero). Once we find a file that has a non-zero
+ // result, we stop registering for events but still go through all files
+ // to get their ready masks.
+ state := make([]pollState, len(pfd))
+ defer releaseState(state)
+ n := uintptr(0)
+ for i := range pfd {
+ initReadiness(t, &pfd[i], &state[i], ch)
+ if pfd[i].REvents != 0 {
+ n++
+ ch = nil
+ }
+ }
+
+ if timeout == 0 {
+ return timeout, n, nil
+ }
+
+ haveTimeout := timeout >= 0
+
+ for n == 0 {
+ var err error
+ // Wait for a notification.
+ timeout, err = t.BlockWithTimeout(ch, haveTimeout, timeout)
+ if err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = nil
+ }
+ return timeout, 0, err
+ }
+
+ // We got notified, count how many files are ready. If none,
+ // then this was a spurious notification, and we just go back
+ // to sleep with the remaining timeout.
+ for i := range state {
+ if state[i].file == nil {
+ continue
+ }
+
+ r := state[i].file.Readiness(waiter.EventMaskFromLinux(uint32(pfd[i].Events)))
+ rl := int16(r.ToLinux()) & pfd[i].Events
+ if rl != 0 {
+ pfd[i].REvents = rl
+ n++
+ }
+ }
+ }
+
+ return timeout, n, nil
+}
+
+// copyInPollFDs copies an array of struct pollfd unless nfds exceeds the max.
+func copyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD, error) {
+ if uint64(nfds) > t.ThreadGroup().Limits().GetCapped(limits.NumberOfFiles, fileCap) {
+ return nil, syserror.EINVAL
+ }
+
+ pfd := make([]linux.PollFD, nfds)
+ if nfds > 0 {
+ if _, err := t.CopyIn(addr, &pfd); err != nil {
+ return nil, err
+ }
+ }
+
+ return pfd, nil
+}
+
+func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) (time.Duration, uintptr, error) {
+ pfd, err := copyInPollFDs(t, addr, nfds)
+ if err != nil {
+ return timeout, 0, err
+ }
+
+ // Compatibility warning: Linux adds POLLHUP and POLLERR just before
+ // polling, in fs/select.c:do_pollfd(). Since pfd is copied out after
+ // polling, changing event masks here is an application-visible difference.
+ // (Linux also doesn't copy out event masks at all, only revents.)
+ for i := range pfd {
+ pfd[i].Events |= linux.POLLHUP | linux.POLLERR
+ }
+ remainingTimeout, n, err := pollBlock(t, pfd, timeout)
+ err = syserror.ConvertIntr(err, syserror.EINTR)
+
+ // The poll entries are copied out regardless of whether
+ // any are set or not. This aligns with the Linux behavior.
+ if nfds > 0 && err == nil {
+ if _, err := t.CopyOut(addr, pfd); err != nil {
+ return remainingTimeout, 0, err
+ }
+ }
+
+ return remainingTimeout, n, err
+}
+
+// CopyInFDSet copies an fd set from select(2)/pselect(2).
+func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) {
+ set := make([]byte, nBytes)
+
+ if addr != 0 {
+ if _, err := t.CopyIn(addr, &set); err != nil {
+ return nil, err
+ }
+ // If we only use part of the last byte, mask out the extraneous bits.
+ //
+ // N.B. This only works on little-endian architectures.
+ if nBitsInLastPartialByte != 0 {
+ set[nBytes-1] &^= byte(0xff) << nBitsInLastPartialByte
+ }
+ }
+ return set, nil
+}
+
+func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Addr, timeout time.Duration) (uintptr, error) {
+ if nfds < 0 || nfds > fileCap {
+ return 0, syserror.EINVAL
+ }
+
+ // Calculate the size of the fd sets (one bit per fd).
+ nBytes := (nfds + 7) / 8
+ nBitsInLastPartialByte := nfds % 8
+
+ // Capture all the provided input vectors.
+ r, err := CopyInFDSet(t, readFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+ w, err := CopyInFDSet(t, writeFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+ e, err := CopyInFDSet(t, exceptFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+
+ // Count how many FDs are actually being requested so that we can build
+ // a PollFD array.
+ fdCount := 0
+ for i := 0; i < nBytes; i++ {
+ v := r[i] | w[i] | e[i]
+ for v != 0 {
+ v &= (v - 1)
+ fdCount++
+ }
+ }
+
+ // Build the PollFD array.
+ pfd := make([]linux.PollFD, 0, fdCount)
+ var fd int32
+ for i := 0; i < nBytes; i++ {
+ rV, wV, eV := r[i], w[i], e[i]
+ v := rV | wV | eV
+ m := byte(1)
+ for j := 0; j < 8; j++ {
+ if (v & m) != 0 {
+ // Make sure the fd is valid and decrement the reference
+ // immediately to ensure we don't leak. Note, another thread
+ // might be about to close fd. This is racy, but that's
+ // OK. Linux is racy in the same way.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ file.DecRef()
+
+ var mask int16
+ if (rV & m) != 0 {
+ mask |= selectReadEvents
+ }
+
+ if (wV & m) != 0 {
+ mask |= selectWriteEvents
+ }
+
+ if (eV & m) != 0 {
+ mask |= selectExceptEvents
+ }
+
+ pfd = append(pfd, linux.PollFD{
+ FD: fd,
+ Events: mask,
+ })
+ }
+
+ fd++
+ m <<= 1
+ }
+ }
+
+ // Do the syscall, then count the number of bits set.
+ if _, _, err = pollBlock(t, pfd, timeout); err != nil {
+ return 0, syserror.ConvertIntr(err, syserror.EINTR)
+ }
+
+ // r, w, and e are currently event mask bitsets; unset bits corresponding
+ // to events that *didn't* occur.
+ bitSetCount := uintptr(0)
+ for idx := range pfd {
+ events := pfd[idx].REvents
+ i, j := pfd[idx].FD/8, uint(pfd[idx].FD%8)
+ m := byte(1) << j
+ if r[i]&m != 0 {
+ if (events & selectReadEvents) != 0 {
+ bitSetCount++
+ } else {
+ r[i] &^= m
+ }
+ }
+ if w[i]&m != 0 {
+ if (events & selectWriteEvents) != 0 {
+ bitSetCount++
+ } else {
+ w[i] &^= m
+ }
+ }
+ if e[i]&m != 0 {
+ if (events & selectExceptEvents) != 0 {
+ bitSetCount++
+ } else {
+ e[i] &^= m
+ }
+ }
+ }
+
+ // Copy updated vectors back.
+ if readFDs != 0 {
+ if _, err := t.CopyOut(readFDs, r); err != nil {
+ return 0, err
+ }
+ }
+
+ if writeFDs != 0 {
+ if _, err := t.CopyOut(writeFDs, w); err != nil {
+ return 0, err
+ }
+ }
+
+ if exceptFDs != 0 {
+ if _, err := t.CopyOut(exceptFDs, e); err != nil {
+ return 0, err
+ }
+ }
+
+ return bitSetCount, nil
+}
+
+// timeoutRemaining returns the amount of time remaining for the specified
+// timeout or 0 if it has elapsed.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func timeoutRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration) time.Duration {
+ now := t.Kernel().MonotonicClock().Now()
+ remaining := timeout - now.Sub(startNs)
+ if remaining < 0 {
+ remaining = 0
+ }
+ return remaining
+}
+
+// copyOutTimespecRemaining copies the time remaining in timeout to timespecAddr.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timespecAddr usermem.Addr) error {
+ if timeout <= 0 {
+ return nil
+ }
+ remaining := timeoutRemaining(t, startNs, timeout)
+ tsRemaining := linux.NsecToTimespec(remaining.Nanoseconds())
+ return tsRemaining.CopyOut(t, timespecAddr)
+}
+
+// copyOutTimevalRemaining copies the time remaining in timeout to timevalAddr.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timevalAddr usermem.Addr) error {
+ if timeout <= 0 {
+ return nil
+ }
+ remaining := timeoutRemaining(t, startNs, timeout)
+ tvRemaining := linux.NsecToTimeval(remaining.Nanoseconds())
+ return tvRemaining.CopyOut(t, timevalAddr)
+}
+
+// pollRestartBlock encapsulates the state required to restart poll(2) via
+// restart_syscall(2).
+//
+// +stateify savable
+type pollRestartBlock struct {
+ pfdAddr usermem.Addr
+ nfds uint
+ timeout time.Duration
+}
+
+// Restart implements kernel.SyscallRestartBlock.Restart.
+func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
+ return poll(t, p.pfdAddr, p.nfds, p.timeout)
+}
+
+func poll(t *kernel.Task, pfdAddr usermem.Addr, nfds uint, timeout time.Duration) (uintptr, error) {
+ remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout)
+ // On an interrupt poll(2) is restarted with the remaining timeout.
+ if err == syserror.EINTR {
+ t.SetSyscallRestartBlock(&pollRestartBlock{
+ pfdAddr: pfdAddr,
+ nfds: nfds,
+ timeout: remainingTimeout,
+ })
+ return 0, kernel.ERESTART_RESTARTBLOCK
+ }
+ return n, err
+}
+
+// Poll implements linux syscall poll(2).
+func Poll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pfdAddr := args[0].Pointer()
+ nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
+ timeout := time.Duration(args[2].Int()) * time.Millisecond
+ n, err := poll(t, pfdAddr, nfds, timeout)
+ return n, nil, err
+}
+
+// Ppoll implements linux syscall ppoll(2).
+func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pfdAddr := args[0].Pointer()
+ nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
+ timespecAddr := args[2].Pointer()
+ maskAddr := args[3].Pointer()
+ maskSize := uint(args[4].Uint())
+
+ timeout, err := copyTimespecInToDuration(t, timespecAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var startNs ktime.Time
+ if timeout > 0 {
+ startNs = t.Kernel().MonotonicClock().Now()
+ }
+
+ if err := setTempSignalSet(t, maskAddr, maskSize); err != nil {
+ return 0, nil, err
+ }
+
+ _, n, err := doPoll(t, pfdAddr, nfds, timeout)
+ copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
+ // doPoll returns EINTR if interrupted, but ppoll is normally restartable
+ // if interrupted by something other than a signal handled by the
+ // application (i.e. returns ERESTARTNOHAND). However, if
+ // copyOutTimespecRemaining failed, then the restarted ppoll would use the
+ // wrong timeout, so the error should be left as EINTR.
+ //
+ // Note that this means that if err is nil but copyErr is not, copyErr is
+ // ignored. This is consistent with Linux.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// Select implements linux syscall select(2).
+func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nfds := int(args[0].Int()) // select(2) uses an int.
+ readFDs := args[1].Pointer()
+ writeFDs := args[2].Pointer()
+ exceptFDs := args[3].Pointer()
+ timevalAddr := args[4].Pointer()
+
+ // Use a negative Duration to indicate "no timeout".
+ timeout := time.Duration(-1)
+ if timevalAddr != 0 {
+ var timeval linux.Timeval
+ if err := timeval.CopyIn(t, timevalAddr); err != nil {
+ return 0, nil, err
+ }
+ if timeval.Sec < 0 || timeval.Usec < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ timeout = time.Duration(timeval.ToNsecCapped())
+ }
+ startNs := t.Kernel().MonotonicClock().Now()
+ n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
+ copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr)
+ // See comment in Ppoll.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// Pselect implements linux syscall pselect(2).
+func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nfds := int(args[0].Int()) // select(2) uses an int.
+ readFDs := args[1].Pointer()
+ writeFDs := args[2].Pointer()
+ exceptFDs := args[3].Pointer()
+ timespecAddr := args[4].Pointer()
+ maskWithSizeAddr := args[5].Pointer()
+
+ timeout, err := copyTimespecInToDuration(t, timespecAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var startNs ktime.Time
+ if timeout > 0 {
+ startNs = t.Kernel().MonotonicClock().Now()
+ }
+
+ if maskWithSizeAddr != 0 {
+ if t.Arch().Width() != 8 {
+ panic(fmt.Sprintf("unsupported sizeof(void*): %d", t.Arch().Width()))
+ }
+ var maskStruct sigSetWithSize
+ if err := maskStruct.CopyIn(t, maskWithSizeAddr); err != nil {
+ return 0, nil, err
+ }
+ if err := setTempSignalSet(t, usermem.Addr(maskStruct.sigsetAddr), uint(maskStruct.sizeofSigset)); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
+ copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
+ // See comment in Ppoll.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// +marshal
+type sigSetWithSize struct {
+ sigsetAddr uint64
+ sizeofSigset uint64
+}
+
+// copyTimespecInToDuration copies a Timespec from the untrusted app range,
+// validates it and converts it to a Duration.
+//
+// If the Timespec is larger than what can be represented in a Duration, the
+// returned value is the maximum that Duration will allow.
+//
+// If timespecAddr is NULL, the returned value is negative.
+func copyTimespecInToDuration(t *kernel.Task, timespecAddr usermem.Addr) (time.Duration, error) {
+ // Use a negative Duration to indicate "no timeout".
+ timeout := time.Duration(-1)
+ if timespecAddr != 0 {
+ var timespec linux.Timespec
+ if err := timespec.CopyIn(t, timespecAddr); err != nil {
+ return 0, err
+ }
+ if !timespec.Valid() {
+ return 0, syserror.EINVAL
+ }
+ timeout = time.Duration(timespec.ToNsecCapped())
+ }
+ return timeout, nil
+}
+
+func setTempSignalSet(t *kernel.Task, maskAddr usermem.Addr, maskSize uint) error {
+ if maskAddr == 0 {
+ return nil
+ }
+ if maskSize != linux.SignalSetSize {
+ return syserror.EINVAL
+ }
+ var mask linux.SignalSet
+ if err := mask.CopyIn(t, maskAddr); err != nil {
+ return err
+ }
+ mask &^= kernel.UnblockableSignals
+ oldmask := t.SignalMask()
+ t.SetSignalMask(mask)
+ t.SetSavedSignalMask(oldmask)
+ return nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go
new file mode 100644
index 000000000..35f6308d6
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go
@@ -0,0 +1,511 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ eventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr
+ eventMaskWrite = waiter.EventOut | waiter.EventHUp | waiter.EventErr
+)
+
+// Read implements Linux syscall read(2).
+func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := read(t, file, dst, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "read", file)
+}
+
+// Readv implements Linux syscall readv(2).
+func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Get the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := read(t, file, dst, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "readv", file)
+}
+
+func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ n, err := file.Read(t, dst, opts)
+ if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 {
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskRead)
+
+ total := n
+ for {
+ // Shorten dst to reflect bytes previously read.
+ dst = dst.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err := file.Read(t, dst, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+ if err := t.Block(ch); err != nil {
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ return total, err
+}
+
+// Pread64 implements Linux syscall pread64(2).
+func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pread(t, file, dst, offset, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pread64", file)
+}
+
+// Preadv implements Linux syscall preadv(2).
+func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pread(t, file, dst, offset, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "preadv", file)
+}
+
+// Preadv2 implements Linux syscall preadv2(2).
+func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // While the glibc signature is
+ // preadv2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags)
+ // the actual syscall
+ // (https://elixir.bootlin.com/linux/v5.5/source/fs/read_write.c#L1142)
+ // splits the offset argument into a high/low value for compatibility with
+ // 32-bit architectures. The flags argument is the 6th argument (index 5).
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+ flags := args[5].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < -1 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.ReadOptions{
+ Flags: uint32(flags),
+ }
+ var n int64
+ if offset == -1 {
+ n, err = read(t, file, dst, opts)
+ } else {
+ n, err = pread(t, file, dst, offset, opts)
+ }
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "preadv2", file)
+}
+
+func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ n, err := file.PRead(t, dst, offset, opts)
+ if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 {
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskRead)
+
+ total := n
+ for {
+ // Shorten dst to reflect bytes previously read.
+ dst = dst.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err := file.PRead(t, dst, offset+total, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+ if err := t.Block(ch); err != nil {
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ return total, err
+}
+
+// Write implements Linux syscall write(2).
+func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := write(t, file, src, vfs.WriteOptions{})
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "write", file)
+}
+
+// Writev implements Linux syscall writev(2).
+func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Get the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := write(t, file, src, vfs.WriteOptions{})
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "writev", file)
+}
+
+func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ n, err := file.Write(t, src, opts)
+ if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 {
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskWrite)
+
+ total := n
+ for {
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err := file.Write(t, src, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+ if err := t.Block(ch); err != nil {
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ return total, err
+}
+
+// Pwrite64 implements Linux syscall pwrite64(2).
+func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pwrite(t, file, src, offset, vfs.WriteOptions{})
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwrite64", file)
+}
+
+// Pwritev implements Linux syscall pwritev(2).
+func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pwrite(t, file, src, offset, vfs.WriteOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwritev", file)
+}
+
+// Pwritev2 implements Linux syscall pwritev2(2).
+func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // While the glibc signature is
+ // pwritev2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags)
+ // the actual syscall
+ // (https://elixir.bootlin.com/linux/v5.5/source/fs/read_write.c#L1162)
+ // splits the offset argument into a high/low value for compatibility with
+ // 32-bit architectures. The flags argument is the 6th argument (index 5).
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+ flags := args[5].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < -1 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.WriteOptions{
+ Flags: uint32(flags),
+ }
+ var n int64
+ if offset == -1 {
+ n, err = write(t, file, src, opts)
+ } else {
+ n, err = pwrite(t, file, src, offset, opts)
+ }
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwritev2", file)
+}
+
+func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ n, err := file.PWrite(t, src, offset, opts)
+ if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 {
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskWrite)
+
+ total := n
+ for {
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err := file.PWrite(t, src, offset+total, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+ if err := t.Block(ch); err != nil {
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ return total, err
+}
+
+// Lseek implements Linux syscall lseek(2).
+func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ whence := args[2].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ newoff, err := file.Seek(t, offset, whence)
+ return uintptr(newoff), nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go
new file mode 100644
index 000000000..9250659ff
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go
@@ -0,0 +1,380 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const chmodMask = 0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX
+
+// Chmod implements Linux syscall chmod(2).
+func Chmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ mode := args[1].ModeT()
+ return 0, nil, fchmodat(t, linux.AT_FDCWD, pathAddr, mode)
+}
+
+// Fchmodat implements Linux syscall fchmodat(2).
+func Fchmodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ mode := args[2].ModeT()
+ return 0, nil, fchmodat(t, dirfd, pathAddr, mode)
+}
+
+func fchmodat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, mode uint) error {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+
+ return setstatat(t, dirfd, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_MODE,
+ Mode: uint16(mode & chmodMask),
+ },
+ })
+}
+
+// Fchmod implements Linux syscall fchmod(2).
+func Fchmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ mode := args[1].ModeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, file.SetStat(t, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_MODE,
+ Mode: uint16(mode & chmodMask),
+ },
+ })
+}
+
+// Chown implements Linux syscall chown(2).
+func Chown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ owner := args[1].Int()
+ group := args[2].Int()
+ return 0, nil, fchownat(t, linux.AT_FDCWD, pathAddr, owner, group, 0 /* flags */)
+}
+
+// Lchown implements Linux syscall lchown(2).
+func Lchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ owner := args[1].Int()
+ group := args[2].Int()
+ return 0, nil, fchownat(t, linux.AT_FDCWD, pathAddr, owner, group, linux.AT_SYMLINK_NOFOLLOW)
+}
+
+// Fchownat implements Linux syscall fchownat(2).
+func Fchownat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ owner := args[2].Int()
+ group := args[3].Int()
+ flags := args[4].Int()
+ return 0, nil, fchownat(t, dirfd, pathAddr, owner, group, flags)
+}
+
+func fchownat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, owner, group, flags int32) error {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForChown(t, owner, group, &opts); err != nil {
+ return err
+ }
+
+ return setstatat(t, dirfd, path, shouldAllowEmptyPath(flags&linux.AT_EMPTY_PATH != 0), shouldFollowFinalSymlink(flags&linux.AT_SYMLINK_NOFOLLOW == 0), &opts)
+}
+
+func populateSetStatOptionsForChown(t *kernel.Task, owner, group int32, opts *vfs.SetStatOptions) error {
+ userns := t.UserNamespace()
+ if owner != -1 {
+ kuid := userns.MapToKUID(auth.UID(owner))
+ if !kuid.Ok() {
+ return syserror.EINVAL
+ }
+ opts.Stat.Mask |= linux.STATX_UID
+ opts.Stat.UID = uint32(kuid)
+ }
+ if group != -1 {
+ kgid := userns.MapToKGID(auth.GID(group))
+ if !kgid.Ok() {
+ return syserror.EINVAL
+ }
+ opts.Stat.Mask |= linux.STATX_GID
+ opts.Stat.GID = uint32(kgid)
+ }
+ return nil
+}
+
+// Fchown implements Linux syscall fchown(2).
+func Fchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ owner := args[1].Int()
+ group := args[2].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForChown(t, owner, group, &opts); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, file.SetStat(t, opts)
+}
+
+// Truncate implements Linux syscall truncate(2).
+func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].Int64()
+
+ if length < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: uint64(length),
+ },
+ })
+}
+
+// Ftruncate implements Linux syscall ftruncate(2).
+func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ length := args[1].Int64()
+
+ if length < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, file.SetStat(t, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: uint64(length),
+ },
+ })
+}
+
+// Utime implements Linux syscall utime(2).
+func Utime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ timesAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_ATIME | linux.STATX_MTIME,
+ },
+ }
+ if timesAddr == 0 {
+ opts.Stat.Atime.Nsec = linux.UTIME_NOW
+ opts.Stat.Mtime.Nsec = linux.UTIME_NOW
+ } else {
+ var times linux.Utime
+ if err := times.CopyIn(t, timesAddr); err != nil {
+ return 0, nil, err
+ }
+ opts.Stat.Atime.Sec = times.Actime
+ opts.Stat.Mtime.Sec = times.Modtime
+ }
+
+ return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &opts)
+}
+
+// Utimes implements Linux syscall utimes(2).
+func Utimes(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ timesAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_ATIME | linux.STATX_MTIME,
+ },
+ }
+ if timesAddr == 0 {
+ opts.Stat.Atime.Nsec = linux.UTIME_NOW
+ opts.Stat.Mtime.Nsec = linux.UTIME_NOW
+ } else {
+ var times [2]linux.Timeval
+ if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ return 0, nil, err
+ }
+ opts.Stat.Atime = linux.StatxTimestamp{
+ Sec: times[0].Sec,
+ Nsec: uint32(times[0].Usec * 1000),
+ }
+ opts.Stat.Mtime = linux.StatxTimestamp{
+ Sec: times[1].Sec,
+ Nsec: uint32(times[1].Usec * 1000),
+ }
+ }
+
+ return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &opts)
+}
+
+// Utimensat implements Linux syscall utimensat(2).
+func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ timesAddr := args[2].Pointer()
+ flags := args[3].Int()
+
+ if flags&^linux.AT_SYMLINK_NOFOLLOW != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForUtimens(t, timesAddr, &opts); err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, setstatat(t, dirfd, path, disallowEmptyPath, followFinalSymlink, &opts)
+}
+
+// Futimens implements Linux syscall futimens(2).
+func Futimens(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ timesAddr := args[1].Pointer()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForUtimens(t, timesAddr, &opts); err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, file.SetStat(t, opts)
+}
+
+func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr usermem.Addr, opts *vfs.SetStatOptions) error {
+ if timesAddr == 0 {
+ opts.Stat.Mask = linux.STATX_ATIME | linux.STATX_MTIME
+ opts.Stat.Atime.Nsec = linux.UTIME_NOW
+ opts.Stat.Mtime.Nsec = linux.UTIME_NOW
+ return nil
+ }
+ var times [2]linux.Timespec
+ if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ return err
+ }
+ if times[0].Nsec != linux.UTIME_OMIT {
+ opts.Stat.Mask |= linux.STATX_ATIME
+ opts.Stat.Atime = linux.StatxTimestamp{
+ Sec: times[0].Sec,
+ Nsec: uint32(times[0].Nsec),
+ }
+ }
+ if times[1].Nsec != linux.UTIME_OMIT {
+ opts.Stat.Mask |= linux.STATX_MTIME
+ opts.Stat.Mtime = linux.StatxTimestamp{
+ Sec: times[1].Sec,
+ Nsec: uint32(times[1].Nsec),
+ }
+ }
+ return nil
+}
+
+func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPath shouldAllowEmptyPath, shouldFollowFinalSymlink shouldFollowFinalSymlink, opts *vfs.SetStatOptions) error {
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ if !path.Absolute {
+ if !path.HasComponents() && !bool(shouldAllowEmptyPath) {
+ return syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ return syserror.EBADF
+ }
+ if !path.HasComponents() {
+ // Use FileDescription.SetStat() instead of
+ // VirtualFilesystem.SetStatAt(), since the former may be able
+ // to use opened file state to expedite the SetStat.
+ err := dirfile.SetStat(t, *opts)
+ dirfile.DecRef()
+ return err
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ defer start.DecRef()
+ dirfile.DecRef()
+ }
+ }
+ return t.Kernel().VFS().SetStatAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: bool(shouldFollowFinalSymlink),
+ }, opts)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go
new file mode 100644
index 000000000..12c532310
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/stat.go
@@ -0,0 +1,323 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Stat implements Linux syscall stat(2).
+func Stat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ statAddr := args[1].Pointer()
+ return 0, nil, fstatat(t, linux.AT_FDCWD, pathAddr, statAddr, 0 /* flags */)
+}
+
+// Lstat implements Linux syscall lstat(2).
+func Lstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ statAddr := args[1].Pointer()
+ return 0, nil, fstatat(t, linux.AT_FDCWD, pathAddr, statAddr, linux.AT_SYMLINK_NOFOLLOW)
+}
+
+// Newfstatat implements Linux syscall newfstatat, which backs fstatat(2).
+func Newfstatat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ statAddr := args[2].Pointer()
+ flags := args[3].Int()
+ return 0, nil, fstatat(t, dirfd, pathAddr, statAddr, flags)
+}
+
+func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags int32) error {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return syserror.EINVAL
+ }
+
+ opts := vfs.StatOptions{
+ Mask: linux.STATX_BASIC_STATS,
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ if !path.Absolute {
+ if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
+ return syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ return syserror.EBADF
+ }
+ if !path.HasComponents() {
+ // Use FileDescription.Stat() instead of
+ // VirtualFilesystem.StatAt() for fstatat(fd, ""), since the
+ // former may be able to use opened file state to expedite the
+ // Stat.
+ statx, err := dirfile.Stat(t, opts)
+ dirfile.DecRef()
+ if err != nil {
+ return err
+ }
+ var stat linux.Stat
+ convertStatxToUserStat(t, &statx, &stat)
+ return stat.CopyOut(t, statAddr)
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ defer start.DecRef()
+ dirfile.DecRef()
+ }
+ }
+
+ statx, err := t.Kernel().VFS().StatAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ }, &opts)
+ if err != nil {
+ return err
+ }
+ var stat linux.Stat
+ convertStatxToUserStat(t, &statx, &stat)
+ return stat.CopyOut(t, statAddr)
+}
+
+func timespecFromStatxTimestamp(sxts linux.StatxTimestamp) linux.Timespec {
+ return linux.Timespec{
+ Sec: sxts.Sec,
+ Nsec: int64(sxts.Nsec),
+ }
+}
+
+// Fstat implements Linux syscall fstat(2).
+func Fstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ statAddr := args[1].Pointer()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ statx, err := file.Stat(t, vfs.StatOptions{
+ Mask: linux.STATX_BASIC_STATS,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ var stat linux.Stat
+ convertStatxToUserStat(t, &statx, &stat)
+ return 0, nil, stat.CopyOut(t, statAddr)
+}
+
+// Statx implements Linux syscall statx(2).
+func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ flags := args[2].Int()
+ mask := args[3].Uint()
+ statxAddr := args[4].Pointer()
+
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ opts := vfs.StatOptions{
+ Mask: mask,
+ Sync: uint32(flags & linux.AT_STATX_SYNC_TYPE),
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ if !path.Absolute {
+ if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
+ return 0, nil, syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ if !path.HasComponents() {
+ // Use FileDescription.Stat() instead of
+ // VirtualFilesystem.StatAt() for statx(fd, ""), since the
+ // former may be able to use opened file state to expedite the
+ // Stat.
+ statx, err := dirfile.Stat(t, opts)
+ dirfile.DecRef()
+ if err != nil {
+ return 0, nil, err
+ }
+ userifyStatx(t, &statx)
+ return 0, nil, statx.CopyOut(t, statxAddr)
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ defer start.DecRef()
+ dirfile.DecRef()
+ }
+ }
+
+ statx, err := t.Kernel().VFS().StatAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ }, &opts)
+ if err != nil {
+ return 0, nil, err
+ }
+ userifyStatx(t, &statx)
+ return 0, nil, statx.CopyOut(t, statxAddr)
+}
+
+func userifyStatx(t *kernel.Task, statx *linux.Statx) {
+ userns := t.UserNamespace()
+ statx.UID = uint32(auth.KUID(statx.UID).In(userns).OrOverflow())
+ statx.GID = uint32(auth.KGID(statx.GID).In(userns).OrOverflow())
+}
+
+// Readlink implements Linux syscall readlink(2).
+func Readlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ bufAddr := args[1].Pointer()
+ size := args[2].SizeT()
+ return readlinkat(t, linux.AT_FDCWD, pathAddr, bufAddr, size)
+}
+
+// Access implements Linux syscall access(2).
+func Access(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // FIXME(jamieliu): actually implement
+ return 0, nil, nil
+}
+
+// Faccessat implements Linux syscall access(2).
+func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // FIXME(jamieliu): actually implement
+ return 0, nil, nil
+}
+
+// Readlinkat implements Linux syscall mknodat(2).
+func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ bufAddr := args[2].Pointer()
+ size := args[3].SizeT()
+ return readlinkat(t, dirfd, pathAddr, bufAddr, size)
+}
+
+func readlinkat(t *kernel.Task, dirfd int32, pathAddr, bufAddr usermem.Addr, size uint) (uintptr, *kernel.SyscallControl, error) {
+ if int(size) <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ // "Since Linux 2.6.39, pathname can be an empty string, in which case the
+ // call operates on the symbolic link referred to by dirfd ..." -
+ // readlinkat(2)
+ tpop, err := getTaskPathOperation(t, dirfd, path, allowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ target, err := t.Kernel().VFS().ReadlinkAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if len(target) > int(size) {
+ target = target[:size]
+ }
+ n, err := t.CopyOutBytes(bufAddr, gohacks.ImmutableBytesFromString(target))
+ if n == 0 {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Statfs implements Linux syscall statfs(2).
+func Statfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ bufAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, statfs.CopyOut(t, bufAddr)
+}
+
+// Fstatfs implements Linux syscall fstatfs(2).
+func Fstatfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ bufAddr := args[1].Pointer()
+
+ tpop, err := getTaskPathOperation(t, fd, fspath.Path{}, allowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, statfs.CopyOut(t, bufAddr)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat_amd64.go b/pkg/sentry/syscalls/linux/vfs2/stat_amd64.go
new file mode 100644
index 000000000..2da538fc6
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/stat_amd64.go
@@ -0,0 +1,46 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// This takes both input and output as pointer arguments to avoid copying large
+// structs.
+func convertStatxToUserStat(t *kernel.Task, statx *linux.Statx, stat *linux.Stat) {
+ // Linux just copies fields from struct kstat without regard to struct
+ // kstat::result_mask (fs/stat.c:cp_new_stat()), so we do too.
+ userns := t.UserNamespace()
+ *stat = linux.Stat{
+ Dev: uint64(linux.MakeDeviceID(uint16(statx.DevMajor), statx.DevMinor)),
+ Ino: statx.Ino,
+ Nlink: uint64(statx.Nlink),
+ Mode: uint32(statx.Mode),
+ UID: uint32(auth.KUID(statx.UID).In(userns).OrOverflow()),
+ GID: uint32(auth.KGID(statx.GID).In(userns).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(uint16(statx.RdevMajor), statx.RdevMinor)),
+ Size: int64(statx.Size),
+ Blksize: int64(statx.Blksize),
+ Blocks: int64(statx.Blocks),
+ ATime: timespecFromStatxTimestamp(statx.Atime),
+ MTime: timespecFromStatxTimestamp(statx.Mtime),
+ CTime: timespecFromStatxTimestamp(statx.Ctime),
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat_arm64.go b/pkg/sentry/syscalls/linux/vfs2/stat_arm64.go
new file mode 100644
index 000000000..88b9c7627
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/stat_arm64.go
@@ -0,0 +1,46 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// This takes both input and output as pointer arguments to avoid copying large
+// structs.
+func convertStatxToUserStat(t *kernel.Task, statx *linux.Statx, stat *linux.Stat) {
+ // Linux just copies fields from struct kstat without regard to struct
+ // kstat::result_mask (fs/stat.c:cp_new_stat()), so we do too.
+ userns := t.UserNamespace()
+ *stat = linux.Stat{
+ Dev: uint64(linux.MakeDeviceID(uint16(statx.DevMajor), statx.DevMinor)),
+ Ino: statx.Ino,
+ Nlink: uint32(statx.Nlink),
+ Mode: uint32(statx.Mode),
+ UID: uint32(auth.KUID(statx.UID).In(userns).OrOverflow()),
+ GID: uint32(auth.KGID(statx.GID).In(userns).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(uint16(statx.RdevMajor), statx.RdevMinor)),
+ Size: int64(statx.Size),
+ Blksize: int32(statx.Blksize),
+ Blocks: int64(statx.Blocks),
+ ATime: timespecFromStatxTimestamp(statx.Atime),
+ MTime: timespecFromStatxTimestamp(statx.Mtime),
+ CTime: timespecFromStatxTimestamp(statx.Ctime),
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go
new file mode 100644
index 000000000..365250b0b
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/sync.go
@@ -0,0 +1,87 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Sync implements Linux syscall sync(2).
+func Sync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, t.Kernel().VFS().SyncAllFilesystems(t)
+}
+
+// Syncfs implements Linux syscall syncfs(2).
+func Syncfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, file.SyncFS(t)
+}
+
+// Fsync implements Linux syscall fsync(2).
+func Fsync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, file.Sync(t)
+}
+
+// Fdatasync implements Linux syscall fdatasync(2).
+func Fdatasync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // TODO(gvisor.dev/issue/1897): Avoid writeback of unnecessary metadata.
+ return Fsync(t, args)
+}
+
+// SyncFileRange implements Linux syscall sync_file_range(2).
+func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ nbytes := args[2].Int64()
+ flags := args[3].Uint()
+
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if nbytes < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if flags&^(linux.SYNC_FILE_RANGE_WAIT_BEFORE|linux.SYNC_FILE_RANGE_WRITE|linux.SYNC_FILE_RANGE_WAIT_AFTER) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // TODO(gvisor.dev/issue/1897): Avoid writeback of data ranges outside of
+ // [offset, offset+nbytes).
+ return 0, nil, file.Sync(t)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/sys_read.go b/pkg/sentry/syscalls/linux/vfs2/sys_read.go
deleted file mode 100644
index 7667524c7..000000000
--- a/pkg/sentry/syscalls/linux/vfs2/sys_read.go
+++ /dev/null
@@ -1,95 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package vfs2
-
-import (
- "gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- // EventMaskRead contains events that can be triggered on reads.
- EventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr
-)
-
-// Read implements linux syscall read(2). Note that we try to get a buffer that
-// is exactly the size requested because some applications like qemu expect
-// they can do large reads all at once. Bug for bug. Same for other read
-// calls below.
-func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- fd := args[0].Int()
- addr := args[1].Pointer()
- size := args[2].SizeT()
-
- file := t.GetFileVFS2(fd)
- if file == nil {
- return 0, nil, syserror.EBADF
- }
- defer file.DecRef()
-
- // Check that the size is legitimate.
- si := int(size)
- if si < 0 {
- return 0, nil, syserror.EINVAL
- }
-
- // Get the destination of the read.
- dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
- AddressSpaceActive: true,
- })
- if err != nil {
- return 0, nil, err
- }
-
- n, err := read(t, file, dst, vfs.ReadOptions{})
- t.IOUsage().AccountReadSyscall(n)
- return uintptr(n), nil, linux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "read", file)
-}
-
-func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
- n, err := file.Read(t, dst, opts)
- if err != syserror.ErrWouldBlock {
- return n, err
- }
-
- // Register for notifications.
- w, ch := waiter.NewChannelEntry(nil)
- file.EventRegister(&w, EventMaskRead)
-
- total := n
- for {
- // Shorten dst to reflect bytes previously read.
- dst = dst.DropFirst(int(n))
-
- // Issue the request and break out if it completes with anything other than
- // "would block".
- n, err := file.Read(t, dst, opts)
- total += n
- if err != syserror.ErrWouldBlock {
- break
- }
- if err := t.Block(ch); err != nil {
- break
- }
- }
- file.EventUnregister(&w)
-
- return total, err
-}
diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go
new file mode 100644
index 000000000..89e9ff4d7
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go
@@ -0,0 +1,353 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "bytes"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Listxattr implements Linux syscall listxattr(2).
+func Listxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return listxattr(t, args, followFinalSymlink)
+}
+
+// Llistxattr implements Linux syscall llistxattr(2).
+func Llistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return listxattr(t, args, nofollowFinalSymlink)
+}
+
+func listxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ listAddr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ names, err := t.Kernel().VFS().ListxattrAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrNameList(t, listAddr, size, names)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Flistxattr implements Linux syscall flistxattr(2).
+func Flistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ listAddr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ names, err := file.Listxattr(t)
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrNameList(t, listAddr, size, names)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Getxattr implements Linux syscall getxattr(2).
+func Getxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getxattr(t, args, followFinalSymlink)
+}
+
+// Lgetxattr implements Linux syscall lgetxattr(2).
+func Lgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getxattr(t, args, nofollowFinalSymlink)
+}
+
+func getxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ value, err := t.Kernel().VFS().GetxattrAt(t, t.Credentials(), &tpop.pop, name)
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrValue(t, valueAddr, size, value)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Fgetxattr implements Linux syscall fgetxattr(2).
+func Fgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ value, err := file.Getxattr(t, name)
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrValue(t, valueAddr, size, value)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Setxattr implements Linux syscall setxattr(2).
+func Setxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, setxattr(t, args, followFinalSymlink)
+}
+
+// Lsetxattr implements Linux syscall lsetxattr(2).
+func Lsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, setxattr(t, args, nofollowFinalSymlink)
+}
+
+func setxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) error {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+ flags := args[4].Int()
+
+ if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 {
+ return syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return err
+ }
+ value, err := copyInXattrValue(t, valueAddr, size)
+ if err != nil {
+ return err
+ }
+
+ return t.Kernel().VFS().SetxattrAt(t, t.Credentials(), &tpop.pop, &vfs.SetxattrOptions{
+ Name: name,
+ Value: value,
+ Flags: uint32(flags),
+ })
+}
+
+// Fsetxattr implements Linux syscall fsetxattr(2).
+func Fsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+ flags := args[4].Int()
+
+ if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ value, err := copyInXattrValue(t, valueAddr, size)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, file.Setxattr(t, vfs.SetxattrOptions{
+ Name: name,
+ Value: value,
+ Flags: uint32(flags),
+ })
+}
+
+// Removexattr implements Linux syscall removexattr(2).
+func Removexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, removexattr(t, args, followFinalSymlink)
+}
+
+// Lremovexattr implements Linux syscall lremovexattr(2).
+func Lremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, removexattr(t, args, nofollowFinalSymlink)
+}
+
+func removexattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) error {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return err
+ }
+
+ return t.Kernel().VFS().RemovexattrAt(t, t.Credentials(), &tpop.pop, name)
+}
+
+// Fremovexattr implements Linux syscall fremovexattr(2).
+func Fremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, file.Removexattr(t, name)
+}
+
+func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) {
+ name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1)
+ if err != nil {
+ if err == syserror.ENAMETOOLONG {
+ return "", syserror.ERANGE
+ }
+ return "", err
+ }
+ if len(name) == 0 {
+ return "", syserror.ERANGE
+ }
+ return name, nil
+}
+
+func copyOutXattrNameList(t *kernel.Task, listAddr usermem.Addr, size uint, names []string) (int, error) {
+ if size > linux.XATTR_LIST_MAX {
+ size = linux.XATTR_LIST_MAX
+ }
+ var buf bytes.Buffer
+ for _, name := range names {
+ buf.WriteString(name)
+ buf.WriteByte(0)
+ }
+ if size == 0 {
+ // Return the size that would be required to accomodate the list.
+ return buf.Len(), nil
+ }
+ if buf.Len() > int(size) {
+ if size >= linux.XATTR_LIST_MAX {
+ return 0, syserror.E2BIG
+ }
+ return 0, syserror.ERANGE
+ }
+ return t.CopyOutBytes(listAddr, buf.Bytes())
+}
+
+func copyInXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint) (string, error) {
+ if size > linux.XATTR_SIZE_MAX {
+ return "", syserror.E2BIG
+ }
+ buf := make([]byte, size)
+ if _, err := t.CopyInBytes(valueAddr, buf); err != nil {
+ return "", err
+ }
+ return gohacks.StringFromImmutableBytes(buf), nil
+}
+
+func copyOutXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint, value string) (int, error) {
+ if size > linux.XATTR_SIZE_MAX {
+ size = linux.XATTR_SIZE_MAX
+ }
+ if size == 0 {
+ // Return the size that would be required to accomodate the value.
+ return len(value), nil
+ }
+ if len(value) > int(size) {
+ if size >= linux.XATTR_SIZE_MAX {
+ return 0, syserror.E2BIG
+ }
+ return 0, syserror.ERANGE
+ }
+ return t.CopyOutBytes(valueAddr, gohacks.ImmutableBytesFromString(value))
+}
diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go
index 538c645eb..4320ad17f 100644
--- a/pkg/sentry/usage/memory.go
+++ b/pkg/sentry/usage/memory.go
@@ -253,6 +253,10 @@ func (m *MemoryLocked) Copy() (MemoryStats, uint64) {
}
// MinimumTotalMemoryBytes is the minimum reported total system memory.
+//
+// This can be configured through options provided to the Sentry at start.
+// This number is purely synthetic. This is only set before the application
+// starts executing, and must not be modified.
var MinimumTotalMemoryBytes uint64 = 2 << 30 // 2 GB
// TotalMemory returns the "total usable memory" available.
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index ced9d07b1..07c8383e6 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -43,7 +43,10 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/fspath",
+ "//pkg/gohacks",
+ "//pkg/log",
"//pkg/sentry/arch",
+ "//pkg/sentry/fs/lock",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/memmap",
"//pkg/sync",
diff --git a/pkg/sentry/vfs/context.go b/pkg/sentry/vfs/context.go
index d97362b9a..82781e6d3 100644
--- a/pkg/sentry/vfs/context.go
+++ b/pkg/sentry/vfs/context.go
@@ -29,9 +29,10 @@ const (
CtxRoot
)
-// MountNamespaceFromContext returns the MountNamespace used by ctx. It does
-// not take a reference on the returned MountNamespace. If ctx is not
-// associated with a MountNamespace, MountNamespaceFromContext returns nil.
+// MountNamespaceFromContext returns the MountNamespace used by ctx. If ctx is
+// not associated with a MountNamespace, MountNamespaceFromContext returns nil.
+//
+// A reference is taken on the returned MountNamespace.
func MountNamespaceFromContext(ctx context.Context) *MountNamespace {
if v := ctx.Value(CtxMountNamespace); v != nil {
return v.(*MountNamespace)
diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go
index 486a76475..35b208721 100644
--- a/pkg/sentry/vfs/dentry.go
+++ b/pkg/sentry/vfs/dentry.go
@@ -71,6 +71,8 @@ import (
// lifetime. Dentry reference counts only indicate the extent to which VFS
// requires Dentries to exist; Filesystems may elect to cache or discard
// Dentries with zero references.
+//
+// +stateify savable
type Dentry struct {
// parent is this Dentry's parent in this Filesystem. If this Dentry is
// independent, parent is nil.
@@ -89,7 +91,7 @@ type Dentry struct {
children map[string]*Dentry
// mu synchronizes disowning and mounting over this Dentry.
- mu sync.Mutex
+ mu sync.Mutex `state:"nosave"`
// impl is the DentryImpl associated with this Dentry. impl is immutable.
// This should be the last field in Dentry.
diff --git a/pkg/sentry/vfs/device.go b/pkg/sentry/vfs/device.go
index 3af2aa58d..bda5576fa 100644
--- a/pkg/sentry/vfs/device.go
+++ b/pkg/sentry/vfs/device.go
@@ -56,6 +56,7 @@ type Device interface {
Open(ctx context.Context, mnt *Mount, d *Dentry, opts OpenOptions) (*FileDescription, error)
}
+// +stateify savable
type registeredDevice struct {
dev Device
opts RegisterDeviceOptions
@@ -63,6 +64,8 @@ type registeredDevice struct {
// RegisterDeviceOptions contains options to
// VirtualFilesystem.RegisterDevice().
+//
+// +stateify savable
type RegisterDeviceOptions struct {
// GroupName is the name shown for this device registration in
// /proc/devices. If GroupName is empty, this registration will not be
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
index 7c83f9a5a..3da45d744 100644
--- a/pkg/sentry/vfs/epoll.go
+++ b/pkg/sentry/vfs/epoll.go
@@ -85,8 +85,8 @@ type epollInterest struct {
ready bool
epollInterestEntry
- // userData is the epoll_data_t associated with this epollInterest.
- // userData is protected by epoll.mu.
+ // userData is the struct epoll_event::data associated with this
+ // epollInterest. userData is protected by epoll.mu.
userData [2]int32
}
@@ -157,7 +157,7 @@ func (ep *EpollInstance) Seek(ctx context.Context, offset int64, whence int32) (
// AddInterest implements the semantics of EPOLL_CTL_ADD.
//
// Preconditions: A reference must be held on file.
-func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, mask uint32, userData [2]int32) error {
+func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, event linux.EpollEvent) error {
// Check for cyclic polling if necessary.
subep, _ := file.impl.(*EpollInstance)
if subep != nil {
@@ -183,12 +183,12 @@ func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, mask uint
}
// Register interest in file.
- mask |= linux.EPOLLERR | linux.EPOLLRDHUP
+ mask := event.Events | linux.EPOLLERR | linux.EPOLLRDHUP
epi := &epollInterest{
epoll: ep,
key: key,
mask: mask,
- userData: userData,
+ userData: event.Data,
}
ep.interest[key] = epi
wmask := waiter.EventMaskFromLinux(mask)
@@ -202,6 +202,9 @@ func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, mask uint
// Add epi to file.epolls so that it is removed when the last
// FileDescription reference is dropped.
file.epollMu.Lock()
+ if file.epolls == nil {
+ file.epolls = make(map[*epollInterest]struct{})
+ }
file.epolls[epi] = struct{}{}
file.epollMu.Unlock()
@@ -236,7 +239,7 @@ func (ep *EpollInstance) mightPollRecursive(ep2 *EpollInstance, remainingRecursi
// ModifyInterest implements the semantics of EPOLL_CTL_MOD.
//
// Preconditions: A reference must be held on file.
-func (ep *EpollInstance) ModifyInterest(file *FileDescription, num int32, mask uint32, userData [2]int32) error {
+func (ep *EpollInstance) ModifyInterest(file *FileDescription, num int32, event linux.EpollEvent) error {
ep.interestMu.Lock()
defer ep.interestMu.Unlock()
@@ -250,13 +253,13 @@ func (ep *EpollInstance) ModifyInterest(file *FileDescription, num int32, mask u
}
// Update epi for the next call to ep.ReadEvents().
+ mask := event.Events | linux.EPOLLERR | linux.EPOLLRDHUP
ep.mu.Lock()
epi.mask = mask
- epi.userData = userData
+ epi.userData = event.Data
ep.mu.Unlock()
// Re-register with the new mask.
- mask |= linux.EPOLLERR | linux.EPOLLRDHUP
file.EventUnregister(&epi.waiter)
wmask := waiter.EventMaskFromLinux(mask)
file.EventRegister(&epi.waiter, wmask)
@@ -363,8 +366,7 @@ func (ep *EpollInstance) ReadEvents(events []linux.EpollEvent) int {
// Report ievents.
events[i] = linux.EpollEvent{
Events: ievents.ToLinux(),
- Fd: epi.userData[0],
- Data: epi.userData[1],
+ Data: epi.userData,
}
i++
if i == len(events) {
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index badacb55e..9a1ad630c 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sync"
@@ -393,7 +394,25 @@ type FileDescriptionImpl interface {
// Removexattr removes the given extended attribute from the file.
Removexattr(ctx context.Context, name string) error
- // TODO: file locking
+ // LockBSD tries to acquire a BSD-style advisory file lock.
+ //
+ // TODO(gvisor.dev/issue/1480): BSD-style file locking
+ LockBSD(ctx context.Context, uid lock.UniqueID, t lock.LockType, block lock.Blocker) error
+
+ // LockBSD releases a BSD-style advisory file lock.
+ //
+ // TODO(gvisor.dev/issue/1480): BSD-style file locking
+ UnlockBSD(ctx context.Context, uid lock.UniqueID) error
+
+ // LockPOSIX tries to acquire a POSIX-style advisory file lock.
+ //
+ // TODO(gvisor.dev/issue/1480): POSIX-style file locking
+ LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, rng lock.LockRange, block lock.Blocker) error
+
+ // UnlockPOSIX releases a POSIX-style advisory file lock.
+ //
+ // TODO(gvisor.dev/issue/1480): POSIX-style file locking
+ UnlockPOSIX(ctx context.Context, uid lock.UniqueID, rng lock.LockRange) error
}
// Dirent holds the information contained in struct linux_dirent64.
@@ -416,11 +435,11 @@ type Dirent struct {
// IterDirentsCallback receives Dirents from FileDescriptionImpl.IterDirents.
type IterDirentsCallback interface {
- // Handle handles the given iterated Dirent. It returns true if iteration
- // should continue, and false if FileDescriptionImpl.IterDirents should
- // terminate now and restart with the same Dirent the next time it is
- // called.
- Handle(dirent Dirent) bool
+ // Handle handles the given iterated Dirent. If Handle returns a non-nil
+ // error, FileDescriptionImpl.IterDirents must stop iteration and return
+ // the error; the next call to FileDescriptionImpl.IterDirents should
+ // restart with the same Dirent.
+ Handle(dirent Dirent) error
}
// OnClose is called when a file descriptor representing the FileDescription is
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index a4900c170..c2a52ec1b 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -152,6 +153,26 @@ func (FileDescriptionDefaultImpl) Removexattr(ctx context.Context, name string)
return syserror.ENOTSUP
}
+// LockBSD implements FileDescriptionImpl.LockBSD.
+func (FileDescriptionDefaultImpl) LockBSD(ctx context.Context, uid lock.UniqueID, t lock.LockType, block lock.Blocker) error {
+ return syserror.EBADF
+}
+
+// UnlockBSD implements FileDescriptionImpl.UnlockBSD.
+func (FileDescriptionDefaultImpl) UnlockBSD(ctx context.Context, uid lock.UniqueID) error {
+ return syserror.EBADF
+}
+
+// LockPOSIX implements FileDescriptionImpl.LockPOSIX.
+func (FileDescriptionDefaultImpl) LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, rng lock.LockRange, block lock.Blocker) error {
+ return syserror.EBADF
+}
+
+// UnlockPOSIX implements FileDescriptionImpl.UnlockPOSIX.
+func (FileDescriptionDefaultImpl) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, rng lock.LockRange) error {
+ return syserror.EBADF
+}
+
// DirectoryFileDescriptionDefaultImpl may be embedded by implementations of
// FileDescriptionImpl that always represent directories to obtain
// implementations of non-directory I/O methods that return EISDIR.
diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go
index 8fa26418e..3a75d4d62 100644
--- a/pkg/sentry/vfs/file_description_impl_util_test.go
+++ b/pkg/sentry/vfs/file_description_impl_util_test.go
@@ -107,7 +107,10 @@ func (fd *testFD) SetStat(ctx context.Context, opts SetStatOptions) error {
func TestGenCountFD(t *testing.T) {
ctx := contexttest.Context(t)
- vfsObj := New() // vfs.New()
+ vfsObj := &VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
fd := newTestFD(vfsObj, linux.O_RDWR, &genCount{})
defer fd.DecRef()
@@ -162,7 +165,10 @@ func TestGenCountFD(t *testing.T) {
func TestWritable(t *testing.T) {
ctx := contexttest.Context(t)
- vfsObj := New() // vfs.New()
+ vfsObj := &VirtualFilesystem{}
+ if err := vfsObj.Init(); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
fd := newTestFD(vfsObj, linux.O_RDWR, &storeData{data: "init"})
defer fd.DecRef()
diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go
index a06a6caf3..556976d0b 100644
--- a/pkg/sentry/vfs/filesystem.go
+++ b/pkg/sentry/vfs/filesystem.go
@@ -29,6 +29,8 @@ import (
// Filesystem methods require that a reference is held.
//
// Filesystem is analogous to Linux's struct super_block.
+//
+// +stateify savable
type Filesystem struct {
// refs is the reference count. refs is accessed using atomic memory
// operations.
diff --git a/pkg/sentry/vfs/filesystem_type.go b/pkg/sentry/vfs/filesystem_type.go
index c58b70728..bb9cada81 100644
--- a/pkg/sentry/vfs/filesystem_type.go
+++ b/pkg/sentry/vfs/filesystem_type.go
@@ -44,6 +44,7 @@ type GetFilesystemOptions struct {
InternalData interface{}
}
+// +stateify savable
type registeredFilesystemType struct {
fsType FilesystemType
opts RegisterFilesystemTypeOptions
diff --git a/pkg/sentry/vfs/lock/BUILD b/pkg/sentry/vfs/lock/BUILD
new file mode 100644
index 000000000..d9ab063b7
--- /dev/null
+++ b/pkg/sentry/vfs/lock/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "lock",
+ srcs = ["lock.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/sentry/fs/lock",
+ "//pkg/syserror",
+ ],
+)
diff --git a/pkg/sentry/vfs/lock/lock.go b/pkg/sentry/vfs/lock/lock.go
new file mode 100644
index 000000000..724dfe743
--- /dev/null
+++ b/pkg/sentry/vfs/lock/lock.go
@@ -0,0 +1,72 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package lock provides POSIX and BSD style file locking for VFS2 file
+// implementations.
+//
+// The actual implementations can be found in the lock package under
+// sentry/fs/lock.
+package lock
+
+import (
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// FileLocks supports POSIX and BSD style locks, which correspond to fcntl(2)
+// and flock(2) respectively in Linux. It can be embedded into various file
+// implementations for VFS2 that support locking.
+//
+// Note that in Linux these two types of locks are _not_ cooperative, because
+// race and deadlock conditions make merging them prohibitive. We do the same
+// and keep them oblivious to each other.
+type FileLocks struct {
+ // bsd is a set of BSD-style advisory file wide locks, see flock(2).
+ bsd fslock.Locks
+
+ // posix is a set of POSIX-style regional advisory locks, see fcntl(2).
+ posix fslock.Locks
+}
+
+// LockBSD tries to acquire a BSD-style lock on the entire file.
+func (fl *FileLocks) LockBSD(uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error {
+ if fl.bsd.LockRegion(uid, t, fslock.LockRange{0, fslock.LockEOF}, block) {
+ return nil
+ }
+ return syserror.ErrWouldBlock
+}
+
+// UnlockBSD releases a BSD-style lock on the entire file.
+//
+// This operation is always successful, even if there did not exist a lock on
+// the requested region held by uid in the first place.
+func (fl *FileLocks) UnlockBSD(uid fslock.UniqueID) {
+ fl.bsd.UnlockRegion(uid, fslock.LockRange{0, fslock.LockEOF})
+}
+
+// LockPOSIX tries to acquire a POSIX-style lock on a file region.
+func (fl *FileLocks) LockPOSIX(uid fslock.UniqueID, t fslock.LockType, rng fslock.LockRange, block fslock.Blocker) error {
+ if fl.posix.LockRegion(uid, t, rng, block) {
+ return nil
+ }
+ return syserror.ErrWouldBlock
+}
+
+// UnlockPOSIX releases a POSIX-style lock on a file region.
+//
+// This operation is always successful, even if there did not exist a lock on
+// the requested region held by uid in the first place.
+func (fl *FileLocks) UnlockPOSIX(uid fslock.UniqueID, rng fslock.LockRange) {
+ fl.posix.UnlockRegion(uid, rng)
+}
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index d39528051..31a4e5480 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -38,6 +38,8 @@ import (
//
// Mount is analogous to Linux's struct mount. (gVisor does not distinguish
// between struct mount and struct vfsmount.)
+//
+// +stateify savable
type Mount struct {
// vfs, fs, and root are immutable. References are held on fs and root.
//
@@ -85,6 +87,8 @@ type Mount struct {
// MountNamespace methods require that a reference is held.
//
// MountNamespace is analogous to Linux's struct mnt_namespace.
+//
+// +stateify savable
type MountNamespace struct {
// root is the MountNamespace's root mount. root is immutable.
root *Mount
@@ -114,6 +118,7 @@ type MountNamespace struct {
func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *GetFilesystemOptions) (*MountNamespace, error) {
rft := vfs.getFilesystemType(fsTypeName)
if rft == nil {
+ ctx.Warningf("Unknown filesystem: %s", fsTypeName)
return nil, syserror.ENODEV
}
fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, *opts)
@@ -134,6 +139,23 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth
return mntns, nil
}
+// NewDisconnectedMount returns a Mount representing fs with the given root
+// (which may be nil). The new Mount is not associated with any MountNamespace
+// and is not connected to any other Mounts. References are taken on fs and
+// root.
+func (vfs *VirtualFilesystem) NewDisconnectedMount(fs *Filesystem, root *Dentry, opts *MountOptions) (*Mount, error) {
+ fs.IncRef()
+ if root != nil {
+ root.IncRef()
+ }
+ return &Mount{
+ vfs: vfs,
+ fs: fs,
+ root: root,
+ refs: 1,
+ }, nil
+}
+
// MountAt creates and mounts a Filesystem configured by the given arguments.
func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) error {
rft := vfs.getFilesystemType(fsTypeName)
@@ -231,9 +253,12 @@ func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credenti
return syserror.EINVAL
}
vfs.mountMu.Lock()
- if mntns := MountNamespaceFromContext(ctx); mntns != nil && mntns != vd.mount.ns {
- vfs.mountMu.Unlock()
- return syserror.EINVAL
+ if mntns := MountNamespaceFromContext(ctx); mntns != nil {
+ defer mntns.DecRef()
+ if mntns != vd.mount.ns {
+ vfs.mountMu.Unlock()
+ return syserror.EINVAL
+ }
}
// TODO(jamieliu): Linux special-cases umount of the caller's root, which
@@ -423,7 +448,8 @@ func (mntns *MountNamespace) IncRef() {
}
// DecRef decrements mntns' reference count.
-func (mntns *MountNamespace) DecRef(vfs *VirtualFilesystem) {
+func (mntns *MountNamespace) DecRef() {
+ vfs := mntns.root.fs.VirtualFilesystem()
if refs := atomic.AddInt64(&mntns.refs, -1); refs == 0 {
vfs.mountMu.Lock()
vfs.mounts.seq.BeginWrite()
diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go
index bd90d36c4..bc7581698 100644
--- a/pkg/sentry/vfs/mount_unsafe.go
+++ b/pkg/sentry/vfs/mount_unsafe.go
@@ -26,6 +26,7 @@ import (
"sync/atomic"
"unsafe"
+ "gvisor.dev/gvisor/pkg/gohacks"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -64,6 +65,8 @@ func (mnt *Mount) storeKey(vd VirtualDentry) {
// (provided mutation is sufficiently uncommon).
//
// mountTable.Init() must be called on new mountTables before use.
+//
+// +stateify savable
type mountTable struct {
// mountTable is implemented as a seqcount-protected hash table that
// resolves collisions with linear probing, featuring Robin Hood insertion
@@ -75,8 +78,8 @@ type mountTable struct {
// intrinsics and inline assembly, limiting the performance of this
// approach.)
- seq sync.SeqCount
- seed uint32 // for hashing keys
+ seq sync.SeqCount `state:"nosave"`
+ seed uint32 // for hashing keys
// size holds both length (number of elements) and capacity (number of
// slots): capacity is stored as its base-2 log (referred to as order) in
@@ -89,7 +92,7 @@ type mountTable struct {
// length and cap in separate uint32s) for ~free.
size uint64
- slots unsafe.Pointer // []mountSlot; never nil after Init
+ slots unsafe.Pointer `state:"nosave"` // []mountSlot; never nil after Init
}
type mountSlot struct {
@@ -158,7 +161,7 @@ func newMountTableSlots(cap uintptr) unsafe.Pointer {
// Lookup may be called even if there are concurrent mutators of mt.
func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount {
key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)}
- hash := memhash(noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes)
+ hash := memhash(gohacks.Noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes)
loop:
for {
@@ -359,12 +362,3 @@ func memhash(p unsafe.Pointer, seed, s uintptr) uintptr
//go:linkname rand32 runtime.fastrand
func rand32() uint32
-
-// This is copy/pasted from runtime.noescape(), and is needed because arguments
-// apparently escape from all functions defined by linkname.
-//
-//go:nosplit
-func noescape(p unsafe.Pointer) unsafe.Pointer {
- x := uintptr(p)
- return unsafe.Pointer(x ^ 0)
-}
diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go
index b7774bf28..6af7fdac1 100644
--- a/pkg/sentry/vfs/options.go
+++ b/pkg/sentry/vfs/options.go
@@ -61,7 +61,7 @@ type MountOptions struct {
type OpenOptions struct {
// Flags contains access mode and flags as specified for open(2).
//
- // FilesystemImpls is reponsible for implementing the following flags:
+ // FilesystemImpls are responsible for implementing the following flags:
// O_RDONLY, O_WRONLY, O_RDWR, O_APPEND, O_CREAT, O_DIRECT, O_DSYNC,
// O_EXCL, O_NOATIME, O_NOCTTY, O_NONBLOCK, O_PATH, O_SYNC, O_TMPFILE, and
// O_TRUNC. VFS is responsible for handling O_DIRECTORY, O_LARGEFILE, and
@@ -72,6 +72,11 @@ type OpenOptions struct {
// If FilesystemImpl.OpenAt() creates a file, Mode is the file mode for the
// created file.
Mode linux.FileMode
+
+ // FileExec is set when the file is being opened to be executed.
+ // VirtualFilesystem.OpenAt() checks that the caller has execute permissions
+ // on the file, and that the file is a regular file.
+ FileExec bool
}
// ReadOptions contains options to FileDescription.PRead(),
diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go
index f664581f4..8e250998a 100644
--- a/pkg/sentry/vfs/permissions.go
+++ b/pkg/sentry/vfs/permissions.go
@@ -103,17 +103,22 @@ func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, isDir boo
// AccessTypesForOpenFlags returns MayRead|MayWrite in this case.
//
// Use May{Read,Write}FileWithOpenFlags() for these checks instead.
-func AccessTypesForOpenFlags(flags uint32) AccessTypes {
- switch flags & linux.O_ACCMODE {
+func AccessTypesForOpenFlags(opts *OpenOptions) AccessTypes {
+ ats := AccessTypes(0)
+ if opts.FileExec {
+ ats |= MayExec
+ }
+
+ switch opts.Flags & linux.O_ACCMODE {
case linux.O_RDONLY:
- if flags&linux.O_TRUNC != 0 {
- return MayRead | MayWrite
+ if opts.Flags&linux.O_TRUNC != 0 {
+ return ats | MayRead | MayWrite
}
- return MayRead
+ return ats | MayRead
case linux.O_WRONLY:
- return MayWrite
+ return ats | MayWrite
default:
- return MayRead | MayWrite
+ return ats | MayRead | MayWrite
}
}
diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go
index 8a0b382f6..eb4ebb511 100644
--- a/pkg/sentry/vfs/resolving_path.go
+++ b/pkg/sentry/vfs/resolving_path.go
@@ -228,7 +228,7 @@ func (rp *ResolvingPath) Advance() {
rp.pit = next
} else { // at end of path segment, continue with next one
rp.curPart--
- rp.pit = rp.parts[rp.curPart-1]
+ rp.pit = rp.parts[rp.curPart]
}
}
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 908c69f91..bde81e1ef 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -46,11 +46,13 @@ import (
//
// There is no analogue to the VirtualFilesystem type in Linux, as the
// equivalent state in Linux is global.
+//
+// +stateify savable
type VirtualFilesystem struct {
// mountMu serializes mount mutations.
//
// mountMu is analogous to Linux's namespace_sem.
- mountMu sync.Mutex
+ mountMu sync.Mutex `state:"nosave"`
// mounts maps (mount parent, mount point) pairs to mounts. (Since mounts
// are uniquely namespaced, including mount parent in the key correctly
@@ -89,56 +91,59 @@ type VirtualFilesystem struct {
// devices contains all registered Devices. devices is protected by
// devicesMu.
- devicesMu sync.RWMutex
+ devicesMu sync.RWMutex `state:"nosave"`
devices map[devTuple]*registeredDevice
// anonBlockDevMinor contains all allocated anonymous block device minor
// numbers. anonBlockDevMinorNext is a lower bound for the smallest
// unallocated anonymous block device number. anonBlockDevMinorNext and
// anonBlockDevMinor are protected by anonBlockDevMinorMu.
- anonBlockDevMinorMu sync.Mutex
+ anonBlockDevMinorMu sync.Mutex `state:"nosave"`
anonBlockDevMinorNext uint32
anonBlockDevMinor map[uint32]struct{}
// fsTypes contains all registered FilesystemTypes. fsTypes is protected by
// fsTypesMu.
- fsTypesMu sync.RWMutex
+ fsTypesMu sync.RWMutex `state:"nosave"`
fsTypes map[string]*registeredFilesystemType
// filesystems contains all Filesystems. filesystems is protected by
// filesystemsMu.
- filesystemsMu sync.Mutex
+ filesystemsMu sync.Mutex `state:"nosave"`
filesystems map[*Filesystem]struct{}
}
-// New returns a new VirtualFilesystem with no mounts or FilesystemTypes.
-func New() *VirtualFilesystem {
- vfs := &VirtualFilesystem{
- mountpoints: make(map[*Dentry]map[*Mount]struct{}),
- devices: make(map[devTuple]*registeredDevice),
- anonBlockDevMinorNext: 1,
- anonBlockDevMinor: make(map[uint32]struct{}),
- fsTypes: make(map[string]*registeredFilesystemType),
- filesystems: make(map[*Filesystem]struct{}),
- }
+// Init initializes a new VirtualFilesystem with no mounts or FilesystemTypes.
+func (vfs *VirtualFilesystem) Init() error {
+ vfs.mountpoints = make(map[*Dentry]map[*Mount]struct{})
+ vfs.devices = make(map[devTuple]*registeredDevice)
+ vfs.anonBlockDevMinorNext = 1
+ vfs.anonBlockDevMinor = make(map[uint32]struct{})
+ vfs.fsTypes = make(map[string]*registeredFilesystemType)
+ vfs.filesystems = make(map[*Filesystem]struct{})
vfs.mounts.Init()
// Construct vfs.anonMount.
anonfsDevMinor, err := vfs.GetAnonBlockDevMinor()
if err != nil {
- panic(fmt.Sprintf("VirtualFilesystem.GetAnonBlockDevMinor() failed during VirtualFilesystem construction: %v", err))
+ // This shouldn't be possible since anonBlockDevMinorNext was
+ // initialized to 1 above (no device numbers have been allocated yet).
+ panic(fmt.Sprintf("VirtualFilesystem.Init: device number allocation for anonfs failed: %v", err))
}
anonfs := anonFilesystem{
devMinor: anonfsDevMinor,
}
anonfs.vfsfs.Init(vfs, &anonfs)
- vfs.anonMount = &Mount{
- vfs: vfs,
- fs: &anonfs.vfsfs,
- refs: 1,
+ defer anonfs.vfsfs.DecRef()
+ anonMount, err := vfs.NewDisconnectedMount(&anonfs.vfsfs, nil, &MountOptions{})
+ if err != nil {
+ // We should not be passing any MountOptions that would cause
+ // construction of this mount to fail.
+ panic(fmt.Sprintf("VirtualFilesystem.Init: anonfs mount failed: %v", err))
}
+ vfs.anonMount = anonMount
- return vfs
+ return nil
}
// PathOperation specifies the path operated on by a VFS method.
@@ -379,6 +384,22 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential
fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts)
if err == nil {
vfs.putResolvingPath(rp)
+
+ // TODO(gvisor.dev/issue/1193): Move inside fsimpl to avoid another call
+ // to FileDescription.Stat().
+ if opts.FileExec {
+ // Only a regular file can be executed.
+ stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_TYPE})
+ if err != nil {
+ fd.DecRef()
+ return nil, err
+ }
+ if stat.Mask&linux.STATX_TYPE == 0 || stat.Mode&linux.S_IFMT != linux.S_IFREG {
+ fd.DecRef()
+ return nil, syserror.EACCES
+ }
+ }
+
return fd, nil
}
if !rp.handleError(err) {
@@ -724,6 +745,8 @@ func (vfs *VirtualFilesystem) SyncAllFilesystems(ctx context.Context) error {
// VirtualDentry methods require that a reference is held on the VirtualDentry.
//
// VirtualDentry is analogous to Linux's struct path.
+//
+// +stateify savable
type VirtualDentry struct {
mount *Mount
dentry *Dentry
diff --git a/pkg/sleep/commit_noasm.go b/pkg/sleep/commit_noasm.go
index 3af447fb9..f59061f37 100644
--- a/pkg/sleep/commit_noasm.go
+++ b/pkg/sleep/commit_noasm.go
@@ -28,15 +28,6 @@ import "sync/atomic"
// It is written in assembly because it is called from g0, so it doesn't have
// a race context.
func commitSleep(g uintptr, waitingG *uintptr) bool {
- for {
- // Check if the wait was aborted.
- if atomic.LoadUintptr(waitingG) == 0 {
- return false
- }
-
- // Try to store the G so that wakers know who to wake.
- if atomic.CompareAndSwapUintptr(waitingG, preparingG, g) {
- return true
- }
- }
+ // Try to store the G so that wakers know who to wake.
+ return atomic.CompareAndSwapUintptr(waitingG, preparingG, g)
}
diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go
index acbf0229b..65bfcf778 100644
--- a/pkg/sleep/sleep_unsafe.go
+++ b/pkg/sleep/sleep_unsafe.go
@@ -299,20 +299,17 @@ func (s *Sleeper) enqueueAssertedWaker(w *Waker) {
}
}
- for {
- // Nothing to do if there isn't a G waiting.
- g := atomic.LoadUintptr(&s.waitingG)
- if g == 0 {
- return
- }
+ // Nothing to do if there isn't a G waiting.
+ if atomic.LoadUintptr(&s.waitingG) == 0 {
+ return
+ }
- // Signal to the sleeper that a waker has been asserted.
- if atomic.CompareAndSwapUintptr(&s.waitingG, g, 0) {
- if g != preparingG {
- // We managed to get a G. Wake it up.
- goready(g, 0)
- }
- }
+ // Signal to the sleeper that a waker has been asserted.
+ switch g := atomic.SwapUintptr(&s.waitingG, 0); g {
+ case 0, preparingG:
+ default:
+ // We managed to get a G. Wake it up.
+ goready(g, 0)
}
}
diff --git a/pkg/syncevent/BUILD b/pkg/syncevent/BUILD
new file mode 100644
index 000000000..0500a22cf
--- /dev/null
+++ b/pkg/syncevent/BUILD
@@ -0,0 +1,39 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+licenses(["notice"])
+
+go_library(
+ name = "syncevent",
+ srcs = [
+ "broadcaster.go",
+ "receiver.go",
+ "source.go",
+ "syncevent.go",
+ "waiter_amd64.s",
+ "waiter_arm64.s",
+ "waiter_asm_unsafe.go",
+ "waiter_noasm_unsafe.go",
+ "waiter_unsafe.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/atomicbitops",
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "syncevent_test",
+ size = "small",
+ srcs = [
+ "broadcaster_test.go",
+ "syncevent_example_test.go",
+ "waiter_test.go",
+ ],
+ library = ":syncevent",
+ deps = [
+ "//pkg/sleep",
+ "//pkg/sync",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/syncevent/broadcaster.go b/pkg/syncevent/broadcaster.go
new file mode 100644
index 000000000..4bff59e7d
--- /dev/null
+++ b/pkg/syncevent/broadcaster.go
@@ -0,0 +1,218 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Broadcaster is an implementation of Source that supports any number of
+// subscribed Receivers.
+//
+// The zero value of Broadcaster is valid and has no subscribed Receivers.
+// Broadcaster is not copyable by value.
+//
+// All Broadcaster methods may be called concurrently from multiple goroutines.
+type Broadcaster struct {
+ // Broadcaster is implemented as a hash table where keys are assigned by
+ // the Broadcaster and returned as SubscriptionIDs, making it safe to use
+ // the identity function for hashing. The hash table resolves collisions
+ // using linear probing and features Robin Hood insertion and backward
+ // shift deletion in order to support a relatively high load factor
+ // efficiently, which matters since the cost of Broadcast is linear in the
+ // size of the table.
+
+ // mu protects the following fields.
+ mu sync.Mutex
+
+ // Invariants: len(table) is 0 or a power of 2.
+ table []broadcasterSlot
+
+ // load is the number of entries in table with receiver != nil.
+ load int
+
+ lastID SubscriptionID
+}
+
+type broadcasterSlot struct {
+ // Invariants: If receiver == nil, then filter == NoEvents and id == 0.
+ // Otherwise, id != 0.
+ receiver *Receiver
+ filter Set
+ id SubscriptionID
+}
+
+const (
+ broadcasterMinNonZeroTableSize = 2 // must be a power of 2 > 1
+
+ broadcasterMaxLoadNum = 13
+ broadcasterMaxLoadDen = 16
+)
+
+// SubscribeEvents implements Source.SubscribeEvents.
+func (b *Broadcaster) SubscribeEvents(r *Receiver, filter Set) SubscriptionID {
+ b.mu.Lock()
+
+ // Assign an ID for this subscription.
+ b.lastID++
+ id := b.lastID
+
+ // Expand the table if over the maximum load factor:
+ //
+ // load / len(b.table) > broadcasterMaxLoadNum / broadcasterMaxLoadDen
+ // load * broadcasterMaxLoadDen > broadcasterMaxLoadNum * len(b.table)
+ b.load++
+ if (b.load * broadcasterMaxLoadDen) > (broadcasterMaxLoadNum * len(b.table)) {
+ // Double the number of slots in the new table.
+ newlen := broadcasterMinNonZeroTableSize
+ if len(b.table) != 0 {
+ newlen = 2 * len(b.table)
+ }
+ if newlen <= cap(b.table) {
+ // Reuse excess capacity in the current table, moving entries not
+ // already in their first-probed positions to better ones.
+ newtable := b.table[:newlen]
+ newmask := uint64(newlen - 1)
+ for i := range b.table {
+ if b.table[i].receiver != nil && uint64(b.table[i].id)&newmask != uint64(i) {
+ entry := b.table[i]
+ b.table[i] = broadcasterSlot{}
+ broadcasterTableInsert(newtable, entry.id, entry.receiver, entry.filter)
+ }
+ }
+ b.table = newtable
+ } else {
+ newtable := make([]broadcasterSlot, newlen)
+ // Copy existing entries to the new table.
+ for i := range b.table {
+ if b.table[i].receiver != nil {
+ broadcasterTableInsert(newtable, b.table[i].id, b.table[i].receiver, b.table[i].filter)
+ }
+ }
+ // Switch to the new table.
+ b.table = newtable
+ }
+ }
+
+ broadcasterTableInsert(b.table, id, r, filter)
+ b.mu.Unlock()
+ return id
+}
+
+// Preconditions: table must not be full. len(table) is a power of 2.
+func broadcasterTableInsert(table []broadcasterSlot, id SubscriptionID, r *Receiver, filter Set) {
+ entry := broadcasterSlot{
+ receiver: r,
+ filter: filter,
+ id: id,
+ }
+ mask := uint64(len(table) - 1)
+ i := uint64(id) & mask
+ disp := uint64(0)
+ for {
+ if table[i].receiver == nil {
+ table[i] = entry
+ return
+ }
+ // If we've been displaced farther from our first-probed slot than the
+ // element stored in this one, swap elements and switch to inserting
+ // the replaced one. (This is Robin Hood insertion.)
+ slotDisp := (i - uint64(table[i].id)) & mask
+ if disp > slotDisp {
+ table[i], entry = entry, table[i]
+ disp = slotDisp
+ }
+ i = (i + 1) & mask
+ disp++
+ }
+}
+
+// UnsubscribeEvents implements Source.UnsubscribeEvents.
+func (b *Broadcaster) UnsubscribeEvents(id SubscriptionID) {
+ b.mu.Lock()
+
+ mask := uint64(len(b.table) - 1)
+ i := uint64(id) & mask
+ for {
+ if b.table[i].id == id {
+ // Found the element to remove. Move all subsequent elements
+ // backward until we either find an empty slot, or an element that
+ // is already in its first-probed slot. (This is backward shift
+ // deletion.)
+ for {
+ next := (i + 1) & mask
+ if b.table[next].receiver == nil {
+ break
+ }
+ if uint64(b.table[next].id)&mask == next {
+ break
+ }
+ b.table[i] = b.table[next]
+ i = next
+ }
+ b.table[i] = broadcasterSlot{}
+ break
+ }
+ i = (i + 1) & mask
+ }
+
+ // If a table 1/4 of the current size would still be at or under the
+ // maximum load factor (i.e. the current table size is at least two
+ // expansions bigger than necessary), halve the size of the table to reduce
+ // the cost of Broadcast. Since we are concerned with iteration time and
+ // not memory usage, reuse the existing slice to reduce future allocations
+ // from table re-expansion.
+ b.load--
+ if len(b.table) > broadcasterMinNonZeroTableSize && (b.load*(4*broadcasterMaxLoadDen)) <= (broadcasterMaxLoadNum*len(b.table)) {
+ newlen := len(b.table) / 2
+ newtable := b.table[:newlen]
+ for i := newlen; i < len(b.table); i++ {
+ if b.table[i].receiver != nil {
+ broadcasterTableInsert(newtable, b.table[i].id, b.table[i].receiver, b.table[i].filter)
+ b.table[i] = broadcasterSlot{}
+ }
+ }
+ b.table = newtable
+ }
+
+ b.mu.Unlock()
+}
+
+// Broadcast notifies all Receivers subscribed to the Broadcaster of the subset
+// of events to which they subscribed. The order in which Receivers are
+// notified is unspecified.
+func (b *Broadcaster) Broadcast(events Set) {
+ b.mu.Lock()
+ for i := range b.table {
+ if intersection := events & b.table[i].filter; intersection != 0 {
+ // We don't need to check if broadcasterSlot.receiver is nil, since
+ // if it is then broadcasterSlot.filter is 0.
+ b.table[i].receiver.Notify(intersection)
+ }
+ }
+ b.mu.Unlock()
+}
+
+// FilteredEvents returns the set of events for which Broadcast will notify at
+// least one Receiver, i.e. the union of filters for all subscribed Receivers.
+func (b *Broadcaster) FilteredEvents() Set {
+ var es Set
+ b.mu.Lock()
+ for i := range b.table {
+ es |= b.table[i].filter
+ }
+ b.mu.Unlock()
+ return es
+}
diff --git a/pkg/syncevent/broadcaster_test.go b/pkg/syncevent/broadcaster_test.go
new file mode 100644
index 000000000..e88779e23
--- /dev/null
+++ b/pkg/syncevent/broadcaster_test.go
@@ -0,0 +1,376 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "fmt"
+ "math/rand"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func TestBroadcasterFilter(t *testing.T) {
+ const numReceivers = 2 * MaxEvents
+
+ var br Broadcaster
+ ws := make([]Waiter, numReceivers)
+ for i := range ws {
+ ws[i].Init()
+ br.SubscribeEvents(ws[i].Receiver(), 1<<(i%MaxEvents))
+ }
+ for ev := 0; ev < MaxEvents; ev++ {
+ br.Broadcast(1 << ev)
+ for i := range ws {
+ want := NoEvents
+ if i%MaxEvents == ev {
+ want = 1 << ev
+ }
+ if got := ws[i].Receiver().PendingAndAckAll(); got != want {
+ t.Errorf("after Broadcast of event %d: waiter %d has pending event set %#x, wanted %#x", ev, i, got, want)
+ }
+ }
+ }
+}
+
+// TestBroadcasterManySubscriptions tests that subscriptions are not lost by
+// table expansion/compaction.
+func TestBroadcasterManySubscriptions(t *testing.T) {
+ const numReceivers = 5000 // arbitrary
+
+ var br Broadcaster
+ ws := make([]Waiter, numReceivers)
+ for i := range ws {
+ ws[i].Init()
+ }
+
+ ids := make([]SubscriptionID, numReceivers)
+ for i := 0; i < numReceivers; i++ {
+ // Subscribe receiver i.
+ ids[i] = br.SubscribeEvents(ws[i].Receiver(), 1)
+ // Check that receivers [0, i] are subscribed.
+ br.Broadcast(1)
+ for j := 0; j <= i; j++ {
+ if ws[j].Pending() != 1 {
+ t.Errorf("receiver %d did not receive an event after subscription of receiver %d", j, i)
+ }
+ ws[j].Ack(1)
+ }
+ }
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numReceivers)
+ for i := 0; i < numReceivers; i++ {
+ // Unsubscribe receiver unsub[i].
+ br.UnsubscribeEvents(ids[unsub[i]])
+ // Check that receivers [unsub[0], unsub[i]] are not subscribed, and that
+ // receivers (unsub[i], unsub[numReceivers]) are still subscribed.
+ br.Broadcast(1)
+ for j := 0; j <= i; j++ {
+ if ws[unsub[j]].Pending() != 0 {
+ t.Errorf("unsub iteration %d: receiver %d received an event after unsubscription of receiver %d", i, unsub[j], unsub[i])
+ }
+ }
+ for j := i + 1; j < numReceivers; j++ {
+ if ws[unsub[j]].Pending() != 1 {
+ t.Errorf("unsub iteration %d: receiver %d did not receive an event after unsubscription of receiver %d", i, unsub[j], unsub[i])
+ }
+ ws[unsub[j]].Ack(1)
+ }
+ }
+}
+
+var (
+ receiverCountsNonZero = []int{1, 4, 16, 64}
+ receiverCountsIncludingZero = append([]int{0}, receiverCountsNonZero...)
+)
+
+// BenchmarkBroadcasterX, BenchmarkMapX, and BenchmarkQueueX benchmark usage
+// pattern X (described in terms of Broadcaster) with Broadcaster, a
+// Mutex-protected map[*Receiver]Set, and waiter.Queue respectively.
+
+// BenchmarkXxxSubscribeUnsubscribe measures the cost of a Subscribe/Unsubscribe
+// cycle.
+
+func BenchmarkBroadcasterSubscribeUnsubscribe(b *testing.B) {
+ var br Broadcaster
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ id := br.SubscribeEvents(w.Receiver(), 1)
+ br.UnsubscribeEvents(id)
+ }
+}
+
+func BenchmarkMapSubscribeUnsubscribe(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ mu.Lock()
+ m[w.Receiver()] = Set(1)
+ mu.Unlock()
+ mu.Lock()
+ delete(m, w.Receiver())
+ mu.Unlock()
+ }
+}
+
+func BenchmarkQueueSubscribeUnsubscribe(b *testing.B) {
+ var q waiter.Queue
+ e, _ := waiter.NewChannelEntry(nil)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ q.EventRegister(&e, 1)
+ q.EventUnregister(&e)
+ }
+}
+
+// BenchmarkXxxSubscribeUnsubscribeBatch is similar to
+// BenchmarkXxxSubscribeUnsubscribe, but subscribes and unsubscribes a large
+// number of Receivers at a time in order to measure the amortized overhead of
+// table expansion/compaction. (Since waiter.Queue is implemented using a
+// linked list, BenchmarkQueueSubscribeUnsubscribe and
+// BenchmarkQueueSubscribeUnsubscribeBatch should produce nearly the same
+// result.)
+
+const numBatchReceivers = 1000
+
+func BenchmarkBroadcasterSubscribeUnsubscribeBatch(b *testing.B) {
+ var br Broadcaster
+ ws := make([]Waiter, numBatchReceivers)
+ for i := range ws {
+ ws[i].Init()
+ }
+ ids := make([]SubscriptionID, numBatchReceivers)
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numBatchReceivers)
+
+ b.ResetTimer()
+ for i := 0; i < b.N/numBatchReceivers; i++ {
+ for j := 0; j < numBatchReceivers; j++ {
+ ids[j] = br.SubscribeEvents(ws[j].Receiver(), 1)
+ }
+ for j := 0; j < numBatchReceivers; j++ {
+ br.UnsubscribeEvents(ids[unsub[j]])
+ }
+ }
+}
+
+func BenchmarkMapSubscribeUnsubscribeBatch(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ ws := make([]Waiter, numBatchReceivers)
+ for i := range ws {
+ ws[i].Init()
+ }
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numBatchReceivers)
+
+ b.ResetTimer()
+ for i := 0; i < b.N/numBatchReceivers; i++ {
+ for j := 0; j < numBatchReceivers; j++ {
+ mu.Lock()
+ m[ws[j].Receiver()] = Set(1)
+ mu.Unlock()
+ }
+ for j := 0; j < numBatchReceivers; j++ {
+ mu.Lock()
+ delete(m, ws[unsub[j]].Receiver())
+ mu.Unlock()
+ }
+ }
+}
+
+func BenchmarkQueueSubscribeUnsubscribeBatch(b *testing.B) {
+ var q waiter.Queue
+ es := make([]waiter.Entry, numBatchReceivers)
+ for i := range es {
+ es[i], _ = waiter.NewChannelEntry(nil)
+ }
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numBatchReceivers)
+
+ b.ResetTimer()
+ for i := 0; i < b.N/numBatchReceivers; i++ {
+ for j := 0; j < numBatchReceivers; j++ {
+ q.EventRegister(&es[j], 1)
+ }
+ for j := 0; j < numBatchReceivers; j++ {
+ q.EventUnregister(&es[unsub[j]])
+ }
+ }
+}
+
+// BenchmarkXxxBroadcastRedundant measures how long it takes to Broadcast
+// already-pending events to multiple Receivers.
+
+func BenchmarkBroadcasterBroadcastRedundant(b *testing.B) {
+ for _, n := range receiverCountsIncludingZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var br Broadcaster
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ br.SubscribeEvents(ws[i].Receiver(), 1)
+ }
+ br.Broadcast(1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ br.Broadcast(1)
+ }
+ })
+ }
+}
+
+func BenchmarkMapBroadcastRedundant(b *testing.B) {
+ for _, n := range receiverCountsIncludingZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ m[ws[i].Receiver()] = Set(1)
+ }
+ mu.Lock()
+ for r := range m {
+ r.Notify(1)
+ }
+ mu.Unlock()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ mu.Lock()
+ for r := range m {
+ r.Notify(1)
+ }
+ mu.Unlock()
+ }
+ })
+ }
+}
+
+func BenchmarkQueueBroadcastRedundant(b *testing.B) {
+ for _, n := range receiverCountsIncludingZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var q waiter.Queue
+ for i := 0; i < n; i++ {
+ e, _ := waiter.NewChannelEntry(nil)
+ q.EventRegister(&e, 1)
+ }
+ q.Notify(1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ q.Notify(1)
+ }
+ })
+ }
+}
+
+// BenchmarkXxxBroadcastAck measures how long it takes to Broadcast events to
+// multiple Receivers, check that all Receivers have received the event, and
+// clear the event from all Receivers.
+
+func BenchmarkBroadcasterBroadcastAck(b *testing.B) {
+ for _, n := range receiverCountsNonZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var br Broadcaster
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ br.SubscribeEvents(ws[i].Receiver(), 1)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ br.Broadcast(1)
+ for j := range ws {
+ if got, want := ws[j].Pending(), Set(1); got != want {
+ b.Fatalf("Receiver.Pending(): got %#x, wanted %#x", got, want)
+ }
+ ws[j].Ack(1)
+ }
+ }
+ })
+ }
+}
+
+func BenchmarkMapBroadcastAck(b *testing.B) {
+ for _, n := range receiverCountsNonZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ m[ws[i].Receiver()] = Set(1)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ mu.Lock()
+ for r := range m {
+ r.Notify(1)
+ }
+ mu.Unlock()
+ for j := range ws {
+ if got, want := ws[j].Pending(), Set(1); got != want {
+ b.Fatalf("Receiver.Pending(): got %#x, wanted %#x", got, want)
+ }
+ ws[j].Ack(1)
+ }
+ }
+ })
+ }
+}
+
+func BenchmarkQueueBroadcastAck(b *testing.B) {
+ for _, n := range receiverCountsNonZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var q waiter.Queue
+ chs := make([]chan struct{}, n)
+ for i := range chs {
+ e, ch := waiter.NewChannelEntry(nil)
+ q.EventRegister(&e, 1)
+ chs[i] = ch
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ q.Notify(1)
+ for _, ch := range chs {
+ select {
+ case <-ch:
+ default:
+ b.Fatalf("channel did not receive event")
+ }
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/syncevent/receiver.go b/pkg/syncevent/receiver.go
new file mode 100644
index 000000000..5c86e5400
--- /dev/null
+++ b/pkg/syncevent/receiver.go
@@ -0,0 +1,103 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/atomicbitops"
+)
+
+// Receiver is an event sink that holds pending events and invokes a callback
+// whenever new events become pending. Receiver's methods may be called
+// concurrently from multiple goroutines.
+//
+// Receiver.Init() must be called before first use.
+type Receiver struct {
+ // pending is the set of pending events. pending is accessed using atomic
+ // memory operations.
+ pending uint64
+
+ // cb is notified when new events become pending. cb is immutable after
+ // Init().
+ cb ReceiverCallback
+}
+
+// ReceiverCallback receives callbacks from a Receiver.
+type ReceiverCallback interface {
+ // NotifyPending is called when the corresponding Receiver has new pending
+ // events.
+ //
+ // NotifyPending is called synchronously from Receiver.Notify(), so
+ // implementations must not take locks that may be held by callers of
+ // Receiver.Notify(). NotifyPending may be called concurrently from
+ // multiple goroutines.
+ NotifyPending()
+}
+
+// Init must be called before first use of r.
+func (r *Receiver) Init(cb ReceiverCallback) {
+ r.cb = cb
+}
+
+// Pending returns the set of pending events.
+func (r *Receiver) Pending() Set {
+ return Set(atomic.LoadUint64(&r.pending))
+}
+
+// Notify sets the given events as pending.
+func (r *Receiver) Notify(es Set) {
+ p := Set(atomic.LoadUint64(&r.pending))
+ // Optimization: Skip the atomic CAS on r.pending if all events are
+ // already pending.
+ if p&es == es {
+ return
+ }
+ // When this is uncontended (the common case), CAS is faster than
+ // atomic-OR because the former is inlined and the latter (which we
+ // implement in assembly ourselves) is not.
+ if !atomic.CompareAndSwapUint64(&r.pending, uint64(p), uint64(p|es)) {
+ // If the CAS fails, fall back to atomic-OR.
+ atomicbitops.OrUint64(&r.pending, uint64(es))
+ }
+ r.cb.NotifyPending()
+}
+
+// Ack unsets the given events as pending.
+func (r *Receiver) Ack(es Set) {
+ p := Set(atomic.LoadUint64(&r.pending))
+ // Optimization: Skip the atomic CAS on r.pending if all events are
+ // already not pending.
+ if p&es == 0 {
+ return
+ }
+ // When this is uncontended (the common case), CAS is faster than
+ // atomic-AND because the former is inlined and the latter (which we
+ // implement in assembly ourselves) is not.
+ if !atomic.CompareAndSwapUint64(&r.pending, uint64(p), uint64(p&^es)) {
+ // If the CAS fails, fall back to atomic-AND.
+ atomicbitops.AndUint64(&r.pending, ^uint64(es))
+ }
+}
+
+// PendingAndAckAll unsets all events as pending and returns the set of
+// previously-pending events.
+//
+// PendingAndAckAll should only be used in preference to a call to Pending
+// followed by a conditional call to Ack when the caller expects events to be
+// pending (e.g. after a call to ReceiverCallback.NotifyPending()).
+func (r *Receiver) PendingAndAckAll() Set {
+ return Set(atomic.SwapUint64(&r.pending, 0))
+}
diff --git a/pkg/syncevent/source.go b/pkg/syncevent/source.go
new file mode 100644
index 000000000..ddffb171a
--- /dev/null
+++ b/pkg/syncevent/source.go
@@ -0,0 +1,59 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+// Source represents an event source.
+type Source interface {
+ // SubscribeEvents causes the Source to notify the given Receiver of the
+ // given subset of events.
+ //
+ // Preconditions: r != nil. The ReceiverCallback for r must not take locks
+ // that are ordered prior to the Source; for example, it cannot call any
+ // Source methods.
+ SubscribeEvents(r *Receiver, filter Set) SubscriptionID
+
+ // UnsubscribeEvents causes the Source to stop notifying the Receiver
+ // subscribed by a previous call to SubscribeEvents that returned the given
+ // SubscriptionID.
+ //
+ // Preconditions: UnsubscribeEvents may be called at most once for any
+ // given SubscriptionID.
+ UnsubscribeEvents(id SubscriptionID)
+}
+
+// SubscriptionID identifies a call to Source.SubscribeEvents.
+type SubscriptionID uint64
+
+// UnsubscribeAndAck is a convenience function that unsubscribes r from the
+// given events from src and also clears them from r.
+func UnsubscribeAndAck(src Source, r *Receiver, filter Set, id SubscriptionID) {
+ src.UnsubscribeEvents(id)
+ r.Ack(filter)
+}
+
+// NoopSource implements Source by never sending events to subscribed
+// Receivers.
+type NoopSource struct{}
+
+// SubscribeEvents implements Source.SubscribeEvents.
+func (NoopSource) SubscribeEvents(*Receiver, Set) SubscriptionID {
+ return 0
+}
+
+// UnsubscribeEvents implements Source.UnsubscribeEvents.
+func (NoopSource) UnsubscribeEvents(SubscriptionID) {
+}
+
+// See Broadcaster for a non-noop implementations of Source.
diff --git a/pkg/syncevent/syncevent.go b/pkg/syncevent/syncevent.go
new file mode 100644
index 000000000..9fb6a06de
--- /dev/null
+++ b/pkg/syncevent/syncevent.go
@@ -0,0 +1,32 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package syncevent provides efficient primitives for goroutine
+// synchronization based on event bitmasks.
+package syncevent
+
+// Set is a bitmask where each bit represents a distinct user-defined event.
+// The event package does not treat any bits in Set specially.
+type Set uint64
+
+const (
+ // NoEvents is a Set containing no events.
+ NoEvents = Set(0)
+
+ // AllEvents is a Set containing all possible events.
+ AllEvents = ^Set(0)
+
+ // MaxEvents is the number of distinct events that can be represented by a Set.
+ MaxEvents = 64
+)
diff --git a/pkg/syncevent/syncevent_example_test.go b/pkg/syncevent/syncevent_example_test.go
new file mode 100644
index 000000000..bfb18e2ea
--- /dev/null
+++ b/pkg/syncevent/syncevent_example_test.go
@@ -0,0 +1,108 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "fmt"
+ "sync/atomic"
+ "time"
+)
+
+func Example_ioReadinessInterrputible() {
+ const (
+ evReady = Set(1 << iota)
+ evInterrupt
+ )
+ errNotReady := fmt.Errorf("not ready for I/O")
+
+ // State of some I/O object.
+ var (
+ br Broadcaster
+ ready uint32
+ )
+ doIO := func() error {
+ if atomic.LoadUint32(&ready) == 0 {
+ return errNotReady
+ }
+ return nil
+ }
+ go func() {
+ // The I/O object eventually becomes ready for I/O.
+ time.Sleep(100 * time.Millisecond)
+ // When it does, it first ensures that future calls to isReady() return
+ // true, then broadcasts the readiness event to Receivers.
+ atomic.StoreUint32(&ready, 1)
+ br.Broadcast(evReady)
+ }()
+
+ // Each user of the I/O object owns a Waiter.
+ var w Waiter
+ w.Init()
+ // The Waiter may be asynchronously interruptible, e.g. for signal
+ // handling in the sentry.
+ go func() {
+ time.Sleep(200 * time.Millisecond)
+ w.Receiver().Notify(evInterrupt)
+ }()
+
+ // To use the I/O object:
+ //
+ // Optionally, if the I/O object is likely to be ready, attempt I/O first.
+ err := doIO()
+ if err == nil {
+ // Success, we're done.
+ return /* nil */
+ }
+ if err != errNotReady {
+ // Failure, I/O failed for some reason other than readiness.
+ return /* err */
+ }
+ // Subscribe for readiness events from the I/O object.
+ id := br.SubscribeEvents(w.Receiver(), evReady)
+ // When we are finished blocking, unsubscribe from readiness events and
+ // remove readiness events from the pending event set.
+ defer UnsubscribeAndAck(&br, w.Receiver(), evReady, id)
+ for {
+ // Attempt I/O again. This must be done after the call to SubscribeEvents,
+ // since the I/O object might have become ready between the previous call
+ // to doIO and the call to SubscribeEvents.
+ err = doIO()
+ if err == nil {
+ return /* nil */
+ }
+ if err != errNotReady {
+ return /* err */
+ }
+ // Block until either the I/O object indicates it is ready, or we are
+ // interrupted.
+ events := w.Wait()
+ if events&evInterrupt != 0 {
+ // In the specific case of sentry signal handling, signal delivery
+ // is handled by another system, so we aren't responsible for
+ // acknowledging evInterrupt.
+ return /* errInterrupted */
+ }
+ // Note that, in a concurrent context, the I/O object might become
+ // ready and then not ready again. To handle this:
+ //
+ // - evReady must be acknowledged before calling doIO() again (rather
+ // than after), so that if the I/O object becomes ready *again* after
+ // the call to doIO(), the readiness event is not lost.
+ //
+ // - We must loop instead of just calling doIO() once after receiving
+ // evReady.
+ w.Ack(evReady)
+ }
+}
diff --git a/pkg/syncevent/waiter_amd64.s b/pkg/syncevent/waiter_amd64.s
new file mode 100644
index 000000000..985b56ae5
--- /dev/null
+++ b/pkg/syncevent/waiter_amd64.s
@@ -0,0 +1,32 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// See waiter_noasm_unsafe.go for a description of waiterUnlock.
+//
+// func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool
+TEXT ·waiterUnlock(SB),NOSPLIT,$0-24
+ MOVQ g+0(FP), DI
+ MOVQ wg+8(FP), SI
+
+ MOVQ $·preparingG(SB), AX
+ LOCK
+ CMPXCHGQ DI, 0(SI)
+
+ SETEQ AX
+ MOVB AX, ret+16(FP)
+
+ RET
+
diff --git a/pkg/syncevent/waiter_arm64.s b/pkg/syncevent/waiter_arm64.s
new file mode 100644
index 000000000..20d7ac23b
--- /dev/null
+++ b/pkg/syncevent/waiter_arm64.s
@@ -0,0 +1,34 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// See waiter_noasm_unsafe.go for a description of waiterUnlock.
+//
+// func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool
+TEXT ·waiterUnlock(SB),NOSPLIT,$0-24
+ MOVD wg+8(FP), R0
+ MOVD $·preparingG(SB), R1
+ MOVD g+0(FP), R2
+again:
+ LDAXR (R0), R3
+ CMP R1, R3
+ BNE ok
+ STLXR R2, (R0), R3
+ CBNZ R3, again
+ok:
+ CSET EQ, R0
+ MOVB R0, ret+16(FP)
+ RET
+
diff --git a/pkg/syncevent/waiter_asm_unsafe.go b/pkg/syncevent/waiter_asm_unsafe.go
new file mode 100644
index 000000000..0995e9053
--- /dev/null
+++ b/pkg/syncevent/waiter_asm_unsafe.go
@@ -0,0 +1,24 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64 arm64
+
+package syncevent
+
+import (
+ "unsafe"
+)
+
+// See waiter_noasm_unsafe.go for a description of waiterUnlock.
+func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool
diff --git a/pkg/syncevent/waiter_noasm_unsafe.go b/pkg/syncevent/waiter_noasm_unsafe.go
new file mode 100644
index 000000000..1c4b0e39a
--- /dev/null
+++ b/pkg/syncevent/waiter_noasm_unsafe.go
@@ -0,0 +1,39 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// waiterUnlock is called from g0, so when the race detector is enabled,
+// waiterUnlock must be implemented in assembly since no race context is
+// available.
+//
+// +build !race
+// +build !amd64,!arm64
+
+package syncevent
+
+import (
+ "sync/atomic"
+ "unsafe"
+)
+
+// waiterUnlock is the "unlock function" passed to runtime.gopark by
+// Waiter.Wait*. wg is &Waiter.g, and g is a pointer to the calling runtime.g.
+// waiterUnlock returns true if Waiter.Wait should sleep and false if sleeping
+// should be aborted.
+//
+//go:nosplit
+func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool {
+ // The only way this CAS can fail is if a call to Waiter.NotifyPending()
+ // has replaced *wg with nil, in which case we should not sleep.
+ return atomic.CompareAndSwapPointer(wg, (unsafe.Pointer)(&preparingG), g)
+}
diff --git a/pkg/syncevent/waiter_test.go b/pkg/syncevent/waiter_test.go
new file mode 100644
index 000000000..3c8cbcdd8
--- /dev/null
+++ b/pkg/syncevent/waiter_test.go
@@ -0,0 +1,414 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+func TestWaiterAlreadyPending(t *testing.T) {
+ var w Waiter
+ w.Init()
+ want := Set(1)
+ w.Notify(want)
+ if got := w.Wait(); got != want {
+ t.Errorf("Waiter.Wait: got %#x, wanted %#x", got, want)
+ }
+}
+
+func TestWaiterAsyncNotify(t *testing.T) {
+ var w Waiter
+ w.Init()
+ want := Set(1)
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ w.Notify(want)
+ }()
+ if got := w.Wait(); got != want {
+ t.Errorf("Waiter.Wait: got %#x, wanted %#x", got, want)
+ }
+}
+
+func TestWaiterWaitFor(t *testing.T) {
+ var w Waiter
+ w.Init()
+ evWaited := Set(1)
+ evOther := Set(2)
+ w.Notify(evOther)
+ notifiedEvent := uint32(0)
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ atomic.StoreUint32(&notifiedEvent, 1)
+ w.Notify(evWaited)
+ }()
+ if got, want := w.WaitFor(evWaited), evWaited|evOther; got != want {
+ t.Errorf("Waiter.WaitFor: got %#x, wanted %#x", got, want)
+ }
+ if atomic.LoadUint32(&notifiedEvent) == 0 {
+ t.Errorf("Waiter.WaitFor returned before goroutine notified waited-for event")
+ }
+}
+
+func TestWaiterWaitAndAckAll(t *testing.T) {
+ var w Waiter
+ w.Init()
+ w.Notify(AllEvents)
+ if got := w.WaitAndAckAll(); got != AllEvents {
+ t.Errorf("Waiter.WaitAndAckAll: got %#x, wanted %#x", got, AllEvents)
+ }
+ if got := w.Pending(); got != NoEvents {
+ t.Errorf("Waiter.WaitAndAckAll did not ack all events: got %#x, wanted 0", got)
+ }
+}
+
+// BenchmarkWaiterX, BenchmarkSleeperX, and BenchmarkChannelX benchmark usage
+// pattern X (described in terms of Waiter) with Waiter, sleep.Sleeper, and
+// buffered chan struct{} respectively. When the maximum number of event
+// sources is relevant, we use 3 event sources because this is representative
+// of the kernel.Task.block() use case: an interrupt source, a timeout source,
+// and the actual event source being waited on.
+
+// Event set used by most benchmarks.
+const evBench Set = 1
+
+// BenchmarkXxxNotifyRedundant measures how long it takes to notify a Waiter of
+// an event that is already pending.
+
+func BenchmarkWaiterNotifyRedundant(b *testing.B) {
+ var w Waiter
+ w.Init()
+ w.Notify(evBench)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Notify(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyRedundant(b *testing.B) {
+ var s sleep.Sleeper
+ var w sleep.Waker
+ s.AddWaker(&w, 0)
+ w.Assert()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Assert()
+ }
+}
+
+func BenchmarkChannelNotifyRedundant(b *testing.B) {
+ ch := make(chan struct{}, 1)
+ ch <- struct{}{}
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+ }
+}
+
+// BenchmarkXxxNotifyWaitAck measures how long it takes to notify a Waiter an
+// event, return that event using a blocking check, and then unset the event as
+// pending.
+
+func BenchmarkWaiterNotifyWaitAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Notify(evBench)
+ w.Wait()
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyWaitAck(b *testing.B) {
+ var s sleep.Sleeper
+ var w sleep.Waker
+ s.AddWaker(&w, 0)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Assert()
+ s.Fetch(true)
+ }
+}
+
+func BenchmarkChannelNotifyWaitAck(b *testing.B) {
+ ch := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ // notify
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+
+ // wait + ack
+ <-ch
+ }
+}
+
+// BenchmarkSleeperMultiNotifyWaitAck is equivalent to
+// BenchmarkSleeperNotifyWaitAck, but also includes allocation of a
+// temporary sleep.Waker. This is necessary when multiple goroutines may wait
+// for the same event, since each sleep.Waker can wake only a single
+// sleep.Sleeper.
+//
+// The syncevent package does not require a distinct object for each
+// waiter-waker relationship, so BenchmarkWaiterNotifyWaitAck and
+// BenchmarkWaiterMultiNotifyWaitAck would be identical. The analogous state
+// for channels, runtime.sudog, is inescapably runtime-allocated, so
+// BenchmarkChannelNotifyWaitAck and BenchmarkChannelMultiNotifyWaitAck would
+// also be identical.
+
+func BenchmarkSleeperMultiNotifyWaitAck(b *testing.B) {
+ var s sleep.Sleeper
+ // The sleep package doesn't provide sync.Pool allocation of Wakers;
+ // we do for a fairer comparison.
+ wakerPool := sync.Pool{
+ New: func() interface{} {
+ return &sleep.Waker{}
+ },
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w := wakerPool.Get().(*sleep.Waker)
+ s.AddWaker(w, 0)
+ w.Assert()
+ s.Fetch(true)
+ s.Done()
+ wakerPool.Put(w)
+ }
+}
+
+// BenchmarkXxxTempNotifyWaitAck is equivalent to NotifyWaitAck, but also
+// includes allocation of a temporary Waiter. This models the case where a
+// goroutine not already associated with a Waiter needs one in order to block.
+//
+// The analogous state for channels is built into runtime.g, so
+// BenchmarkChannelNotifyWaitAck and BenchmarkChannelTempNotifyWaitAck would be
+// identical.
+
+func BenchmarkWaiterTempNotifyWaitAck(b *testing.B) {
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w := GetWaiter()
+ w.Notify(evBench)
+ w.Wait()
+ w.Ack(evBench)
+ PutWaiter(w)
+ }
+}
+
+func BenchmarkSleeperTempNotifyWaitAck(b *testing.B) {
+ // The sleep package doesn't provide sync.Pool allocation of Sleepers;
+ // we do for a fairer comparison.
+ sleeperPool := sync.Pool{
+ New: func() interface{} {
+ return &sleep.Sleeper{}
+ },
+ }
+ var w sleep.Waker
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ s := sleeperPool.Get().(*sleep.Sleeper)
+ s.AddWaker(&w, 0)
+ w.Assert()
+ s.Fetch(true)
+ s.Done()
+ sleeperPool.Put(s)
+ }
+}
+
+// BenchmarkXxxNotifyWaitMultiAck is equivalent to NotifyWaitAck, but allows
+// for multiple event sources.
+
+func BenchmarkWaiterNotifyWaitMultiAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Notify(evBench)
+ if e := w.Wait(); e != evBench {
+ b.Fatalf("Wait: got %#x, wanted %#x", e, evBench)
+ }
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyWaitMultiAck(b *testing.B) {
+ var s sleep.Sleeper
+ var ws [3]sleep.Waker
+ for i := range ws {
+ s.AddWaker(&ws[i], i)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ ws[0].Assert()
+ if id, _ := s.Fetch(true); id != 0 {
+ b.Fatalf("Fetch: got %d, wanted 0", id)
+ }
+ }
+}
+
+func BenchmarkChannelNotifyWaitMultiAck(b *testing.B) {
+ ch0 := make(chan struct{}, 1)
+ ch1 := make(chan struct{}, 1)
+ ch2 := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ // notify
+ select {
+ case ch0 <- struct{}{}:
+ default:
+ }
+
+ // wait + clear
+ select {
+ case <-ch0:
+ // ok
+ case <-ch1:
+ b.Fatalf("received from ch1")
+ case <-ch2:
+ b.Fatalf("received from ch2")
+ }
+ }
+}
+
+// BenchmarkXxxNotifyAsyncWaitAck measures how long it takes to wait for an
+// event while another goroutine signals the event. This assumes that a new
+// goroutine doesn't run immediately (i.e. the creator of a new goroutine is
+// allowed to go to sleep before the new goroutine has a chance to run).
+
+func BenchmarkWaiterNotifyAsyncWaitAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ w.Notify(1)
+ }()
+ w.Wait()
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyAsyncWaitAck(b *testing.B) {
+ var s sleep.Sleeper
+ var w sleep.Waker
+ s.AddWaker(&w, 0)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ w.Assert()
+ }()
+ s.Fetch(true)
+ }
+}
+
+func BenchmarkChannelNotifyAsyncWaitAck(b *testing.B) {
+ ch := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+ }()
+ <-ch
+ }
+}
+
+// BenchmarkXxxNotifyAsyncWaitMultiAck is equivalent to NotifyAsyncWaitAck, but
+// allows for multiple event sources.
+
+func BenchmarkWaiterNotifyAsyncWaitMultiAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ w.Notify(evBench)
+ }()
+ if e := w.Wait(); e != evBench {
+ b.Fatalf("Wait: got %#x, wanted %#x", e, evBench)
+ }
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyAsyncWaitMultiAck(b *testing.B) {
+ var s sleep.Sleeper
+ var ws [3]sleep.Waker
+ for i := range ws {
+ s.AddWaker(&ws[i], i)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ ws[0].Assert()
+ }()
+ if id, _ := s.Fetch(true); id != 0 {
+ b.Fatalf("Fetch: got %d, expected 0", id)
+ }
+ }
+}
+
+func BenchmarkChannelNotifyAsyncWaitMultiAck(b *testing.B) {
+ ch0 := make(chan struct{}, 1)
+ ch1 := make(chan struct{}, 1)
+ ch2 := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ select {
+ case ch0 <- struct{}{}:
+ default:
+ }
+ }()
+
+ select {
+ case <-ch0:
+ // ok
+ case <-ch1:
+ b.Fatalf("received from ch1")
+ case <-ch2:
+ b.Fatalf("received from ch2")
+ }
+ }
+}
diff --git a/pkg/syncevent/waiter_unsafe.go b/pkg/syncevent/waiter_unsafe.go
new file mode 100644
index 000000000..112e0e604
--- /dev/null
+++ b/pkg/syncevent/waiter_unsafe.go
@@ -0,0 +1,206 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build go1.11
+// +build !go1.15
+
+// Check go:linkname function signatures when updating Go version.
+
+package syncevent
+
+import (
+ "sync/atomic"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+//go:linkname gopark runtime.gopark
+func gopark(unlockf func(unsafe.Pointer, *unsafe.Pointer) bool, wg *unsafe.Pointer, reason uint8, traceEv byte, traceskip int)
+
+//go:linkname goready runtime.goready
+func goready(g unsafe.Pointer, traceskip int)
+
+const (
+ waitReasonSelect = 9 // Go: src/runtime/runtime2.go
+ traceEvGoBlockSelect = 24 // Go: src/runtime/trace.go
+)
+
+// Waiter allows a goroutine to block on pending events received by a Receiver.
+//
+// Waiter.Init() must be called before first use.
+type Waiter struct {
+ r Receiver
+
+ // g is one of:
+ //
+ // - nil: No goroutine is blocking in Wait.
+ //
+ // - &preparingG: A goroutine is in Wait preparing to sleep, but hasn't yet
+ // completed waiterUnlock(). Thus the wait can only be interrupted by
+ // replacing the value of g with nil (the G may not be in state Gwaiting
+ // yet, so we can't call goready.)
+ //
+ // - Otherwise: g is a pointer to the runtime.g in state Gwaiting for the
+ // goroutine blocked in Wait, which can only be woken by calling goready.
+ g unsafe.Pointer `state:"zerovalue"`
+}
+
+// Sentinel object for Waiter.g.
+var preparingG struct{}
+
+// Init must be called before first use of w.
+func (w *Waiter) Init() {
+ w.r.Init(w)
+}
+
+// Receiver returns the Receiver that receives events that unblock calls to
+// w.Wait().
+func (w *Waiter) Receiver() *Receiver {
+ return &w.r
+}
+
+// Pending returns the set of pending events.
+func (w *Waiter) Pending() Set {
+ return w.r.Pending()
+}
+
+// Wait blocks until at least one event is pending, then returns the set of
+// pending events. It does not affect the set of pending events; callers must
+// call w.Ack() to do so, or use w.WaitAndAck() instead.
+//
+// Precondition: Only one goroutine may call any Wait* method at a time.
+func (w *Waiter) Wait() Set {
+ return w.WaitFor(AllEvents)
+}
+
+// WaitFor blocks until at least one event in es is pending, then returns the
+// set of pending events (including those not in es). It does not affect the
+// set of pending events; callers must call w.Ack() to do so.
+//
+// Precondition: Only one goroutine may call any Wait* method at a time.
+func (w *Waiter) WaitFor(es Set) Set {
+ for {
+ // Optimization: Skip the atomic store to w.g if an event is already
+ // pending.
+ if p := w.r.Pending(); p&es != NoEvents {
+ return p
+ }
+
+ // Indicate that we're preparing to go to sleep.
+ atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG))
+
+ // If an event is pending, abort the sleep.
+ if p := w.r.Pending(); p&es != NoEvents {
+ atomic.StorePointer(&w.g, nil)
+ return p
+ }
+
+ // If w.g is still preparingG (i.e. w.NotifyPending() has not been
+ // called or has not reached atomic.SwapPointer()), go to sleep until
+ // w.NotifyPending() => goready().
+ gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0)
+ }
+}
+
+// Ack marks the given events as not pending.
+func (w *Waiter) Ack(es Set) {
+ w.r.Ack(es)
+}
+
+// WaitAndAckAll blocks until at least one event is pending, then marks all
+// events as not pending and returns the set of previously-pending events.
+//
+// Precondition: Only one goroutine may call any Wait* method at a time.
+func (w *Waiter) WaitAndAckAll() Set {
+ // Optimization: Skip the atomic store to w.g if an event is already
+ // pending. Call Pending() first since, in the common case that events are
+ // not yet pending, this skips an atomic swap on w.r.pending.
+ if w.r.Pending() != NoEvents {
+ if p := w.r.PendingAndAckAll(); p != NoEvents {
+ return p
+ }
+ }
+
+ for {
+ // Indicate that we're preparing to go to sleep.
+ atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG))
+
+ // If an event is pending, abort the sleep.
+ if w.r.Pending() != NoEvents {
+ if p := w.r.PendingAndAckAll(); p != NoEvents {
+ atomic.StorePointer(&w.g, nil)
+ return p
+ }
+ }
+
+ // If w.g is still preparingG (i.e. w.NotifyPending() has not been
+ // called or has not reached atomic.SwapPointer()), go to sleep until
+ // w.NotifyPending() => goready().
+ gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0)
+
+ // Check for pending events. We call PendingAndAckAll() directly now since
+ // we only expect to be woken after events become pending.
+ if p := w.r.PendingAndAckAll(); p != NoEvents {
+ return p
+ }
+ }
+}
+
+// Notify marks the given events as pending, possibly unblocking concurrent
+// calls to w.Wait() or w.WaitFor().
+func (w *Waiter) Notify(es Set) {
+ w.r.Notify(es)
+}
+
+// NotifyPending implements ReceiverCallback.NotifyPending. Users of Waiter
+// should not call NotifyPending.
+func (w *Waiter) NotifyPending() {
+ // Optimization: Skip the atomic swap on w.g if there is no sleeping
+ // goroutine. NotifyPending is called after w.r.Pending() is updated, so
+ // concurrent and future calls to w.Wait() will observe pending events and
+ // abort sleeping.
+ if atomic.LoadPointer(&w.g) == nil {
+ return
+ }
+ // Wake a sleeping G, or prevent a G that is preparing to sleep from doing
+ // so. Swap is needed here to ensure that only one call to NotifyPending
+ // calls goready.
+ if g := atomic.SwapPointer(&w.g, nil); g != nil && g != (unsafe.Pointer)(&preparingG) {
+ goready(g, 0)
+ }
+}
+
+var waiterPool = sync.Pool{
+ New: func() interface{} {
+ w := &Waiter{}
+ w.Init()
+ return w
+ },
+}
+
+// GetWaiter returns an unused Waiter. PutWaiter should be called to release
+// the Waiter once it is no longer needed.
+//
+// Where possible, users should prefer to associate each goroutine that calls
+// Waiter.Wait() with a distinct pre-allocated Waiter to avoid allocation of
+// Waiters in hot paths.
+func GetWaiter() *Waiter {
+ return waiterPool.Get().(*Waiter)
+}
+
+// PutWaiter releases an unused Waiter previously returned by GetWaiter.
+func PutWaiter(w *Waiter) {
+ waiterPool.Put(w)
+}
diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go
index 2269f6237..4b5a0fca6 100644
--- a/pkg/syserror/syserror.go
+++ b/pkg/syserror/syserror.go
@@ -29,6 +29,7 @@ var (
EACCES = error(syscall.EACCES)
EAGAIN = error(syscall.EAGAIN)
EBADF = error(syscall.EBADF)
+ EBADFD = error(syscall.EBADFD)
EBUSY = error(syscall.EBUSY)
ECHILD = error(syscall.ECHILD)
ECONNREFUSED = error(syscall.ECONNREFUSED)
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 711969b9b..6e0db2741 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -43,18 +43,28 @@ func (e *timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
-// A Listener is a wrapper around a tcpip endpoint that implements
+// A TCPListener is a wrapper around a TCP tcpip.Endpoint that implements
// net.Listener.
-type Listener struct {
+type TCPListener struct {
stack *stack.Stack
ep tcpip.Endpoint
wq *waiter.Queue
cancel chan struct{}
}
-// NewListener creates a new Listener.
-func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Listener, error) {
- // Create TCP endpoint, bind it, then start listening.
+// NewTCPListener creates a new TCPListener from a listening tcpip.Endpoint.
+func NewTCPListener(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *TCPListener {
+ return &TCPListener{
+ stack: s,
+ ep: ep,
+ wq: wq,
+ cancel: make(chan struct{}),
+ }
+}
+
+// ListenTCP creates a new TCPListener.
+func ListenTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPListener, error) {
+ // Create a TCP endpoint, bind it, then start listening.
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
if err != nil {
@@ -81,28 +91,23 @@ func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkPr
}
}
- return &Listener{
- stack: s,
- ep: ep,
- wq: &wq,
- cancel: make(chan struct{}),
- }, nil
+ return NewTCPListener(s, &wq, ep), nil
}
// Close implements net.Listener.Close.
-func (l *Listener) Close() error {
+func (l *TCPListener) Close() error {
l.ep.Close()
return nil
}
// Shutdown stops the HTTP server.
-func (l *Listener) Shutdown() {
+func (l *TCPListener) Shutdown() {
l.ep.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
close(l.cancel) // broadcast cancellation
}
// Addr implements net.Listener.Addr.
-func (l *Listener) Addr() net.Addr {
+func (l *TCPListener) Addr() net.Addr {
a, err := l.ep.GetLocalAddress()
if err != nil {
return nil
@@ -208,9 +213,9 @@ func (d *deadlineTimer) SetDeadline(t time.Time) error {
return nil
}
-// A Conn is a wrapper around a tcpip.Endpoint that implements the net.Conn
+// A TCPConn is a wrapper around a TCP tcpip.Endpoint that implements the net.Conn
// interface.
-type Conn struct {
+type TCPConn struct {
deadlineTimer
wq *waiter.Queue
@@ -228,9 +233,9 @@ type Conn struct {
read buffer.View
}
-// NewConn creates a new Conn.
-func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn {
- c := &Conn{
+// NewTCPConn creates a new TCPConn.
+func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn {
+ c := &TCPConn{
wq: wq,
ep: ep,
}
@@ -239,7 +244,7 @@ func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn {
}
// Accept implements net.Conn.Accept.
-func (l *Listener) Accept() (net.Conn, error) {
+func (l *TCPListener) Accept() (net.Conn, error) {
n, wq, err := l.ep.Accept()
if err == tcpip.ErrWouldBlock {
@@ -272,7 +277,7 @@ func (l *Listener) Accept() (net.Conn, error) {
}
}
- return NewConn(wq, n), nil
+ return NewTCPConn(wq, n), nil
}
type opErrorer interface {
@@ -323,7 +328,7 @@ func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, a
}
// Read implements net.Conn.Read.
-func (c *Conn) Read(b []byte) (int, error) {
+func (c *TCPConn) Read(b []byte) (int, error) {
c.readMu.Lock()
defer c.readMu.Unlock()
@@ -352,7 +357,7 @@ func (c *Conn) Read(b []byte) (int, error) {
}
// Write implements net.Conn.Write.
-func (c *Conn) Write(b []byte) (int, error) {
+func (c *TCPConn) Write(b []byte) (int, error) {
deadline := c.writeCancel()
// Check if deadlineTimer has already expired.
@@ -431,7 +436,7 @@ func (c *Conn) Write(b []byte) (int, error) {
}
// Close implements net.Conn.Close.
-func (c *Conn) Close() error {
+func (c *TCPConn) Close() error {
c.ep.Close()
return nil
}
@@ -440,7 +445,7 @@ func (c *Conn) Close() error {
// should just use Close.
//
// A TCP Half-Close is performed the same as CloseRead for *net.TCPConn.
-func (c *Conn) CloseRead() error {
+func (c *TCPConn) CloseRead() error {
if terr := c.ep.Shutdown(tcpip.ShutdownRead); terr != nil {
return c.newOpError("close", errors.New(terr.String()))
}
@@ -451,7 +456,7 @@ func (c *Conn) CloseRead() error {
// should just use Close.
//
// A TCP Half-Close is performed the same as CloseWrite for *net.TCPConn.
-func (c *Conn) CloseWrite() error {
+func (c *TCPConn) CloseWrite() error {
if terr := c.ep.Shutdown(tcpip.ShutdownWrite); terr != nil {
return c.newOpError("close", errors.New(terr.String()))
}
@@ -459,7 +464,7 @@ func (c *Conn) CloseWrite() error {
}
// LocalAddr implements net.Conn.LocalAddr.
-func (c *Conn) LocalAddr() net.Addr {
+func (c *TCPConn) LocalAddr() net.Addr {
a, err := c.ep.GetLocalAddress()
if err != nil {
return nil
@@ -468,7 +473,7 @@ func (c *Conn) LocalAddr() net.Addr {
}
// RemoteAddr implements net.Conn.RemoteAddr.
-func (c *Conn) RemoteAddr() net.Addr {
+func (c *TCPConn) RemoteAddr() net.Addr {
a, err := c.ep.GetRemoteAddress()
if err != nil {
return nil
@@ -476,7 +481,7 @@ func (c *Conn) RemoteAddr() net.Addr {
return fullToTCPAddr(a)
}
-func (c *Conn) newOpError(op string, err error) *net.OpError {
+func (c *TCPConn) newOpError(op string, err error) *net.OpError {
return &net.OpError{
Op: op,
Net: "tcp",
@@ -494,14 +499,14 @@ func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr {
return &net.UDPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)}
}
-// DialTCP creates a new TCP Conn connected to the specified address.
-func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) {
+// DialTCP creates a new TCPConn connected to the specified address.
+func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
return DialContextTCP(context.Background(), s, addr, network)
}
-// DialContextTCP creates a new TCP Conn connected to the specified address
+// DialContextTCP creates a new TCPConn connected to the specified address
// with the option of adding cancellation and timeouts.
-func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) {
+func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
// Create TCP endpoint, then connect.
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
@@ -543,12 +548,12 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress,
}
}
- return NewConn(&wq, ep), nil
+ return NewTCPConn(&wq, ep), nil
}
-// A PacketConn is a wrapper around a tcpip endpoint that implements
-// net.PacketConn.
-type PacketConn struct {
+// A UDPConn is a wrapper around a UDP tcpip.Endpoint that implements
+// net.Conn and net.PacketConn.
+type UDPConn struct {
deadlineTimer
stack *stack.Stack
@@ -556,9 +561,9 @@ type PacketConn struct {
wq *waiter.Queue
}
-// NewPacketConn creates a new PacketConn.
-func NewPacketConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *PacketConn {
- c := &PacketConn{
+// NewUDPConn creates a new UDPConn.
+func NewUDPConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *UDPConn {
+ c := &UDPConn{
stack: s,
ep: ep,
wq: wq,
@@ -567,12 +572,12 @@ func NewPacketConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *PacketC
return c
}
-// DialUDP creates a new PacketConn.
+// DialUDP creates a new UDPConn.
//
// If laddr is nil, a local address is automatically chosen.
//
-// If raddr is nil, the PacketConn is left unconnected.
-func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) {
+// If raddr is nil, the UDPConn is left unconnected.
+func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*UDPConn, error) {
var wq waiter.Queue
ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq)
if err != nil {
@@ -591,7 +596,7 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw
}
}
- c := NewPacketConn(s, &wq, ep)
+ c := NewUDPConn(s, &wq, ep)
if raddr != nil {
if err := c.ep.Connect(*raddr); err != nil {
@@ -608,11 +613,11 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw
return c, nil
}
-func (c *PacketConn) newOpError(op string, err error) *net.OpError {
+func (c *UDPConn) newOpError(op string, err error) *net.OpError {
return c.newRemoteOpError(op, nil, err)
}
-func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError {
+func (c *UDPConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError {
return &net.OpError{
Op: op,
Net: "udp",
@@ -623,7 +628,7 @@ func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *ne
}
// RemoteAddr implements net.Conn.RemoteAddr.
-func (c *PacketConn) RemoteAddr() net.Addr {
+func (c *UDPConn) RemoteAddr() net.Addr {
a, err := c.ep.GetRemoteAddress()
if err != nil {
return nil
@@ -632,13 +637,13 @@ func (c *PacketConn) RemoteAddr() net.Addr {
}
// Read implements net.Conn.Read
-func (c *PacketConn) Read(b []byte) (int, error) {
+func (c *UDPConn) Read(b []byte) (int, error) {
bytesRead, _, err := c.ReadFrom(b)
return bytesRead, err
}
// ReadFrom implements net.PacketConn.ReadFrom.
-func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
+func (c *UDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
deadline := c.readCancel()
var addr tcpip.FullAddress
@@ -650,12 +655,12 @@ func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
return copy(b, read), fullToUDPAddr(addr), nil
}
-func (c *PacketConn) Write(b []byte) (int, error) {
+func (c *UDPConn) Write(b []byte) (int, error) {
return c.WriteTo(b, nil)
}
// WriteTo implements net.PacketConn.WriteTo.
-func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
+func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
deadline := c.writeCancel()
// Check if deadline has already expired.
@@ -713,13 +718,13 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
}
// Close implements net.PacketConn.Close.
-func (c *PacketConn) Close() error {
+func (c *UDPConn) Close() error {
c.ep.Close()
return nil
}
// LocalAddr implements net.PacketConn.LocalAddr.
-func (c *PacketConn) LocalAddr() net.Addr {
+func (c *UDPConn) LocalAddr() net.Addr {
a, err := c.ep.GetLocalAddress()
if err != nil {
return nil
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index ee077ae83..3c552988a 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -41,7 +41,7 @@ const (
)
func TestTimeouts(t *testing.T) {
- nc := NewConn(nil, nil)
+ nc := NewTCPConn(nil, nil)
dlfs := []struct {
name string
f func(time.Time) error
@@ -127,12 +127,16 @@ func TestCloseReader(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
- l, e := NewListener(s, addr, ipv4.ProtocolNumber)
+ l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
if e != nil {
t.Fatalf("NewListener() = %v", e)
}
@@ -168,13 +172,17 @@ func TestCloseReader(t *testing.T) {
sender.close()
}
-// TestCloseReaderWithForwarder tests that Conn.Close() wakes Conn.Read() when
+// TestCloseReaderWithForwarder tests that TCPConn.Close wakes TCPConn.Read when
// using tcp.Forwarder.
func TestCloseReaderWithForwarder(t *testing.T) {
s, err := newLoopbackStack()
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@@ -192,7 +200,7 @@ func TestCloseReaderWithForwarder(t *testing.T) {
defer ep.Close()
r.Complete(false)
- c := NewConn(&wq, ep)
+ c := NewTCPConn(&wq, ep)
// Give c.Read() a chance to block before closing the connection.
time.AfterFunc(time.Millisecond*50, func() {
@@ -225,30 +233,21 @@ func TestCloseRead(t *testing.T) {
if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
- ep, err := r.CreateEndpoint(&wq)
+ _, err := r.CreateEndpoint(&wq)
if err != nil {
t.Fatalf("r.CreateEndpoint() = %v", err)
}
- defer ep.Close()
- r.Complete(false)
-
- c := NewConn(&wq, ep)
-
- buf := make([]byte, 256)
- n, e := c.Read(buf)
- if e != nil || string(buf[:n]) != "abc123" {
- t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, e)
- }
-
- if n, e = c.Write([]byte("abc123")); e != nil {
- t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e)
- }
+ // Endpoint will be closed in deferred s.Close (above).
})
s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
@@ -257,7 +256,7 @@ func TestCloseRead(t *testing.T) {
if terr != nil {
t.Fatalf("connect() = %v", terr)
}
- c := NewConn(tc.wq, tc.ep)
+ c := NewTCPConn(tc.wq, tc.ep)
if err := c.CloseRead(); err != nil {
t.Errorf("c.CloseRead() = %v", err)
@@ -278,6 +277,10 @@ func TestCloseWrite(t *testing.T) {
if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@@ -291,7 +294,7 @@ func TestCloseWrite(t *testing.T) {
defer ep.Close()
r.Complete(false)
- c := NewConn(&wq, ep)
+ c := NewTCPConn(&wq, ep)
n, e := c.Read(make([]byte, 256))
if n != 0 || e != io.EOF {
@@ -309,7 +312,7 @@ func TestCloseWrite(t *testing.T) {
if terr != nil {
t.Fatalf("connect() = %v", terr)
}
- c := NewConn(tc.wq, tc.ep)
+ c := NewTCPConn(tc.wq, tc.ep)
if err := c.CloseWrite(); err != nil {
t.Errorf("c.CloseWrite() = %v", err)
@@ -334,6 +337,10 @@ func TestUDPForwarder(t *testing.T) {
if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
@@ -353,7 +360,7 @@ func TestUDPForwarder(t *testing.T) {
}
defer ep.Close()
- c := NewConn(&wq, ep)
+ c := NewTCPConn(&wq, ep)
buf := make([]byte, 256)
n, e := c.Read(buf)
@@ -391,12 +398,16 @@ func TestDeadlineChange(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
- l, e := NewListener(s, addr, ipv4.ProtocolNumber)
+ l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
if e != nil {
t.Fatalf("NewListener() = %v", e)
}
@@ -440,6 +451,10 @@ func TestPacketConnTransfer(t *testing.T) {
if e != nil {
t.Fatalf("newLoopbackStack() = %v", e)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
@@ -492,6 +507,10 @@ func TestConnectedPacketConnTransfer(t *testing.T) {
if e != nil {
t.Fatalf("newLoopbackStack() = %v", e)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
@@ -541,7 +560,7 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
addr := tcpip.FullAddress{NICID, ip, 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
- l, err := NewListener(s, addr, ipv4.ProtocolNumber)
+ l, err := ListenTCP(s, addr, ipv4.ProtocolNumber)
if err != nil {
return nil, nil, nil, fmt.Errorf("NewListener: %v", err)
}
@@ -562,6 +581,8 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
stop = func() {
c1.Close()
c2.Close()
+ s.Close()
+ s.Wait()
}
if err := l.Close(); err != nil {
@@ -624,6 +645,10 @@ func TestTCPDialError(t *testing.T) {
if e != nil {
t.Fatalf("newLoopbackStack() = %v", e)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
@@ -641,6 +666,10 @@ func TestDialContextTCPCanceled(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@@ -659,6 +688,10 @@ func TestDialContextTCPTimeout(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index 150310c11..17e94c562 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -156,3 +156,9 @@ func (vv *VectorisedView) Append(vv2 VectorisedView) {
vv.views = append(vv.views, vv2.views...)
vv.size += vv2.size
}
+
+// AppendView appends the given view into this vectorised view.
+func (vv *VectorisedView) AppendView(v View) {
+ vv.views = append(vv.views, v)
+ vv.size += len(v)
+}
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 4d6ae0871..c6c160dfc 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -161,6 +161,20 @@ func FragmentFlags(flags uint8) NetworkChecker {
}
}
+// ReceiveTClass creates a checker that checks the TCLASS field in
+// ControlMessages.
+func ReceiveTClass(want uint32) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasTClass {
+ t.Fatalf("got cm.HasTClass = %t, want cm.TClass = %d", cm.HasTClass, want)
+ }
+ if got := cm.TClass; got != want {
+ t.Fatalf("got cm.TClass = %d, want %d", got, want)
+ }
+ }
+}
+
// ReceiveTOS creates a checker that checks the TOS field in ControlMessages.
func ReceiveTOS(want uint8) ControlMessagesChecker {
return func(t *testing.T, cm tcpip.ControlMessages) {
diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go
index f5d2c127f..b1e92d2d7 100644
--- a/pkg/tcpip/header/eth.go
+++ b/pkg/tcpip/header/eth.go
@@ -134,3 +134,44 @@ func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool {
// addr is a valid unicast ethernet address.
return true
}
+
+// EthernetAddressFromMulticastIPv4Address returns a multicast Ethernet address
+// for a multicast IPv4 address.
+//
+// addr MUST be a multicast IPv4 address.
+func EthernetAddressFromMulticastIPv4Address(addr tcpip.Address) tcpip.LinkAddress {
+ var linkAddrBytes [EthernetAddressSize]byte
+ // RFC 1112 Host Extensions for IP Multicasting
+ //
+ // 6.4. Extensions to an Ethernet Local Network Module:
+ //
+ // An IP host group address is mapped to an Ethernet multicast
+ // address by placing the low-order 23-bits of the IP address
+ // into the low-order 23 bits of the Ethernet multicast address
+ // 01-00-5E-00-00-00 (hex).
+ linkAddrBytes[0] = 0x1
+ linkAddrBytes[2] = 0x5e
+ linkAddrBytes[3] = addr[1] & 0x7F
+ copy(linkAddrBytes[4:], addr[IPv4AddressSize-2:])
+ return tcpip.LinkAddress(linkAddrBytes[:])
+}
+
+// EthernetAddressFromMulticastIPv6Address returns a multicast Ethernet address
+// for a multicast IPv6 address.
+//
+// addr MUST be a multicast IPv6 address.
+func EthernetAddressFromMulticastIPv6Address(addr tcpip.Address) tcpip.LinkAddress {
+ // RFC 2464 Transmission of IPv6 Packets over Ethernet Networks
+ //
+ // 7. Address Mapping -- Multicast
+ //
+ // An IPv6 packet with a multicast destination address DST,
+ // consisting of the sixteen octets DST[1] through DST[16], is
+ // transmitted to the Ethernet multicast address whose first
+ // two octets are the value 3333 hexadecimal and whose last
+ // four octets are the last four octets of DST.
+ linkAddrBytes := []byte(addr[IPv6AddressSize-EthernetAddressSize:])
+ linkAddrBytes[0] = 0x33
+ linkAddrBytes[1] = 0x33
+ return tcpip.LinkAddress(linkAddrBytes[:])
+}
diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go
index 6634c90f5..7a0014ad9 100644
--- a/pkg/tcpip/header/eth_test.go
+++ b/pkg/tcpip/header/eth_test.go
@@ -66,3 +66,37 @@ func TestIsValidUnicastEthernetAddress(t *testing.T) {
})
}
}
+
+func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "IPv4 Multicast without 24th bit set",
+ addr: "\xe0\x7e\xdc\xba",
+ expectedLinkAddr: "\x01\x00\x5e\x7e\xdc\xba",
+ },
+ {
+ name: "IPv4 Multicast with 24th bit set",
+ addr: "\xe0\xfe\xdc\xba",
+ expectedLinkAddr: "\x01\x00\x5e\x7e\xdc\xba",
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := EthernetAddressFromMulticastIPv4Address(test.addr); got != test.expectedLinkAddr {
+ t.Fatalf("got EthernetAddressFromMulticastIPv4Address(%s) = %s, want = %s", got, test.expectedLinkAddr)
+ }
+ })
+ }
+}
+
+func TestEthernetAddressFromMulticastIPv6Address(t *testing.T) {
+ addr := tcpip.Address("\xff\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x1a")
+ if got, want := EthernetAddressFromMulticastIPv6Address(addr), tcpip.LinkAddress("\x33\x33\x0d\x0e\x0f\x1a"); got != want {
+ t.Fatalf("got EthernetAddressFromMulticastIPv6Address(%s) = %s, want = %s", addr, got, want)
+ }
+}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 70e6ce095..76e88e9b3 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -115,6 +115,19 @@ const (
// for the secret key used to generate an opaque interface identifier as
// outlined by RFC 7217.
OpaqueIIDSecretKeyMinBytes = 16
+
+ // ipv6MulticastAddressScopeByteIdx is the byte where the scope (scop) field
+ // is located within a multicast IPv6 address, as per RFC 4291 section 2.7.
+ ipv6MulticastAddressScopeByteIdx = 1
+
+ // ipv6MulticastAddressScopeMask is the mask for the scope (scop) field,
+ // within the byte holding the field, as per RFC 4291 section 2.7.
+ ipv6MulticastAddressScopeMask = 0xF
+
+ // ipv6LinkLocalMulticastScope is the value of the scope (scop) field within
+ // a multicast IPv6 address that indicates the address has link-local scope,
+ // as per RFC 4291 section 2.7.
+ ipv6LinkLocalMulticastScope = 2
)
// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the
@@ -340,6 +353,12 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool {
return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80
}
+// IsV6LinkLocalMulticastAddress determines if the provided address is an IPv6
+// link-local multicast address.
+func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool {
+ return IsV6MulticastAddress(addr) && addr[ipv6MulticastAddressScopeByteIdx]&ipv6MulticastAddressScopeMask == ipv6LinkLocalMulticastScope
+}
+
// IsV6UniqueLocalAddress determines if the provided address is an IPv6
// unique-local address (within the prefix FC00::/7).
func IsV6UniqueLocalAddress(addr tcpip.Address) bool {
@@ -411,6 +430,9 @@ func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) {
}
switch {
+ case IsV6LinkLocalMulticastAddress(addr):
+ return LinkLocalScope, nil
+
case IsV6LinkLocalAddress(addr):
return LinkLocalScope, nil
diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go
index 29f54bc57..426a873b1 100644
--- a/pkg/tcpip/header/ipv6_test.go
+++ b/pkg/tcpip/header/ipv6_test.go
@@ -17,6 +17,7 @@ package header_test
import (
"bytes"
"crypto/sha256"
+ "fmt"
"testing"
"github.com/google/go-cmp/cmp"
@@ -26,11 +27,12 @@ import (
)
const (
- linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
)
func TestEthernetAdddressToModifiedEUI64(t *testing.T) {
@@ -255,6 +257,85 @@ func TestIsV6UniqueLocalAddress(t *testing.T) {
}
}
+func TestIsV6LinkLocalMulticastAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ expected bool
+ }{
+ {
+ name: "Valid Link Local Multicast",
+ addr: linkLocalMulticastAddr,
+ expected: true,
+ },
+ {
+ name: "Valid Link Local Multicast with flags",
+ addr: "\xff\xf2\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ expected: true,
+ },
+ {
+ name: "Link Local Unicast",
+ addr: linkLocalAddr,
+ expected: false,
+ },
+ {
+ name: "IPv4 Multicast",
+ addr: "\xe0\x00\x00\x01",
+ expected: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := header.IsV6LinkLocalMulticastAddress(test.addr); got != test.expected {
+ t.Errorf("got header.IsV6LinkLocalMulticastAddress(%s) = %t, want = %t", test.addr, got, test.expected)
+ }
+ })
+ }
+}
+
+func TestIsV6LinkLocalAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ expected bool
+ }{
+ {
+ name: "Valid Link Local Unicast",
+ addr: linkLocalAddr,
+ expected: true,
+ },
+ {
+ name: "Link Local Multicast",
+ addr: linkLocalMulticastAddr,
+ expected: false,
+ },
+ {
+ name: "Unique Local",
+ addr: uniqueLocalAddr1,
+ expected: false,
+ },
+ {
+ name: "Global",
+ addr: globalAddr,
+ expected: false,
+ },
+ {
+ name: "IPv4 Link Local",
+ addr: "\xa9\xfe\x00\x01",
+ expected: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := header.IsV6LinkLocalAddress(test.addr); got != test.expected {
+ t.Errorf("got header.IsV6LinkLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected)
+ }
+ })
+ }
+}
+
func TestScopeForIPv6Address(t *testing.T) {
tests := []struct {
name string
@@ -269,12 +350,18 @@ func TestScopeForIPv6Address(t *testing.T) {
err: nil,
},
{
- name: "Link Local",
+ name: "Link Local Unicast",
addr: linkLocalAddr,
scope: header.LinkLocalScope,
err: nil,
},
{
+ name: "Link Local Multicast",
+ addr: linkLocalMulticastAddr,
+ scope: header.LinkLocalScope,
+ err: nil,
+ },
+ {
name: "Global",
addr: globalAddr,
scope: header.GlobalScope,
@@ -300,3 +387,31 @@ func TestScopeForIPv6Address(t *testing.T) {
})
}
}
+
+func TestSolicitedNodeAddr(t *testing.T) {
+ tests := []struct {
+ addr tcpip.Address
+ want tcpip.Address
+ }{
+ {
+ addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\xa0",
+ want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0",
+ },
+ {
+ addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x0e\x0f\xa0",
+ want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0",
+ },
+ {
+ addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x01\x02\x03",
+ want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x01\x02\x03",
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(fmt.Sprintf("%s", test.addr), func(t *testing.T) {
+ if got := header.SolicitedNodeAddr(test.addr); got != test.want {
+ t.Fatalf("got header.SolicitedNodeAddr(%s) = %s, want = %s", test.addr, got, test.want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go
index 4bfb3149e..dbaccbb36 100644
--- a/pkg/tcpip/iptables/iptables.go
+++ b/pkg/tcpip/iptables/iptables.go
@@ -52,10 +52,10 @@ func DefaultTables() IPTables {
Tables: map[string]Table{
TablenameNat: Table{
Rules: []Rule{
- Rule{Target: UnconditionalAcceptTarget{}},
- Rule{Target: UnconditionalAcceptTarget{}},
- Rule{Target: UnconditionalAcceptTarget{}},
- Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
BuiltinChains: map[Hook]int{
@@ -74,8 +74,8 @@ func DefaultTables() IPTables {
},
TablenameMangle: Table{
Rules: []Rule{
- Rule{Target: UnconditionalAcceptTarget{}},
- Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
BuiltinChains: map[Hook]int{
@@ -90,9 +90,9 @@ func DefaultTables() IPTables {
},
TablenameFilter: Table{
Rules: []Rule{
- Rule{Target: UnconditionalAcceptTarget{}},
- Rule{Target: UnconditionalAcceptTarget{}},
- Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
Rule{Target: ErrorTarget{}},
},
BuiltinChains: map[Hook]int{
@@ -135,25 +135,53 @@ func EmptyFilterTable() Table {
}
}
+// A chainVerdict is what a table decides should be done with a packet.
+type chainVerdict int
+
+const (
+ // chainAccept indicates the packet should continue through netstack.
+ chainAccept chainVerdict = iota
+
+ // chainAccept indicates the packet should be dropped.
+ chainDrop
+
+ // chainReturn indicates the packet should return to the calling chain
+ // or the underflow rule of a builtin chain.
+ chainReturn
+)
+
// Check runs pkt through the rules for hook. It returns true when the packet
// should continue traversing the network stack and false when it should be
// dropped.
+//
+// Precondition: pkt.NetworkHeader is set.
func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool {
- // TODO(gvisor.dev/issue/170): A lot of this is uncomplicated because
- // we're missing features. Jumps, the call stack, etc. aren't checked
- // for yet because we're yet to support them.
-
// Go through each table containing the hook.
for _, tablename := range it.Priorities[hook] {
- switch verdict := it.checkTable(hook, pkt, tablename); verdict {
+ table := it.Tables[tablename]
+ ruleIdx := table.BuiltinChains[hook]
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx); verdict {
// If the table returns Accept, move on to the next table.
- case Accept:
+ case chainAccept:
continue
// The Drop verdict is final.
- case Drop:
+ case chainDrop:
return false
- case Stolen, Queue, Repeat, None, Jump, Return, Continue:
- panic(fmt.Sprintf("Unimplemented verdict %v.", verdict))
+ case chainReturn:
+ // Any Return from a built-in chain means we have to
+ // call the underflow.
+ underflow := table.Rules[table.Underflows[hook]]
+ switch v, _ := underflow.Target.Action(pkt); v {
+ case RuleAccept:
+ continue
+ case RuleDrop:
+ return false
+ case RuleJump, RuleReturn:
+ panic("Underflows should only return RuleAccept or RuleDrop.")
+ default:
+ panic(fmt.Sprintf("Unknown verdict: %d", v))
+ }
+
default:
panic(fmt.Sprintf("Unknown verdict %v.", verdict))
}
@@ -163,36 +191,60 @@ func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool {
return true
}
-func (it *IPTables) checkTable(hook Hook, pkt tcpip.PacketBuffer, tablename string) Verdict {
+// Precondition: pkt.NetworkHeader is set.
+func (it *IPTables) checkChain(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
- table := it.Tables[tablename]
- for ruleIdx := table.BuiltinChains[hook]; ruleIdx < len(table.Rules); ruleIdx++ {
- switch verdict := it.checkRule(hook, pkt, table, ruleIdx); verdict {
- // In either of these cases, this table is done with the packet.
- case Accept, Drop:
- return verdict
- // Continue traversing the rules of the table.
- case Continue:
- continue
- case Stolen, Queue, Repeat, None, Jump, Return:
- panic(fmt.Sprintf("Unimplemented verdict %v.", verdict))
+ for ruleIdx < len(table.Rules) {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx); verdict {
+ case RuleAccept:
+ return chainAccept
+
+ case RuleDrop:
+ return chainDrop
+
+ case RuleReturn:
+ return chainReturn
+
+ case RuleJump:
+ // "Jumping" to the next rule just means we're
+ // continuing on down the list.
+ if jumpTo == ruleIdx+1 {
+ ruleIdx++
+ continue
+ }
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo); verdict {
+ case chainAccept:
+ return chainAccept
+ case chainDrop:
+ return chainDrop
+ case chainReturn:
+ ruleIdx++
+ continue
+ default:
+ panic(fmt.Sprintf("Unknown verdict: %d", verdict))
+ }
+
default:
- panic(fmt.Sprintf("Unknown verdict %v.", verdict))
+ panic(fmt.Sprintf("Unknown verdict: %d", verdict))
}
+
}
- panic(fmt.Sprintf("Traversed past the entire list of iptables rules in table %q.", tablename))
+ // We got through the entire table without a decision. Default to DROP
+ // for safety.
+ return chainDrop
}
// Precondition: pk.NetworkHeader is set.
-func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) Verdict {
+func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// First check whether the packet matches the IP header filter.
// TODO(gvisor.dev/issue/170): Support other fields of the filter.
if rule.Filter.Protocol != 0 && rule.Filter.Protocol != header.IPv4(pkt.NetworkHeader).TransportProtocol() {
- return Continue
+ // Continue on to the next rule.
+ return RuleJump, ruleIdx + 1
}
// Go through each rule matcher. If they all match, run
@@ -200,14 +252,14 @@ func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ru
for _, matcher := range rule.Matchers {
matches, hotdrop := matcher.Match(hook, pkt, "")
if hotdrop {
- return Drop
+ return RuleDrop, 0
}
if !matches {
- return Continue
+ // Continue on to the next rule.
+ return RuleJump, ruleIdx + 1
}
}
// All the matchers matched, so run the target.
- verdict, _ := rule.Target.Action(pkt)
- return verdict
+ return rule.Target.Action(pkt)
}
diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go
index 4dd281371..81a2e39a2 100644
--- a/pkg/tcpip/iptables/targets.go
+++ b/pkg/tcpip/iptables/targets.go
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// This file contains various Targets.
-
package iptables
import (
@@ -21,20 +19,20 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
-// UnconditionalAcceptTarget accepts all packets.
-type UnconditionalAcceptTarget struct{}
+// AcceptTarget accepts packets.
+type AcceptTarget struct{}
// Action implements Target.Action.
-func (UnconditionalAcceptTarget) Action(packet tcpip.PacketBuffer) (Verdict, string) {
- return Accept, ""
+func (AcceptTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, int) {
+ return RuleAccept, 0
}
-// UnconditionalDropTarget denies all packets.
-type UnconditionalDropTarget struct{}
+// DropTarget drops packets.
+type DropTarget struct{}
// Action implements Target.Action.
-func (UnconditionalDropTarget) Action(packet tcpip.PacketBuffer) (Verdict, string) {
- return Drop, ""
+func (DropTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, int) {
+ return RuleDrop, 0
}
// ErrorTarget logs an error and drops the packet. It represents a target that
@@ -42,7 +40,26 @@ func (UnconditionalDropTarget) Action(packet tcpip.PacketBuffer) (Verdict, strin
type ErrorTarget struct{}
// Action implements Target.Action.
-func (ErrorTarget) Action(packet tcpip.PacketBuffer) (Verdict, string) {
- log.Warningf("ErrorTarget triggered.")
- return Drop, ""
+func (ErrorTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, int) {
+ log.Debugf("ErrorTarget triggered.")
+ return RuleDrop, 0
+}
+
+// UserChainTarget marks a rule as the beginning of a user chain.
+type UserChainTarget struct {
+ Name string
+}
+
+// Action implements Target.Action.
+func (UserChainTarget) Action(tcpip.PacketBuffer) (RuleVerdict, int) {
+ panic("UserChainTarget should never be called.")
+}
+
+// ReturnTarget returns from the current chain. If the chain is a built-in, the
+// hook's underflow should be called.
+type ReturnTarget struct{}
+
+// Action implements Target.Action.
+func (ReturnTarget) Action(tcpip.PacketBuffer) (RuleVerdict, int) {
+ return RuleReturn, 0
}
diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go
index 50893cc55..7d032fd23 100644
--- a/pkg/tcpip/iptables/types.go
+++ b/pkg/tcpip/iptables/types.go
@@ -56,44 +56,21 @@ const (
NumHooks
)
-// A Verdict is returned by a rule's target to indicate how traversal of rules
-// should (or should not) continue.
-type Verdict int
+// A RuleVerdict is what a rule decides should be done with a packet.
+type RuleVerdict int
const (
- // Invalid indicates an unkonwn or erroneous verdict.
- Invalid Verdict = iota
+ // RuleAccept indicates the packet should continue through netstack.
+ RuleAccept RuleVerdict = iota
- // Accept indicates the packet should continue traversing netstack as
- // normal.
- Accept
+ // RuleDrop indicates the packet should be dropped.
+ RuleDrop
- // Drop inicates the packet should be dropped, stopping traversing
- // netstack.
- Drop
+ // RuleJump indicates the packet should jump to another chain.
+ RuleJump
- // Stolen indicates the packet was co-opted by the target and should
- // stop traversing netstack.
- Stolen
-
- // Queue indicates the packet should be queued for userspace processing.
- Queue
-
- // Repeat indicates the packet should re-traverse the chains for the
- // current hook.
- Repeat
-
- // None indicates no verdict was reached.
- None
-
- // Jump indicates a jump to another chain.
- Jump
-
- // Continue indicates that traversal should continue at the next rule.
- Continue
-
- // Return indicates that traversal should return to the calling chain.
- Return
+ // RuleReturn indicates the packet should return to the previous chain.
+ RuleReturn
)
// IPTables holds all the tables for a netstack.
@@ -132,7 +109,7 @@ type Table struct {
// ValidHooks returns a bitmap of the builtin hooks for the given table.
func (table *Table) ValidHooks() uint32 {
hooks := uint32(0)
- for hook, _ := range table.BuiltinChains {
+ for hook := range table.BuiltinChains {
hooks |= 1 << hook
}
return hooks
@@ -171,9 +148,14 @@ type IPHeaderFilter struct {
// A Matcher is the interface for matching packets.
type Matcher interface {
+ // Name returns the name of the Matcher.
+ Name() string
+
// Match returns whether the packet matches and whether the packet
// should be "hotdropped", i.e. dropped immediately. This is usually
// used for suspicious packets.
+ //
+ // Precondition: packet.NetworkHeader is set.
Match(hook Hook, packet tcpip.PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
}
@@ -181,6 +163,6 @@ type Matcher interface {
type Target interface {
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
- // Jump, it also returns the name of the chain to jump to.
- Action(packet tcpip.PacketBuffer) (Verdict, string)
+ // Jump, it also returns the index of the rule to jump to.
+ Action(packet tcpip.PacketBuffer) (RuleVerdict, int)
}
diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD
index 3974c464e..b8b93e78e 100644
--- a/pkg/tcpip/link/channel/BUILD
+++ b/pkg/tcpip/link/channel/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = ["channel.go"],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index 71b9da797..5944ba190 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -20,6 +20,7 @@ package channel
import (
"context"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -30,24 +31,139 @@ type PacketInfo struct {
Pkt tcpip.PacketBuffer
Proto tcpip.NetworkProtocolNumber
GSO *stack.GSO
+ Route stack.Route
+}
+
+// Notification is the interface for receiving notification from the packet
+// queue.
+type Notification interface {
+ // WriteNotify will be called when a write happens to the queue.
+ WriteNotify()
+}
+
+// NotificationHandle is an opaque handle to the registered notification target.
+// It can be used to unregister the notification when no longer interested.
+//
+// +stateify savable
+type NotificationHandle struct {
+ n Notification
+}
+
+type queue struct {
+ // mu protects fields below.
+ mu sync.RWMutex
+ // c is the outbound packet channel. Sending to c should hold mu.
+ c chan PacketInfo
+ numWrite int
+ numRead int
+ notify []*NotificationHandle
+}
+
+func (q *queue) Close() {
+ close(q.c)
+}
+
+func (q *queue) Read() (PacketInfo, bool) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ select {
+ case p := <-q.c:
+ q.numRead++
+ return p, true
+ default:
+ return PacketInfo{}, false
+ }
+}
+
+func (q *queue) ReadContext(ctx context.Context) (PacketInfo, bool) {
+ // We have to receive from channel without holding the lock, since it can
+ // block indefinitely. This will cause a window that numWrite - numRead
+ // produces a larger number, but won't go to negative. numWrite >= numRead
+ // still holds.
+ select {
+ case pkt := <-q.c:
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ q.numRead++
+ return pkt, true
+ case <-ctx.Done():
+ return PacketInfo{}, false
+ }
+}
+
+func (q *queue) Write(p PacketInfo) bool {
+ wrote := false
+
+ // It's important to make sure nobody can see numWrite until we increment it,
+ // so numWrite >= numRead holds.
+ q.mu.Lock()
+ select {
+ case q.c <- p:
+ wrote = true
+ q.numWrite++
+ default:
+ }
+ notify := q.notify
+ q.mu.Unlock()
+
+ if wrote {
+ // Send notification outside of lock.
+ for _, h := range notify {
+ h.n.WriteNotify()
+ }
+ }
+ return wrote
+}
+
+func (q *queue) Num() int {
+ q.mu.RLock()
+ defer q.mu.RUnlock()
+ n := q.numWrite - q.numRead
+ if n < 0 {
+ panic("numWrite < numRead")
+ }
+ return n
+}
+
+func (q *queue) AddNotify(notify Notification) *NotificationHandle {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ h := &NotificationHandle{n: notify}
+ q.notify = append(q.notify, h)
+ return h
+}
+
+func (q *queue) RemoveNotify(handle *NotificationHandle) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ // Make a copy, since we reads the array outside of lock when notifying.
+ notify := make([]*NotificationHandle, 0, len(q.notify))
+ for _, h := range q.notify {
+ if h != handle {
+ notify = append(notify, h)
+ }
+ }
+ q.notify = notify
}
// Endpoint is link layer endpoint that stores outbound packets in a channel
// and allows injection of inbound packets.
type Endpoint struct {
- dispatcher stack.NetworkDispatcher
- mtu uint32
- linkAddr tcpip.LinkAddress
- GSO bool
+ dispatcher stack.NetworkDispatcher
+ mtu uint32
+ linkAddr tcpip.LinkAddress
+ LinkEPCapabilities stack.LinkEndpointCapabilities
- // c is where outbound packets are queued.
- c chan PacketInfo
+ // Outbound packet queue.
+ q *queue
}
// New creates a new channel endpoint.
func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint {
return &Endpoint{
- c: make(chan PacketInfo, size),
+ q: &queue{
+ c: make(chan PacketInfo, size),
+ },
mtu: mtu,
linkAddr: linkAddr,
}
@@ -56,43 +172,36 @@ func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint {
// Close closes e. Further packet injections will panic. Reads continue to
// succeed until all packets are read.
func (e *Endpoint) Close() {
- close(e.c)
+ e.q.Close()
}
-// Read does non-blocking read for one packet from the outbound packet queue.
+// Read does non-blocking read one packet from the outbound packet queue.
func (e *Endpoint) Read() (PacketInfo, bool) {
- select {
- case pkt := <-e.c:
- return pkt, true
- default:
- return PacketInfo{}, false
- }
+ return e.q.Read()
}
// ReadContext does blocking read for one packet from the outbound packet queue.
// It can be cancelled by ctx, and in this case, it returns false.
func (e *Endpoint) ReadContext(ctx context.Context) (PacketInfo, bool) {
- select {
- case pkt := <-e.c:
- return pkt, true
- case <-ctx.Done():
- return PacketInfo{}, false
- }
+ return e.q.ReadContext(ctx)
}
// Drain removes all outbound packets from the channel and counts them.
func (e *Endpoint) Drain() int {
c := 0
for {
- select {
- case <-e.c:
- c++
- default:
+ if _, ok := e.Read(); !ok {
return c
}
+ c++
}
}
+// NumQueued returns the number of packet queued for outbound.
+func (e *Endpoint) NumQueued() int {
+ return e.q.Num()
+}
+
// InjectInbound injects an inbound packet.
func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
e.InjectLinkAddr(protocol, "", pkt)
@@ -122,11 +231,7 @@ func (e *Endpoint) MTU() uint32 {
// Capabilities implements stack.LinkEndpoint.Capabilities.
func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
- caps := stack.LinkEndpointCapabilities(0)
- if e.GSO {
- caps |= stack.CapabilityHardwareGSO
- }
- return caps
+ return e.LinkEPCapabilities
}
// GSOMaxSize returns the maximum GSO packet size.
@@ -146,26 +251,31 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
}
// WritePacket stores outbound packets into the channel.
-func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+ // Clone r then release its resource so we only get the relevant fields from
+ // stack.Route without holding a reference to a NIC's endpoint.
+ route := r.Clone()
+ route.Release()
p := PacketInfo{
Pkt: pkt,
Proto: protocol,
GSO: gso,
+ Route: route,
}
- select {
- case e.c <- p:
- default:
- }
+ e.q.Write(p)
return nil
}
// WritePackets stores outbound packets into the channel.
-func (e *Endpoint) WritePackets(_ *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ // Clone r then release its resource so we only get the relevant fields from
+ // stack.Route without holding a reference to a NIC's endpoint.
+ route := r.Clone()
+ route.Release()
payloadView := pkts[0].Data.ToView()
n := 0
-packetLoop:
for _, pkt := range pkts {
off := pkt.DataOffset
size := pkt.DataSize
@@ -176,14 +286,13 @@ packetLoop:
},
Proto: protocol,
GSO: gso,
+ Route: route,
}
- select {
- case e.c <- p:
- n++
- default:
- break packetLoop
+ if !e.q.Write(p) {
+ break
}
+ n++
}
return n, nil
@@ -197,13 +306,21 @@ func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
GSO: nil,
}
- select {
- case e.c <- p:
- default:
- }
+ e.q.Write(p)
return nil
}
// Wait implements stack.LinkEndpoint.Wait.
func (*Endpoint) Wait() {}
+
+// AddNotify adds a notification target for receiving event about outgoing
+// packets.
+func (e *Endpoint) AddNotify(notify Notification) *NotificationHandle {
+ return e.q.AddNotify(notify)
+}
+
+// RemoveNotify removes handle from the list of notification targets.
+func (e *Endpoint) RemoveNotify(handle *NotificationHandle) {
+ e.q.RemoveNotify(handle)
+}
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
index e5096ea38..e0db6cf54 100644
--- a/pkg/tcpip/link/tun/BUILD
+++ b/pkg/tcpip/link/tun/BUILD
@@ -4,6 +4,22 @@ package(licenses = ["notice"])
go_library(
name = "tun",
- srcs = ["tun_unsafe.go"],
+ srcs = [
+ "device.go",
+ "protocol.go",
+ "tun_unsafe.go",
+ ],
visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/refs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/stack",
+ "//pkg/waiter",
+ ],
)
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
new file mode 100644
index 000000000..6ff47a742
--- /dev/null
+++ b/pkg/tcpip/link/tun/device.go
@@ -0,0 +1,352 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tun
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ // drivers/net/tun.c:tun_net_init()
+ defaultDevMtu = 1500
+
+ // Queue length for outbound packet, arriving at fd side for read. Overflow
+ // causes packet drops. gVisor implementation-specific.
+ defaultDevOutQueueLen = 1024
+)
+
+var zeroMAC [6]byte
+
+// Device is an opened /dev/net/tun device.
+//
+// +stateify savable
+type Device struct {
+ waiter.Queue
+
+ mu sync.RWMutex `state:"nosave"`
+ endpoint *tunEndpoint
+ notifyHandle *channel.NotificationHandle
+ flags uint16
+}
+
+// beforeSave is invoked by stateify.
+func (d *Device) beforeSave() {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ // TODO(b/110961832): Restore the device to stack. At this moment, the stack
+ // is not savable.
+ if d.endpoint != nil {
+ panic("/dev/net/tun does not support save/restore when a device is associated with it.")
+ }
+}
+
+// Release implements fs.FileOperations.Release.
+func (d *Device) Release() {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // Decrease refcount if there is an endpoint associated with this file.
+ if d.endpoint != nil {
+ d.endpoint.RemoveNotify(d.notifyHandle)
+ d.endpoint.DecRef()
+ d.endpoint = nil
+ }
+}
+
+// SetIff services TUNSETIFF ioctl(2) request.
+func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ if d.endpoint != nil {
+ return syserror.EINVAL
+ }
+
+ // Input validations.
+ isTun := flags&linux.IFF_TUN != 0
+ isTap := flags&linux.IFF_TAP != 0
+ supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI)
+ if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 {
+ return syserror.EINVAL
+ }
+
+ prefix := "tun"
+ if isTap {
+ prefix = "tap"
+ }
+
+ endpoint, err := attachOrCreateNIC(s, name, prefix)
+ if err != nil {
+ return syserror.EINVAL
+ }
+
+ d.endpoint = endpoint
+ d.notifyHandle = d.endpoint.AddNotify(d)
+ d.flags = flags
+ return nil
+}
+
+func attachOrCreateNIC(s *stack.Stack, name, prefix string) (*tunEndpoint, error) {
+ for {
+ // 1. Try to attach to an existing NIC.
+ if name != "" {
+ if nic, found := s.GetNICByName(name); found {
+ endpoint, ok := nic.LinkEndpoint().(*tunEndpoint)
+ if !ok {
+ // Not a NIC created by tun device.
+ return nil, syserror.EOPNOTSUPP
+ }
+ if !endpoint.TryIncRef() {
+ // Race detected: NIC got deleted in between.
+ continue
+ }
+ return endpoint, nil
+ }
+ }
+
+ // 2. Creating a new NIC.
+ id := tcpip.NICID(s.UniqueID())
+ endpoint := &tunEndpoint{
+ Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""),
+ stack: s,
+ nicID: id,
+ name: name,
+ }
+ if endpoint.name == "" {
+ endpoint.name = fmt.Sprintf("%s%d", prefix, id)
+ }
+ err := s.CreateNICWithOptions(endpoint.nicID, endpoint, stack.NICOptions{
+ Name: endpoint.name,
+ })
+ switch err {
+ case nil:
+ return endpoint, nil
+ case tcpip.ErrDuplicateNICID:
+ // Race detected: A NIC has been created in between.
+ continue
+ default:
+ return nil, syserror.EINVAL
+ }
+ }
+}
+
+// Write inject one inbound packet to the network interface.
+func (d *Device) Write(data []byte) (int64, error) {
+ d.mu.RLock()
+ endpoint := d.endpoint
+ d.mu.RUnlock()
+ if endpoint == nil {
+ return 0, syserror.EBADFD
+ }
+ if !endpoint.IsAttached() {
+ return 0, syserror.EIO
+ }
+
+ dataLen := int64(len(data))
+
+ // Packet information.
+ var pktInfoHdr PacketInfoHeader
+ if !d.hasFlags(linux.IFF_NO_PI) {
+ if len(data) < PacketInfoHeaderSize {
+ // Ignore bad packet.
+ return dataLen, nil
+ }
+ pktInfoHdr = PacketInfoHeader(data[:PacketInfoHeaderSize])
+ data = data[PacketInfoHeaderSize:]
+ }
+
+ // Ethernet header (TAP only).
+ var ethHdr header.Ethernet
+ if d.hasFlags(linux.IFF_TAP) {
+ if len(data) < header.EthernetMinimumSize {
+ // Ignore bad packet.
+ return dataLen, nil
+ }
+ ethHdr = header.Ethernet(data[:header.EthernetMinimumSize])
+ data = data[header.EthernetMinimumSize:]
+ }
+
+ // Try to determine network protocol number, default zero.
+ var protocol tcpip.NetworkProtocolNumber
+ switch {
+ case pktInfoHdr != nil:
+ protocol = pktInfoHdr.Protocol()
+ case ethHdr != nil:
+ protocol = ethHdr.Type()
+ }
+
+ // Try to determine remote link address, default zero.
+ var remote tcpip.LinkAddress
+ switch {
+ case ethHdr != nil:
+ remote = ethHdr.SourceAddress()
+ default:
+ remote = tcpip.LinkAddress(zeroMAC[:])
+ }
+
+ pkt := tcpip.PacketBuffer{
+ Data: buffer.View(data).ToVectorisedView(),
+ }
+ if ethHdr != nil {
+ pkt.LinkHeader = buffer.View(ethHdr)
+ }
+ endpoint.InjectLinkAddr(protocol, remote, pkt)
+ return dataLen, nil
+}
+
+// Read reads one outgoing packet from the network interface.
+func (d *Device) Read() ([]byte, error) {
+ d.mu.RLock()
+ endpoint := d.endpoint
+ d.mu.RUnlock()
+ if endpoint == nil {
+ return nil, syserror.EBADFD
+ }
+
+ for {
+ info, ok := endpoint.Read()
+ if !ok {
+ return nil, syserror.ErrWouldBlock
+ }
+
+ v, ok := d.encodePkt(&info)
+ if !ok {
+ // Ignore unsupported packet.
+ continue
+ }
+ return v, nil
+ }
+}
+
+// encodePkt encodes packet for fd side.
+func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
+ var vv buffer.VectorisedView
+
+ // Packet information.
+ if !d.hasFlags(linux.IFF_NO_PI) {
+ hdr := make(PacketInfoHeader, PacketInfoHeaderSize)
+ hdr.Encode(&PacketInfoFields{
+ Protocol: info.Proto,
+ })
+ vv.AppendView(buffer.View(hdr))
+ }
+
+ // If the packet does not already have link layer header, and the route
+ // does not exist, we can't compute it. This is possibly a raw packet, tun
+ // device doesn't support this at the moment.
+ if info.Pkt.LinkHeader == nil && info.Route.RemoteLinkAddress == "" {
+ return nil, false
+ }
+
+ // Ethernet header (TAP only).
+ if d.hasFlags(linux.IFF_TAP) {
+ // Add ethernet header if not provided.
+ if info.Pkt.LinkHeader == nil {
+ hdr := &header.EthernetFields{
+ SrcAddr: info.Route.LocalLinkAddress,
+ DstAddr: info.Route.RemoteLinkAddress,
+ Type: info.Proto,
+ }
+ if hdr.SrcAddr == "" {
+ hdr.SrcAddr = d.endpoint.LinkAddress()
+ }
+
+ eth := make(header.Ethernet, header.EthernetMinimumSize)
+ eth.Encode(hdr)
+ vv.AppendView(buffer.View(eth))
+ } else {
+ vv.AppendView(info.Pkt.LinkHeader)
+ }
+ }
+
+ // Append upper headers.
+ vv.AppendView(buffer.View(info.Pkt.Header.View()[len(info.Pkt.LinkHeader):]))
+ // Append data payload.
+ vv.Append(info.Pkt.Data)
+
+ return vv.ToView(), true
+}
+
+// Name returns the name of the attached network interface. Empty string if
+// unattached.
+func (d *Device) Name() string {
+ d.mu.RLock()
+ defer d.mu.RUnlock()
+ if d.endpoint != nil {
+ return d.endpoint.name
+ }
+ return ""
+}
+
+// Flags returns the flags set for d. Zero value if unset.
+func (d *Device) Flags() uint16 {
+ d.mu.RLock()
+ defer d.mu.RUnlock()
+ return d.flags
+}
+
+func (d *Device) hasFlags(flags uint16) bool {
+ return d.flags&flags == flags
+}
+
+// Readiness implements watier.Waitable.Readiness.
+func (d *Device) Readiness(mask waiter.EventMask) waiter.EventMask {
+ if mask&waiter.EventIn != 0 {
+ d.mu.RLock()
+ endpoint := d.endpoint
+ d.mu.RUnlock()
+ if endpoint != nil && endpoint.NumQueued() == 0 {
+ mask &= ^waiter.EventIn
+ }
+ }
+ return mask & (waiter.EventIn | waiter.EventOut)
+}
+
+// WriteNotify implements channel.Notification.WriteNotify.
+func (d *Device) WriteNotify() {
+ d.Notify(waiter.EventIn)
+}
+
+// tunEndpoint is the link endpoint for the NIC created by the tun device.
+//
+// It is ref-counted as multiple opening files can attach to the same NIC.
+// The last owner is responsible for deleting the NIC.
+type tunEndpoint struct {
+ *channel.Endpoint
+
+ refs.AtomicRefCount
+
+ stack *stack.Stack
+ nicID tcpip.NICID
+ name string
+}
+
+// DecRef decrements refcount of e, removes NIC if refcount goes to 0.
+func (e *tunEndpoint) DecRef() {
+ e.DecRefWithDestructor(func() {
+ e.stack.RemoveNIC(e.nicID)
+ })
+}
diff --git a/pkg/tcpip/link/tun/protocol.go b/pkg/tcpip/link/tun/protocol.go
new file mode 100644
index 000000000..89d9d91a9
--- /dev/null
+++ b/pkg/tcpip/link/tun/protocol.go
@@ -0,0 +1,56 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tun
+
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // PacketInfoHeaderSize is the size of the packet information header.
+ PacketInfoHeaderSize = 4
+
+ offsetFlags = 0
+ offsetProtocol = 2
+)
+
+// PacketInfoFields contains fields sent through the wire if IFF_NO_PI flag is
+// not set.
+type PacketInfoFields struct {
+ Flags uint16
+ Protocol tcpip.NetworkProtocolNumber
+}
+
+// PacketInfoHeader is the wire representation of the packet information sent if
+// IFF_NO_PI flag is not set.
+type PacketInfoHeader []byte
+
+// Encode encodes f into h.
+func (h PacketInfoHeader) Encode(f *PacketInfoFields) {
+ binary.BigEndian.PutUint16(h[offsetFlags:][:2], f.Flags)
+ binary.BigEndian.PutUint16(h[offsetProtocol:][:2], uint16(f.Protocol))
+}
+
+// Flags returns the flag field in h.
+func (h PacketInfoHeader) Flags() uint16 {
+ return binary.BigEndian.Uint16(h[offsetFlags:])
+}
+
+// Protocol returns the protocol field in h.
+func (h PacketInfoHeader) Protocol() tcpip.NetworkProtocolNumber {
+ return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(h[offsetProtocol:]))
+}
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 1ceaebfbd..e9fcc89a8 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -148,12 +148,12 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi
}, nil
}
-// LinkAddressProtocol implements stack.LinkAddressResolver.
+// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol.
func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv4ProtocolNumber
}
-// LinkAddressRequest implements stack.LinkAddressResolver.
+// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
r := &stack.Route{
RemoteLinkAddress: broadcastMAC,
@@ -172,42 +172,33 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
})
}
-// ResolveStaticAddress implements stack.LinkAddressResolver.
+// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress.
func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if addr == header.IPv4Broadcast {
return broadcastMAC, true
}
if header.IsV4MulticastAddress(addr) {
- // RFC 1112 Host Extensions for IP Multicasting
- //
- // 6.4. Extensions to an Ethernet Local Network Module:
- //
- // An IP host group address is mapped to an Ethernet multicast
- // address by placing the low-order 23-bits of the IP address
- // into the low-order 23 bits of the Ethernet multicast address
- // 01-00-5E-00-00-00 (hex).
- return tcpip.LinkAddress([]byte{
- 0x01,
- 0x00,
- 0x5e,
- addr[header.IPv4AddressSize-3] & 0x7f,
- addr[header.IPv4AddressSize-2],
- addr[header.IPv4AddressSize-1],
- }), true
+ return header.EthernetAddressFromMulticastIPv4Address(addr), true
}
- return "", false
+ return tcpip.LinkAddress([]byte(nil)), false
}
-// SetOption implements NetworkProtocol.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+// SetOption implements stack.NetworkProtocol.SetOption.
+func (*protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-// Option implements NetworkProtocol.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+// Option implements stack.NetworkProtocol.Option.
+func (*protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
// NewProtocol returns an ARP network protocol.
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 85512f9b2..4f1742938 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -353,6 +353,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
}
pkt.NetworkHeader = headerView[:h.HeaderLength()]
+ hlen := int(h.HeaderLength())
+ tlen := int(h.TotalLength())
+ pkt.Data.TrimFront(hlen)
+ pkt.Data.CapLength(tlen - hlen)
+
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
ipt := e.stack.IPTables()
@@ -361,11 +366,6 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
return
}
- hlen := int(h.HeaderLength())
- tlen := int(h.TotalLength())
- pkt.Data.TrimFront(hlen)
- pkt.Data.CapLength(tlen - hlen)
-
more := (h.Flags() & header.IPv4FlagMoreFragments) != 0
if more || h.FragmentOffset() != 0 {
if pkt.Data.Size() == 0 {
@@ -473,6 +473,12 @@ func (p *protocol) DefaultTTL() uint8 {
return uint8(atomic.LoadUint32(&p.defaultTTL))
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index dc20c0fd7..45dc757c7 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -15,6 +15,8 @@
package ipv6
import (
+ "log"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -194,7 +196,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
// TODO(b/148429853): Properly process the NS message and do Neighbor
// Unreachability Detection.
for {
- opt, done, _ := it.Next()
+ opt, done, err := it.Next()
+ if err != nil {
+ // This should never happen as Iter(true) above did not return an error.
+ log.Fatalf("unexpected error when iterating over NDP options: %s", err)
+ }
if done {
break
}
@@ -253,21 +259,25 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
}
na := header.NDPNeighborAdvert(h.NDPPayload())
+ it, err := na.Options().Iter(true)
+ if err != nil {
+ // If we have a malformed NDP NA option, drop the packet.
+ received.Invalid.Increment()
+ return
+ }
+
targetAddr := na.TargetAddress()
stack := r.Stack()
rxNICID := r.NICID()
- isTentative, err := stack.IsAddrTentative(rxNICID, targetAddr)
- if err != nil {
+ if isTentative, err := stack.IsAddrTentative(rxNICID, targetAddr); err != nil {
// We will only get an error if rxNICID is unrecognized,
// which should not happen. For now short-circuit this
// packet.
//
// TODO(b/141002840): Handle this better?
return
- }
-
- if isTentative {
+ } else if isTentative {
// We just got an NA from a node that owns an address we
// are performing DAD on, implying the address is not
// unique. In this case we let the stack know so it can
@@ -283,13 +293,29 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P
// scenario is beyond the scope of RFC 4862. As such, we simply
// ignore such a scenario for now and proceed as normal.
//
+ // If the NA message has the target link layer option, update the link
+ // address cache with the link address for the target of the message.
+ //
// TODO(b/143147598): Handle the scenario described above. Also
// inform the netstack integration that a duplicate address was
// detected outside of DAD.
+ //
+ // TODO(b/148429853): Properly process the NA message and do Neighbor
+ // Unreachability Detection.
+ for {
+ opt, done, err := it.Next()
+ if err != nil {
+ // This should never happen as Iter(true) above did not return an error.
+ log.Fatalf("unexpected error when iterating over NDP options: %s", err)
+ }
+ if done {
+ break
+ }
- e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, r.RemoteLinkAddress)
- if targetAddr != r.RemoteAddress {
- e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress)
+ switch opt := opt.(type) {
+ case header.NDPTargetLinkLayerAddressOption:
+ e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, opt.EthernetAddress())
+ }
}
case header.ICMPv6EchoRequest:
@@ -408,10 +434,14 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
// LinkAddressRequest implements stack.LinkAddressResolver.
func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
snaddr := header.SolicitedNodeAddr(addr)
+
+ // TODO(b/148672031): Use stack.FindRoute instead of manually creating the
+ // route here. Note, we would need the nicID to do this properly so the right
+ // NIC (associated to linkEP) is used to send the NDP NS message.
r := &stack.Route{
LocalAddress: localAddr,
RemoteAddress: snaddr,
- RemoteLinkAddress: broadcastMAC,
+ RemoteLinkAddress: header.EthernetAddressFromMulticastIPv6Address(snaddr),
}
hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
@@ -441,23 +471,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
// ResolveStaticAddress implements stack.LinkAddressResolver.
func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if header.IsV6MulticastAddress(addr) {
- // RFC 2464 Transmission of IPv6 Packets over Ethernet Networks
- //
- // 7. Address Mapping -- Multicast
- //
- // An IPv6 packet with a multicast destination address DST,
- // consisting of the sixteen octets DST[1] through DST[16], is
- // transmitted to the Ethernet multicast address whose first
- // two octets are the value 3333 hexadecimal and whose last
- // four octets are the last four octets of DST.
- return tcpip.LinkAddress([]byte{
- 0x33,
- 0x33,
- addr[header.IPv6AddressSize-4],
- addr[header.IPv6AddressSize-3],
- addr[header.IPv6AddressSize-2],
- addr[header.IPv6AddressSize-1],
- }), true
+ return header.EthernetAddressFromMulticastIPv6Address(addr), true
}
- return "", false
+ return tcpip.LinkAddress([]byte(nil)), false
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 7a6820643..50c4b6474 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -121,21 +121,60 @@ func TestICMPCounts(t *testing.T) {
}
defer r.Release()
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+
types := []struct {
- typ header.ICMPv6Type
- size int
+ typ header.ICMPv6Type
+ size int
+ extraData []byte
}{
- {header.ICMPv6DstUnreachable, header.ICMPv6DstUnreachableMinimumSize},
- {header.ICMPv6PacketTooBig, header.ICMPv6PacketTooBigMinimumSize},
- {header.ICMPv6TimeExceeded, header.ICMPv6MinimumSize},
- {header.ICMPv6ParamProblem, header.ICMPv6MinimumSize},
- {header.ICMPv6EchoRequest, header.ICMPv6EchoMinimumSize},
- {header.ICMPv6EchoReply, header.ICMPv6EchoMinimumSize},
- {header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize},
- {header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize},
- {header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize},
- {header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize},
- {header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize},
+ {
+ typ: header.ICMPv6DstUnreachable,
+ size: header.ICMPv6DstUnreachableMinimumSize,
+ },
+ {
+ typ: header.ICMPv6PacketTooBig,
+ size: header.ICMPv6PacketTooBigMinimumSize,
+ },
+ {
+ typ: header.ICMPv6TimeExceeded,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6ParamProblem,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6EchoRequest,
+ size: header.ICMPv6EchoMinimumSize,
+ },
+ {
+ typ: header.ICMPv6EchoReply,
+ size: header.ICMPv6EchoMinimumSize,
+ },
+ {
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ },
+ {
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize},
+ {
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ },
+ {
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ },
}
handleIPv6Payload := func(hdr buffer.Prependable) {
@@ -154,10 +193,13 @@ func TestICMPCounts(t *testing.T) {
}
for _, typ := range types {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size)
+ extraDataLen := len(typ.extraData)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
+ extraData := buffer.View(hdr.Prepend(extraDataLen))
+ copy(extraData, typ.extraData)
pkt := header.ICMPv6(hdr.Prepend(typ.size))
pkt.SetType(typ.typ)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView()))
handleIPv6Payload(hdr)
}
@@ -270,8 +312,9 @@ func (c *testContext) cleanup() {
}
type routeArgs struct {
- src, dst *channel.Endpoint
- typ header.ICMPv6Type
+ src, dst *channel.Endpoint
+ typ header.ICMPv6Type
+ remoteLinkAddr tcpip.LinkAddress
}
func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) {
@@ -292,6 +335,11 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.
t.Errorf("unexpected protocol number %d", pi.Proto)
return
}
+
+ if len(args.remoteLinkAddr) != 0 && args.remoteLinkAddr != pi.Route.RemoteLinkAddress {
+ t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr)
+ }
+
ipv6 := header.IPv6(pi.Pkt.Header.View())
transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader())
if transProto != header.ICMPv6ProtocolNumber {
@@ -339,7 +387,7 @@ func TestLinkResolution(t *testing.T) {
t.Fatalf("ep.Write(_) = _, <non-nil>, %s, want = _, <non-nil>, tcpip.ErrNoLinkAddress", err)
}
for _, args := range []routeArgs{
- {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit},
+ {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))},
{src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert},
} {
routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) {
@@ -366,97 +414,104 @@ func TestLinkResolution(t *testing.T) {
}
func TestICMPChecksumValidationSimple(t *testing.T) {
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+
types := []struct {
name string
typ header.ICMPv6Type
size int
+ extraData []byte
statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
}{
{
- "DstUnreachable",
- header.ICMPv6DstUnreachable,
- header.ICMPv6DstUnreachableMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "DstUnreachable",
+ typ: header.ICMPv6DstUnreachable,
+ size: header.ICMPv6DstUnreachableMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.DstUnreachable
},
},
{
- "PacketTooBig",
- header.ICMPv6PacketTooBig,
- header.ICMPv6PacketTooBigMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "PacketTooBig",
+ typ: header.ICMPv6PacketTooBig,
+ size: header.ICMPv6PacketTooBigMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.PacketTooBig
},
},
{
- "TimeExceeded",
- header.ICMPv6TimeExceeded,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "TimeExceeded",
+ typ: header.ICMPv6TimeExceeded,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.TimeExceeded
},
},
{
- "ParamProblem",
- header.ICMPv6ParamProblem,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "ParamProblem",
+ typ: header.ICMPv6ParamProblem,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.ParamProblem
},
},
{
- "EchoRequest",
- header.ICMPv6EchoRequest,
- header.ICMPv6EchoMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "EchoRequest",
+ typ: header.ICMPv6EchoRequest,
+ size: header.ICMPv6EchoMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.EchoRequest
},
},
{
- "EchoReply",
- header.ICMPv6EchoReply,
- header.ICMPv6EchoMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "EchoReply",
+ typ: header.ICMPv6EchoReply,
+ size: header.ICMPv6EchoMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.EchoReply
},
},
{
- "RouterSolicit",
- header.ICMPv6RouterSolicit,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "RouterSolicit",
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.RouterSolicit
},
},
{
- "RouterAdvert",
- header.ICMPv6RouterAdvert,
- header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "RouterAdvert",
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.RouterAdvert
},
},
{
- "NeighborSolicit",
- header.ICMPv6NeighborSolicit,
- header.ICMPv6NeighborSolicitMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "NeighborSolicit",
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.NeighborSolicit
},
},
{
- "NeighborAdvert",
- header.ICMPv6NeighborAdvert,
- header.ICMPv6NeighborAdvertSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "NeighborAdvert",
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.NeighborAdvert
},
},
{
- "RedirectMsg",
- header.ICMPv6RedirectMsg,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "RedirectMsg",
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.RedirectMsg
},
},
@@ -488,16 +543,19 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
)
}
- handleIPv6Payload := func(typ header.ICMPv6Type, size int, checksum bool) {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + size)
- pkt := header.ICMPv6(hdr.Prepend(size))
- pkt.SetType(typ)
+ handleIPv6Payload := func(checksum bool) {
+ extraDataLen := len(typ.extraData)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
+ extraData := buffer.View(hdr.Prepend(extraDataLen))
+ copy(extraData, typ.extraData)
+ pkt := header.ICMPv6(hdr.Prepend(typ.size))
+ pkt.SetType(typ.typ)
if checksum {
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, extraData.ToVectorisedView()))
}
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(size),
+ PayloadLength: uint16(typ.size + extraDataLen),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
SrcAddr: lladdr1,
@@ -522,7 +580,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
// Without setting checksum, the incoming packet should
// be invalid.
- handleIPv6Payload(typ.typ, typ.size, false)
+ handleIPv6Payload(false)
if got := invalid.Value(); got != 1 {
t.Fatalf("got invalid = %d, want = 1", got)
}
@@ -532,7 +590,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
}
// When checksum is set, it should be received.
- handleIPv6Payload(typ.typ, typ.size, true)
+ handleIPv6Payload(true)
if got := typStat.Value(); got != 1 {
t.Fatalf("got %s = %d, want = 1", typ.name, got)
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 180a480fd..9aef5234b 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -265,6 +265,12 @@ func (p *protocol) DefaultTTL() uint8 {
return uint8(atomic.LoadUint32(&p.defaultTTL))
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index bd732f93f..c9395de52 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -70,76 +70,29 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack
return s, ep
}
-// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving an
-// NDP NS message with the Source Link Layer Address option results in a
+// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a
+// valid NDP NS message with the Source Link Layer Address option results in a
// new entry in the link address cache for the sender of the message.
func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
- })
- e := channel.New(0, 1280, linkAddr0)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
- if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
- }
-
- ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
- ns.SetTargetAddress(lladdr0)
- ns.Options().Serialize(header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(linkAddr1),
- })
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
- ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
- HopLimit: 255,
- SrcAddr: lladdr1,
- DstAddr: lladdr0,
- })
- e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
- Data: hdr.View().ToVectorisedView(),
- })
-
- linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
- if err != nil {
- t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
- }
- if c != nil {
- t.Errorf("got unexpected channel")
- }
- if linkAddr != linkAddr1 {
- t.Errorf("got link address = %s, want = %s", linkAddr, linkAddr1)
- }
-}
-
-// TestNeighorSolicitationWithInvalidSourceLinkLayerOption tests that receiving
-// an NDP NS message with an invalid Source Link Layer Address option does not
-// result in a new entry in the link address cache for the sender of the
-// message.
-func TestNeighorSolicitationWithInvalidSourceLinkLayerOption(t *testing.T) {
- const nicID = 1
-
tests := []struct {
- name string
- optsBuf []byte
+ name string
+ optsBuf []byte
+ expectedLinkAddr tcpip.LinkAddress
}{
{
+ name: "Valid",
+ optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7},
+ expectedLinkAddr: "\x02\x03\x04\x05\x06\x07",
+ },
+ {
name: "Too Small",
- optsBuf: []byte{1, 1, 1, 2, 3, 4, 5},
+ optsBuf: []byte{1, 1, 2, 3, 4, 5, 6},
},
{
name: "Invalid Length",
- optsBuf: []byte{1, 2, 1, 2, 3, 4, 5, 6},
+ optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7},
},
}
@@ -186,20 +139,138 @@ func TestNeighorSolicitationWithInvalidSourceLinkLayerOption(t *testing.T) {
Data: hdr.View().ToVectorisedView(),
})
- // Invalid count should have increased.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
+ linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
+ if linkAddr != test.expectedLinkAddr {
+ t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr)
}
- linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
+ if test.expectedLinkAddr != "" {
+ if err != nil {
+ t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
+ }
+ if c != nil {
+ t.Errorf("got unexpected channel")
+ }
+
+ // Invalid count should not have increased.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
+ }
+ if c == nil {
+ t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
+ }
+
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Errorf("got invalid = %d, want = 1", got)
+ }
+ }
+ })
+ }
+}
+
+// TestNeighorAdvertisementWithTargetLinkLayerOption tests that receiving a
+// valid NDP NA message with the Target Link Layer Address option results in a
+// new entry in the link address cache for the target of the message.
+func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ optsBuf []byte
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Valid",
+ optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7},
+ expectedLinkAddr: "\x02\x03\x04\x05\x06\x07",
+ },
+ {
+ name: "Too Small",
+ optsBuf: []byte{2, 1, 2, 3, 4, 5, 6},
+ },
+ {
+ name: "Invalid Length",
+ optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ e := channel.New(0, 1280, linkAddr0)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ ns := header.NDPNeighborAdvert(pkt.NDPPayload())
+ ns.SetTargetAddress(lladdr1)
+ opts := ns.Options()
+ copy(opts, test.optsBuf)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
}
- if c == nil {
- t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
+
+ e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{
+ Data: hdr.View().ToVectorisedView(),
+ })
+
+ linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
+ if linkAddr != test.expectedLinkAddr {
+ t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr)
}
- if linkAddr != "" {
- t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (%s, _, ), want = ('', _, _)", nicID, lladdr1, lladdr0, ProtocolNumber, linkAddr)
+
+ if test.expectedLinkAddr != "" {
+ if err != nil {
+ t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
+ }
+ if c != nil {
+ t.Errorf("got unexpected channel")
+ }
+
+ // Invalid count should not have increased.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
+ }
+ if c == nil {
+ t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
+ }
+
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Errorf("got invalid = %d, want = 1", got)
+ }
}
})
}
@@ -238,27 +309,59 @@ func TestHopLimitValidation(t *testing.T) {
})
}
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+
types := []struct {
name string
typ header.ICMPv6Type
size int
+ extraData []byte
statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
}{
- {"RouterSolicit", header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterSolicit
- }},
- {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterAdvert
- }},
- {"NeighborSolicit", header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborSolicit
- }},
- {"NeighborAdvert", header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborAdvert
- }},
- {"RedirectMsg", header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RedirectMsg
- }},
+ {
+ name: "RouterSolicit",
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterSolicit
+ },
+ },
+ {
+ name: "RouterAdvert",
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterAdvert
+ },
+ },
+ {
+ name: "NeighborSolicit",
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborSolicit
+ },
+ },
+ {
+ name: "NeighborAdvert",
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborAdvert
+ },
+ },
+ {
+ name: "RedirectMsg",
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RedirectMsg
+ },
+ },
}
for _, typ := range types {
@@ -270,10 +373,13 @@ func TestHopLimitValidation(t *testing.T) {
invalid := stats.Invalid
typStat := typ.statCounter(stats)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size)
+ extraDataLen := len(typ.extraData)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen)
+ extraData := buffer.View(hdr.Prepend(extraDataLen))
+ copy(extraData, typ.extraData)
pkt := header.ICMPv6(hdr.Prepend(typ.size))
pkt.SetType(typ.typ)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView()))
// Invalid count should initially be 0.
if got := invalid.Value(); got != 0 {
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index f5b750046..705cf01ee 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -78,11 +78,15 @@ go_test(
go_test(
name = "stack_test",
size = "small",
- srcs = ["linkaddrcache_test.go"],
+ srcs = [
+ "linkaddrcache_test.go",
+ "nic_test.go",
+ ],
library = ":stack",
deps = [
"//pkg/sleep",
"//pkg/sync",
"//pkg/tcpip",
+ "//pkg/tcpip/buffer",
],
)
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index 245694118..f651871ce 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -167,8 +167,8 @@ type NDPDispatcher interface {
// reason, such as the address being removed). If an error occured
// during DAD, err will be set and resolved must be ignored.
//
- // This function is permitted to block indefinitely without interfering
- // with the stack's operation.
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error)
// OnDefaultRouterDiscovered will be called when a new default router is
@@ -448,6 +448,13 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
remaining := ndp.configs.DupAddrDetectTransmits
if remaining == 0 {
ref.setKind(permanent)
+
+ // Consider DAD to have resolved even if no DAD messages were actually
+ // transmitted.
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, true, nil)
+ }
+
return nil
}
@@ -538,29 +545,19 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error {
r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false)
defer r.Release()
- linkAddr := ndp.nic.linkEP.LinkAddress()
- isValidLinkAddr := header.IsValidUnicastEthernetAddress(linkAddr)
- ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize
- if isValidLinkAddr {
- // Only include a Source Link Layer Address option if the NIC has a valid
- // link layer address.
- //
- // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
- // LinkEndpoint.LinkAddress) before reaching this point.
- ndpNSSize += header.NDPLinkLayerAddressSize
+ // Route should resolve immediately since snmc is a multicast address so a
+ // remote link address can be calculated without a resolution process.
+ if c, err := r.Resolve(nil); err != nil {
+ log.Fatalf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err)
+ } else if c != nil {
+ log.Fatalf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID())
}
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + ndpNSSize)
- pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
ns := header.NDPNeighborSolicit(pkt.NDPPayload())
ns.SetTargetAddress(addr)
-
- if isValidLinkAddr {
- ns.Options().Serialize(header.NDPOptionsSerializer{
- header.NDPSourceLinkLayerAddressOption(linkAddr),
- })
- }
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
sent := r.Stats().ICMP.V6PacketsSent
@@ -607,8 +604,8 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
delete(ndp.dad, addr)
// Let the integrator know DAD did not resolve.
- if ndp.nic.stack.ndpDisp != nil {
- go ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil)
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil)
}
}
@@ -916,22 +913,21 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform
return
}
- // We do not already have an address within the prefix, prefix. Do the
+ // We do not already have an address with the prefix prefix. Do the
// work as outlined by RFC 4862 section 5.5.3.d if n is configured
- // to auto-generated global addresses by SLAAC.
- ndp.newAutoGenAddress(prefix, pl, vl)
+ // to auto-generate global addresses by SLAAC.
+ if !ndp.configs.AutoGenGlobalAddresses {
+ return
+ }
+
+ ndp.doSLAAC(prefix, pl, vl)
}
-// newAutoGenAddress generates a new SLAAC address with the provided lifetimes
+// doSLAAC generates a new SLAAC address with the provided lifetimes
// for prefix.
//
// pl is the new preferred lifetime. vl is the new valid lifetime.
-func (ndp *ndpState) newAutoGenAddress(prefix tcpip.Subnet, pl, vl time.Duration) {
- // Are we configured to auto-generate new global addresses?
- if !ndp.configs.AutoGenGlobalAddresses {
- return
- }
-
+func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
// If we do not already have an address for this prefix and the valid
// lifetime is 0, no need to do anything further, as per RFC 4862
// section 5.5.3.d.
@@ -1152,22 +1148,36 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo
return true
}
-// cleanupHostOnlyState cleans up any state that is only useful for hosts.
+// cleanupState cleans up ndp's state.
//
-// cleanupHostOnlyState MUST be called when ndp's NIC is transitioning from a
-// host to a router. This function will invalidate all discovered on-link
-// prefixes, discovered routers, and auto-generated addresses as routers do not
-// normally process Router Advertisements to discover default routers and
-// on-link prefixes, and auto-generate addresses via SLAAC.
+// If hostOnly is true, then only host-specific state will be cleaned up.
+//
+// cleanupState MUST be called with hostOnly set to true when ndp's NIC is
+// transitioning from a host to a router. This function will invalidate all
+// discovered on-link prefixes, discovered routers, and auto-generated
+// addresses.
+//
+// If hostOnly is true, then the link-local auto-generated address will not be
+// invalidated as routers are also expected to generate a link-local address.
//
// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) cleanupHostOnlyState() {
+func (ndp *ndpState) cleanupState(hostOnly bool) {
+ linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet()
+ linkLocalAddrs := 0
for addr := range ndp.autoGenAddresses {
+ // RFC 4862 section 5 states that routers are also expected to generate a
+ // link-local address so we do not invalidate them if we are cleaning up
+ // host-only state.
+ if hostOnly && linkLocalSubnet.Contains(addr) {
+ linkLocalAddrs++
+ continue
+ }
+
ndp.invalidateAutoGenAddress(addr)
}
- if got := len(ndp.autoGenAddresses); got != 0 {
- log.Fatalf("ndp: still have auto-generated addresses after cleaning up, found = %d", got)
+ if got := len(ndp.autoGenAddresses); got != linkLocalAddrs {
+ log.Fatalf("ndp: still have non-linklocal auto-generated addresses after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalAddrs)
}
for prefix := range ndp.onLinkPrefixes {
@@ -1175,7 +1185,7 @@ func (ndp *ndpState) cleanupHostOnlyState() {
}
if got := len(ndp.onLinkPrefixes); got != 0 {
- log.Fatalf("ndp: still have discovered on-link prefixes after cleaning up, found = %d", got)
+ log.Fatalf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got)
}
for router := range ndp.defaultRouters {
@@ -1183,7 +1193,7 @@ func (ndp *ndpState) cleanupHostOnlyState() {
}
if got := len(ndp.defaultRouters); got != 0 {
- log.Fatalf("ndp: still have discovered default routers after cleaning up, found = %d", got)
+ log.Fatalf("ndp: still have discovered default routers after cleaning up; found = %d", got)
}
}
@@ -1215,8 +1225,17 @@ func (ndp *ndpState) startSolicitingRouters() {
r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false)
defer r.Release()
+ // Route should resolve immediately since
+ // header.IPv6AllRoutersMulticastAddress is a multicast address so a
+ // remote link address can be calculated without a resolution process.
+ if c, err := r.Resolve(nil); err != nil {
+ log.Fatalf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err)
+ } else if c != nil {
+ log.Fatalf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID())
+ }
+
payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadSize)
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize)
pkt := header.ICMPv6(hdr.Prepend(payloadSize))
pkt.SetType(header.ICMPv6RouterSolicit)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 726468e41..6e9306d09 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -42,6 +42,7 @@ const (
linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
+ linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09")
defaultTimeout = 100 * time.Millisecond
defaultAsyncEventTimeout = time.Second
)
@@ -50,6 +51,7 @@ var (
llAddr1 = header.LinkLocalAddr(linkAddr1)
llAddr2 = header.LinkLocalAddr(linkAddr2)
llAddr3 = header.LinkLocalAddr(linkAddr3)
+ llAddr4 = header.LinkLocalAddr(linkAddr4)
dstAddr = tcpip.FullAddress{
Addr: "\x0a\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
Port: 25,
@@ -84,39 +86,6 @@ func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWi
return prefix, subnet, addrForSubnet(subnet, linkAddr)
}
-// TestDADDisabled tests that an address successfully resolves immediately
-// when DAD is not enabled (the default for an empty stack.Options).
-func TestDADDisabled(t *testing.T) {
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- }
-
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(opts)
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
-
- if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
- }
-
- // Should get the address immediately since we should not have performed
- // DAD on it.
- addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
- }
- if addr.Address != addr1 {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1)
- }
-
- // We should not have sent any NDP NS messages.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 {
- t.Fatalf("got NeighborSolicit = %d, want = 0", got)
- }
-}
-
// ndpDADEvent is a set of parameters that was passed to
// ndpDispatcher.OnDuplicateAddressDetectionStatus.
type ndpDADEvent struct {
@@ -298,25 +267,113 @@ func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration s
}
}
+// channelLinkWithHeaderLength is a channel.Endpoint with a configurable
+// header length.
+type channelLinkWithHeaderLength struct {
+ *channel.Endpoint
+ headerLength uint16
+}
+
+func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 {
+ return l.headerLength
+}
+
+// Check e to make sure that the event is for addr on nic with ID 1, and the
+// resolved flag set to resolved with the specified err.
+func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) string {
+ return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, resolved: resolved, err: err}, e, cmp.AllowUnexported(e))
+}
+
+// TestDADDisabled tests that an address successfully resolves immediately
+// when DAD is not enabled (the default for an empty stack.Options).
+func TestDADDisabled(t *testing.T) {
+ const nicID = 1
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ }
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
+ }
+
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ }
+
+ // Should get the address immediately since we should not have performed
+ // DAD on it.
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DAD event")
+ }
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(%d, %d) err = %s", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if addr.Address != addr1 {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1)
+ }
+
+ // We should not have sent any NDP NS messages.
+ if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 {
+ t.Fatalf("got NeighborSolicit = %d, want = 0", got)
+ }
+}
+
// TestDADResolve tests that an address successfully resolves after performing
// DAD for various values of DupAddrDetectTransmits and RetransmitTimer.
// Included in the subtests is a test to make sure that an invalid
// RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s.
+// This tests also validates the NDP NS packet that is transmitted.
func TestDADResolve(t *testing.T) {
const nicID = 1
tests := []struct {
name string
+ linkHeaderLen uint16
dupAddrDetectTransmits uint8
retransTimer time.Duration
expectedRetransmitTimer time.Duration
}{
- {"1:1s:1s", 1, time.Second, time.Second},
- {"2:1s:1s", 2, time.Second, time.Second},
- {"1:2s:2s", 1, 2 * time.Second, 2 * time.Second},
+ {
+ name: "1:1s:1s",
+ dupAddrDetectTransmits: 1,
+ retransTimer: time.Second,
+ expectedRetransmitTimer: time.Second,
+ },
+ {
+ name: "2:1s:1s",
+ linkHeaderLen: 1,
+ dupAddrDetectTransmits: 2,
+ retransTimer: time.Second,
+ expectedRetransmitTimer: time.Second,
+ },
+ {
+ name: "1:2s:2s",
+ linkHeaderLen: 2,
+ dupAddrDetectTransmits: 1,
+ retransTimer: 2 * time.Second,
+ expectedRetransmitTimer: 2 * time.Second,
+ },
// 0s is an invalid RetransmitTimer timer and will be fixed to
// the default RetransmitTimer value of 1s.
- {"1:0s:1s", 1, 0, time.Second},
+ {
+ name: "1:0s:1s",
+ linkHeaderLen: 3,
+ dupAddrDetectTransmits: 1,
+ retransTimer: 0,
+ expectedRetransmitTimer: time.Second,
+ },
}
for _, test := range tests {
@@ -335,9 +392,13 @@ func TestDADResolve(t *testing.T) {
opts.NDPConfigs.RetransmitTimer = test.retransTimer
opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits
- e := channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1)
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1),
+ headerLength: test.linkHeaderLen,
+ }
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
s := stack.New(opts)
- if err := s.CreateNIC(nicID, e); err != nil {
+ if err := s.CreateNIC(nicID, &e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -378,17 +439,8 @@ func TestDADResolve(t *testing.T) {
// means something is wrong.
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
- if e.err != nil {
- t.Fatal("got DAD error: ", e.err)
- }
- if e.nicID != nicID {
- t.Fatalf("got DAD event w/ nicID = %d, want = %d", e.nicID, nicID)
- }
- if e.addr != addr1 {
- t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
- }
- if !e.resolved {
- t.Fatal("got DAD event w/ resolved = false, want = true")
+ if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
@@ -413,15 +465,29 @@ func TestDADResolve(t *testing.T) {
t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
}
- // Check NDP packet.
+ // Make sure the right remote link address is used.
+ snmc := header.SolicitedNodeAddr(addr1)
+ if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
+
+ // Check NDP NS packet.
+ //
+ // As per RFC 4861 section 4.3, a possible option is the Source Link
+ // Layer option, but this option MUST NOT be included when the source
+ // address of the packet is the unspecified address.
checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(snmc),
checker.TTL(header.NDPHopLimit),
checker.NDPNS(
checker.NDPNSTargetAddress(addr1),
- checker.NDPNSOptions([]header.NDPOption{
- header.NDPSourceLinkLayerAddressOption(linkAddr1),
- }),
+ checker.NDPNSOptions(nil),
))
+
+ if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want)
+ }
}
})
}
@@ -432,6 +498,8 @@ func TestDADResolve(t *testing.T) {
// a node doing DAD for the same address), or if another node is detected to own
// the address already (receive an NA message for the tentative address).
func TestDADFail(t *testing.T) {
+ const nicID = 1
+
tests := []struct {
name string
makeBuf func(tgt tcpip.Address) buffer.Prependable
@@ -467,13 +535,17 @@ func TestDADFail(t *testing.T) {
{
"RxAdvert",
func(tgt tcpip.Address) buffer.Prependable {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
+ pkt := header.ICMPv6(hdr.Prepend(naSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
na := header.NDPNeighborAdvert(pkt.NDPPayload())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(tgt)
+ na.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
@@ -497,7 +569,7 @@ func TestDADFail(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent),
+ dadC: make(chan ndpDADEvent, 1),
}
ndpConfigs := stack.DefaultNDPConfigurations()
opts := stack.Options{
@@ -509,22 +581,22 @@ func TestDADFail(t *testing.T) {
e := channel.New(0, 1280, linkAddr1)
s := stack.New(opts)
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
}
// Address should not be considered bound to the NIC yet
// (DAD ongoing).
- addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
}
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
// Receive a packet to simulate multiple nodes owning or
@@ -548,102 +620,109 @@ func TestDADFail(t *testing.T) {
// something is wrong.
t.Fatal("timed out waiting for DAD failure")
case e := <-ndpDisp.dadC:
- if e.err != nil {
- t.Fatal("got DAD error: ", e.err)
- }
- if e.nicID != 1 {
- t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID)
- }
- if e.addr != addr1 {
- t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
- }
- if e.resolved {
- t.Fatal("got DAD event w/ resolved = true, want = false")
+ if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
- addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
}
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
})
}
}
-// TestDADStop tests to make sure that the DAD process stops when an address is
-// removed.
func TestDADStop(t *testing.T) {
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent),
- }
- ndpConfigs := stack.NDPConfigurations{
- RetransmitTimer: time.Second,
- DupAddrDetectTransmits: 2,
- }
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPDisp: &ndpDisp,
- NDPConfigs: ndpConfigs,
- }
+ const nicID = 1
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(opts)
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
+ tests := []struct {
+ name string
+ stopFn func(t *testing.T, s *stack.Stack)
+ }{
+ // Tests to make sure that DAD stops when an address is removed.
+ {
+ name: "Remove address",
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ if err := s.RemoveAddress(nicID, addr1); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s): %s", nicID, addr1, err)
+ }
+ },
+ },
- if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
+ // Tests to make sure that DAD stops when the NIC is disabled.
+ {
+ name: "Disable NIC",
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ },
+ },
}
- // Address should not be considered bound to the NIC yet (DAD ongoing).
- addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ }
+ ndpConfigs := stack.NDPConfigurations{
+ RetransmitTimer: time.Second,
+ DupAddrDetectTransmits: 2,
+ }
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
+ }
- // Remove the address. This should stop DAD.
- if err := s.RemoveAddress(1, addr1); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr1, err)
- }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
- // Wait for DAD to fail (since the address was removed during DAD).
- select {
- case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
- // If we don't get a failure event after the expected resolution
- // time + extra 1s buffer, something is wrong.
- t.Fatal("timed out waiting for DAD failure")
- case e := <-ndpDisp.dadC:
- if e.err != nil {
- t.Fatal("got DAD error: ", e.err)
- }
- if e.nicID != 1 {
- t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID)
- }
- if e.addr != addr1 {
- t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
- }
- if e.resolved {
- t.Fatal("got DAD event w/ resolved = true, want = false")
- }
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ }
- }
- addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
- }
+ // Address should not be considered bound to the NIC yet (DAD ongoing).
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
+
+ test.stopFn(t, s)
- // Should not have sent more than 1 NS message.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
- t.Fatalf("got NeighborSolicit = %d, want <= 1", got)
+ // Wait for DAD to fail (since the address was removed during DAD).
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // If we don't get a failure event after the expected resolution
+ // time + extra 1s buffer, something is wrong.
+ t.Fatal("timed out waiting for DAD failure")
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ }
+ addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
+
+ // Should not have sent more than 1 NS message.
+ if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
+ t.Errorf("got NeighborSolicit = %d, want <= 1", got)
+ }
+ })
}
}
@@ -664,6 +743,10 @@ func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) {
// configurations without affecting the default NDP configurations or other
// interfaces' configurations.
func TestSetNDPConfigurations(t *testing.T) {
+ const nicID1 = 1
+ const nicID2 = 2
+ const nicID3 = 3
+
tests := []struct {
name string
dupAddrDetectTransmits uint8
@@ -687,7 +770,7 @@ func TestSetNDPConfigurations(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent),
+ dadC: make(chan ndpDADEvent, 1),
}
e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
@@ -695,17 +778,28 @@ func TestSetNDPConfigurations(t *testing.T) {
NDPDisp: &ndpDisp,
})
+ expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) {
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatalf("expected DAD event for %s", addr)
+ }
+ }
+
// This NIC(1)'s NDP configurations will be updated to
// be different from the default.
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
+ if err := s.CreateNIC(nicID1, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err)
}
// Created before updating NIC(1)'s NDP configurations
// but updating NIC(1)'s NDP configurations should not
// affect other existing NICs.
- if err := s.CreateNIC(2, e); err != nil {
- t.Fatalf("CreateNIC(2) = %s", err)
+ if err := s.CreateNIC(nicID2, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err)
}
// Update the NDP configurations on NIC(1) to use DAD.
@@ -713,36 +807,38 @@ func TestSetNDPConfigurations(t *testing.T) {
DupAddrDetectTransmits: test.dupAddrDetectTransmits,
RetransmitTimer: test.retransmitTimer,
}
- if err := s.SetNDPConfigurations(1, configs); err != nil {
- t.Fatalf("got SetNDPConfigurations(1, _) = %s", err)
+ if err := s.SetNDPConfigurations(nicID1, configs); err != nil {
+ t.Fatalf("got SetNDPConfigurations(%d, _) = %s", nicID1, err)
}
// Created after updating NIC(1)'s NDP configurations
// but the stack's default NDP configurations should not
// have been updated.
- if err := s.CreateNIC(3, e); err != nil {
- t.Fatalf("CreateNIC(3) = %s", err)
+ if err := s.CreateNIC(nicID3, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID3, err)
}
// Add addresses for each NIC.
- if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(1, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
+ if err := s.AddAddress(nicID1, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addr1, err)
}
- if err := s.AddAddress(2, header.IPv6ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(2, %d, %s) = %s", header.IPv6ProtocolNumber, addr2, err)
+ if err := s.AddAddress(nicID2, header.IPv6ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addr2, err)
}
- if err := s.AddAddress(3, header.IPv6ProtocolNumber, addr3); err != nil {
- t.Fatalf("AddAddress(3, %d, %s) = %s", header.IPv6ProtocolNumber, addr3, err)
+ expectDADEvent(nicID2, addr2)
+ if err := s.AddAddress(nicID3, header.IPv6ProtocolNumber, addr3); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addr3, err)
}
+ expectDADEvent(nicID3, addr3)
// Address should not be considered bound to NIC(1) yet
// (DAD ongoing).
- addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ addr, err := s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber)
if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err)
}
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want)
}
// Should get the address on NIC(2) and NIC(3)
@@ -750,31 +846,31 @@ func TestSetNDPConfigurations(t *testing.T) {
// it as the stack was configured to not do DAD by
// default and we only updated the NDP configurations on
// NIC(1).
- addr, err = s.GetMainNICAddress(2, header.IPv6ProtocolNumber)
+ addr, err = s.GetMainNICAddress(nicID2, header.IPv6ProtocolNumber)
if err != nil {
- t.Fatalf("stack.GetMainNICAddress(2, _) err = %s", err)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID2, header.IPv6ProtocolNumber, err)
}
if addr.Address != addr2 {
- t.Fatalf("got stack.GetMainNICAddress(2, _) = %s, want = %s", addr, addr2)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID2, header.IPv6ProtocolNumber, addr, addr2)
}
- addr, err = s.GetMainNICAddress(3, header.IPv6ProtocolNumber)
+ addr, err = s.GetMainNICAddress(nicID3, header.IPv6ProtocolNumber)
if err != nil {
- t.Fatalf("stack.GetMainNICAddress(3, _) err = %s", err)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID3, header.IPv6ProtocolNumber, err)
}
if addr.Address != addr3 {
- t.Fatalf("got stack.GetMainNICAddress(3, _) = %s, want = %s", addr, addr3)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID3, header.IPv6ProtocolNumber, addr, addr3)
}
// Sleep until right (500ms before) before resolution to
// make sure the address didn't resolve on NIC(1) yet.
const delta = 500 * time.Millisecond
time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta)
- addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber)
if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err)
}
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want)
}
// Wait for DAD to resolve.
@@ -788,25 +884,16 @@ func TestSetNDPConfigurations(t *testing.T) {
// means something is wrong.
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
- if e.err != nil {
- t.Fatal("got DAD error: ", e.err)
- }
- if e.nicID != 1 {
- t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID)
- }
- if e.addr != addr1 {
- t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
- }
- if !e.resolved {
- t.Fatal("got DAD event w/ resolved = false, want = true")
+ if diff := checkDADEvent(e, nicID1, addr1, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
- addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber)
if err != nil {
- t.Fatalf("stack.GetMainNICAddress(1, _) err = %s", err)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err)
}
if addr.Address != addr1 {
- t.Fatalf("got stack.GetMainNICAddress(1, _) = %s, want = %s", addr, addr1)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID1, header.IPv6ProtocolNumber, addr, addr1)
}
})
}
@@ -1524,7 +1611,7 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
}
// Checks to see if list contains an IPv6 address, item.
-func contains(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool {
+func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool {
protocolAddress := tcpip.ProtocolAddress{
Protocol: header.IPv6ProtocolNumber,
AddressWithPrefix: item,
@@ -1650,7 +1737,7 @@ func TestAutoGenAddr(t *testing.T) {
// with non-zero lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
expectAutoGenAddrEvent(addr1, newAddr)
- if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
t.Fatalf("Should have %s in the list of addresses", addr1)
}
@@ -1666,10 +1753,10 @@ func TestAutoGenAddr(t *testing.T) {
// Receive an RA with prefix2 in a PI.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
expectAutoGenAddrEvent(addr2, newAddr)
- if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
t.Fatalf("Should have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
t.Fatalf("Should have %s in the list of addresses", addr2)
}
@@ -1690,10 +1777,10 @@ func TestAutoGenAddr(t *testing.T) {
case <-time.After(newMinVLDuration + defaultAsyncEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
- if contains(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
t.Fatalf("Should not have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
t.Fatalf("Should have %s in the list of addresses", addr2)
}
}
@@ -1838,7 +1925,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
// Receive PI for prefix1.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
expectAutoGenAddrEvent(addr1, newAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should have %s in the list of addresses", addr1)
}
expectPrimaryAddr(addr1)
@@ -1846,7 +1933,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
// Deprecate addr for prefix1 immedaitely.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
expectAutoGenAddrEvent(addr1, deprecatedAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should have %s in the list of addresses", addr1)
}
// addr should still be the primary endpoint as there are no other addresses.
@@ -1864,7 +1951,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
// Receive PI for prefix2.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
expectAutoGenAddrEvent(addr2, newAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
expectPrimaryAddr(addr2)
@@ -1872,7 +1959,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
// Deprecate addr for prefix2 immedaitely.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
expectAutoGenAddrEvent(addr2, deprecatedAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
// addr1 should be the primary endpoint now since addr2 is deprecated but
@@ -1967,7 +2054,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
// Receive PI for prefix2.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
expectAutoGenAddrEvent(addr2, newAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
expectPrimaryAddr(addr2)
@@ -1975,10 +2062,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
// Receive a PI for prefix1.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90))
expectAutoGenAddrEvent(addr1, newAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
expectPrimaryAddr(addr1)
@@ -1994,10 +2081,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
// Wait for addr of prefix1 to be deprecated.
expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
// addr2 should be the primary endpoint now since addr1 is deprecated but
@@ -2034,10 +2121,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
// Wait for addr of prefix1 to be deprecated.
expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
// addr2 should be the primary endpoint now since it is not deprecated.
@@ -2048,10 +2135,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
// Wait for addr of prefix1 to be invalidated.
expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncEventTimeout)
- if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
expectPrimaryAddr(addr2)
@@ -2097,10 +2184,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) {
case <-time.After(newMinVLDuration + defaultAsyncEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
- if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
- if contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should not have %s in the list of addresses", addr2)
}
// Should not have any primary endpoints.
@@ -2585,7 +2672,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil {
t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err)
}
- if !contains(s.NICInfo()[1].ProtocolAddresses, addr) {
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
t.Fatalf("Should have %s in the list of addresses", addr1)
}
@@ -2598,7 +2685,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically")
default:
}
- if !contains(s.NICInfo()[1].ProtocolAddresses, addr) {
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
t.Fatalf("Should have %s in the list of addresses", addr1)
}
@@ -2609,7 +2696,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) {
t.Fatal("unexpectedly received an auto gen addr event")
case <-time.After(lifetimeSeconds*time.Second + defaultTimeout):
}
- if !contains(s.NICInfo()[1].ProtocolAddresses, addr) {
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
t.Fatalf("Should have %s in the list of addresses", addr1)
}
}
@@ -2687,17 +2774,17 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
const validLifetimeSecondPrefix1 = 1
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, validLifetimeSecondPrefix1, 0))
expectAutoGenAddrEvent(addr1, newAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should have %s in the list of addresses", addr1)
}
// Receive an RA with prefix2 in a PI with a large valid lifetime.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
expectAutoGenAddrEvent(addr2, newAddr)
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
@@ -2710,10 +2797,10 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
- if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
t.Fatalf("should not have %s in the list of addresses", addr1)
}
- if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
t.Fatalf("should have %s in the list of addresses", addr2)
}
}
@@ -2866,257 +2953,333 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
}
}
-// TestCleanupHostOnlyStateOnBecomingRouter tests that all discovered routers
-// and prefixes, and auto-generated addresses get invalidated when a NIC
-// becomes a router.
-func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) {
+// TestCleanupNDPState tests that all discovered routers and prefixes, and
+// auto-generated addresses are invalidated when a NIC becomes a router.
+func TestCleanupNDPState(t *testing.T) {
t.Parallel()
const (
- lifetimeSeconds = 5
- maxEvents = 4
- nicID1 = 1
- nicID2 = 2
+ lifetimeSeconds = 5
+ maxRouterAndPrefixEvents = 4
+ nicID1 = 1
+ nicID2 = 2
)
prefix1, subnet1, e1Addr1 := prefixSubnetAddr(0, linkAddr1)
prefix2, subnet2, e1Addr2 := prefixSubnetAddr(1, linkAddr1)
e2Addr1 := addrForSubnet(subnet1, linkAddr2)
e2Addr2 := addrForSubnet(subnet2, linkAddr2)
-
- ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, maxEvents),
- rememberRouter: true,
- prefixC: make(chan ndpPrefixEvent, maxEvents),
- rememberPrefix: true,
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, maxEvents),
+ llAddrWithPrefix1 := tcpip.AddressWithPrefix{
+ Address: llAddr1,
+ PrefixLen: 64,
}
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- DiscoverOnLinkPrefixes: true,
- AutoGenGlobalAddresses: true,
+ llAddrWithPrefix2 := tcpip.AddressWithPrefix{
+ Address: llAddr2,
+ PrefixLen: 64,
+ }
+
+ tests := []struct {
+ name string
+ cleanupFn func(t *testing.T, s *stack.Stack)
+ keepAutoGenLinkLocal bool
+ maxAutoGenAddrEvents int
+ }{
+ // A NIC should still keep its auto-generated link-local address when
+ // becoming a router.
+ {
+ name: "Forwarding Enable",
+ cleanupFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ s.SetForwarding(true)
+ },
+ keepAutoGenLinkLocal: true,
+ maxAutoGenAddrEvents: 4,
},
- NDPDisp: &ndpDisp,
- })
- e1 := channel.New(0, 1280, linkAddr1)
- if err := s.CreateNIC(nicID1, e1); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err)
- }
+ // A NIC should cleanup all NDP state when it is disabled.
+ {
+ name: "NIC Disable",
+ cleanupFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
- e2 := channel.New(0, 1280, linkAddr2)
- if err := s.CreateNIC(nicID2, e2); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err)
+ if err := s.DisableNIC(nicID1); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
+ }
+ if err := s.DisableNIC(nicID2); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
+ }
+ },
+ keepAutoGenLinkLocal: false,
+ maxAutoGenAddrEvents: 6,
+ },
}
- expectRouterEvent := func() (bool, ndpRouterEvent) {
- select {
- case e := <-ndpDisp.routerC:
- return true, e
- default:
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents),
+ rememberRouter: true,
+ prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents),
+ rememberPrefix: true,
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents),
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: true,
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ DiscoverOnLinkPrefixes: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
- return false, ndpRouterEvent{}
- }
+ expectRouterEvent := func() (bool, ndpRouterEvent) {
+ select {
+ case e := <-ndpDisp.routerC:
+ return true, e
+ default:
+ }
- expectPrefixEvent := func() (bool, ndpPrefixEvent) {
- select {
- case e := <-ndpDisp.prefixC:
- return true, e
- default:
- }
+ return false, ndpRouterEvent{}
+ }
- return false, ndpPrefixEvent{}
- }
+ expectPrefixEvent := func() (bool, ndpPrefixEvent) {
+ select {
+ case e := <-ndpDisp.prefixC:
+ return true, e
+ default:
+ }
- expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) {
- select {
- case e := <-ndpDisp.autoGenAddrC:
- return true, e
- default:
- }
+ return false, ndpPrefixEvent{}
+ }
- return false, ndpAutoGenAddrEvent{}
- }
+ expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) {
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ return true, e
+ default:
+ }
- // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr1 and
- // llAddr2) w/ PI (for prefix1 in RA from llAddr1 and prefix2 in RA from
- // llAddr2) to discover multiple routers and prefixes, and auto-gen
- // multiple addresses.
+ return false, ndpAutoGenAddrEvent{}
+ }
- e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr1, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
- // We have other tests that make sure we receive the *correct* events
- // on normal discovery of routers/prefixes, and auto-generated
- // addresses. Here we just make sure we get an event and let other tests
- // handle the correctness check.
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr1, nicID1)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1)
- }
+ e1 := channel.New(0, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID1, e1); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err)
+ }
+ // We have other tests that make sure we receive the *correct* events
+ // on normal discovery of routers/prefixes, and auto-generated
+ // addresses. Here we just make sure we get an event and let other tests
+ // handle the correctness check.
+ expectAutoGenAddrEvent()
- e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr2, nicID1)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1)
- }
+ e2 := channel.New(0, 1280, linkAddr2)
+ if err := s.CreateNIC(nicID2, e2); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err)
+ }
+ expectAutoGenAddrEvent()
- e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr1, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr1, nicID2)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2)
- }
+ // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and
+ // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from
+ // llAddr4) to discover multiple routers and prefixes, and auto-gen
+ // multiple addresses.
- e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr2, nicID2)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2)
- }
+ e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1)
+ }
- // We should have the auto-generated addresses added.
- nicinfo := s.NICInfo()
- nic1Addrs := nicinfo[nicID1].ProtocolAddresses
- nic2Addrs := nicinfo[nicID2].ProtocolAddresses
- if !contains(nic1Addrs, e1Addr1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
- }
- if !contains(nic1Addrs, e1Addr2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
- }
- if !contains(nic2Addrs, e2Addr1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
- }
- if !contains(nic2Addrs, e2Addr2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
- }
+ e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1)
+ }
- // We can't proceed any further if we already failed the test (missing
- // some discovery/auto-generated address events or addresses).
- if t.Failed() {
- t.FailNow()
- }
+ e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2)
+ }
- s.SetForwarding(true)
+ e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2)
+ }
- // Collect invalidation events after becoming a router
- gotRouterEvents := make(map[ndpRouterEvent]int)
- for i := 0; i < maxEvents; i++ {
- ok, e := expectRouterEvent()
- if !ok {
- t.Errorf("expected %d router events after becoming a router; got = %d", maxEvents, i)
- break
- }
- gotRouterEvents[e]++
- }
- gotPrefixEvents := make(map[ndpPrefixEvent]int)
- for i := 0; i < maxEvents; i++ {
- ok, e := expectPrefixEvent()
- if !ok {
- t.Errorf("expected %d prefix events after becoming a router; got = %d", maxEvents, i)
- break
- }
- gotPrefixEvents[e]++
- }
- gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int)
- for i := 0; i < maxEvents; i++ {
- ok, e := expectAutoGenAddrEvent()
- if !ok {
- t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", maxEvents, i)
- break
- }
- gotAutoGenAddrEvents[e]++
- }
+ // We should have the auto-generated addresses added.
+ nicinfo := s.NICInfo()
+ nic1Addrs := nicinfo[nicID1].ProtocolAddresses
+ nic2Addrs := nicinfo[nicID2].ProtocolAddresses
+ if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ }
+ if !containsV6Addr(nic1Addrs, e1Addr1) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
+ }
+ if !containsV6Addr(nic1Addrs, e1Addr2) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
+ }
+ if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ }
+ if !containsV6Addr(nic2Addrs, e2Addr1) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
+ }
+ if !containsV6Addr(nic2Addrs, e2Addr2) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
+ }
- // No need to proceed any further if we already failed the test (missing
- // some invalidation events).
- if t.Failed() {
- t.FailNow()
- }
+ // We can't proceed any further if we already failed the test (missing
+ // some discovery/auto-generated address events or addresses).
+ if t.Failed() {
+ t.FailNow()
+ }
- expectedRouterEvents := map[ndpRouterEvent]int{
- {nicID: nicID1, addr: llAddr1, discovered: false}: 1,
- {nicID: nicID1, addr: llAddr2, discovered: false}: 1,
- {nicID: nicID2, addr: llAddr1, discovered: false}: 1,
- {nicID: nicID2, addr: llAddr2, discovered: false}: 1,
- }
- if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" {
- t.Errorf("router events mismatch (-want +got):\n%s", diff)
- }
- expectedPrefixEvents := map[ndpPrefixEvent]int{
- {nicID: nicID1, prefix: subnet1, discovered: false}: 1,
- {nicID: nicID1, prefix: subnet2, discovered: false}: 1,
- {nicID: nicID2, prefix: subnet1, discovered: false}: 1,
- {nicID: nicID2, prefix: subnet2, discovered: false}: 1,
- }
- if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" {
- t.Errorf("prefix events mismatch (-want +got):\n%s", diff)
- }
- expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{
- {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1,
- {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1,
- {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1,
- {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1,
- }
- if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" {
- t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff)
- }
+ test.cleanupFn(t, s)
- // Make sure the auto-generated addresses got removed.
- nicinfo = s.NICInfo()
- nic1Addrs = nicinfo[nicID1].ProtocolAddresses
- nic2Addrs = nicinfo[nicID2].ProtocolAddresses
- if contains(nic1Addrs, e1Addr1) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
- }
- if contains(nic1Addrs, e1Addr2) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
- }
- if contains(nic2Addrs, e2Addr1) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
- }
- if contains(nic2Addrs, e2Addr2) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
- }
+ // Collect invalidation events after having NDP state cleaned up.
+ gotRouterEvents := make(map[ndpRouterEvent]int)
+ for i := 0; i < maxRouterAndPrefixEvents; i++ {
+ ok, e := expectRouterEvent()
+ if !ok {
+ t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
+ break
+ }
+ gotRouterEvents[e]++
+ }
+ gotPrefixEvents := make(map[ndpPrefixEvent]int)
+ for i := 0; i < maxRouterAndPrefixEvents; i++ {
+ ok, e := expectPrefixEvent()
+ if !ok {
+ t.Errorf("expected %d prefix events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
+ break
+ }
+ gotPrefixEvents[e]++
+ }
+ gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int)
+ for i := 0; i < test.maxAutoGenAddrEvents; i++ {
+ ok, e := expectAutoGenAddrEvent()
+ if !ok {
+ t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", test.maxAutoGenAddrEvents, i)
+ break
+ }
+ gotAutoGenAddrEvents[e]++
+ }
- // Should not get any more events (invalidation timers should have been
- // cancelled when we transitioned into a router).
- time.Sleep(lifetimeSeconds*time.Second + defaultTimeout)
- select {
- case <-ndpDisp.routerC:
- t.Error("unexpected router event")
- default:
- }
- select {
- case <-ndpDisp.prefixC:
- t.Error("unexpected prefix event")
- default:
- }
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Error("unexpected auto-generated address event")
- default:
+ // No need to proceed any further if we already failed the test (missing
+ // some invalidation events).
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ expectedRouterEvents := map[ndpRouterEvent]int{
+ {nicID: nicID1, addr: llAddr3, discovered: false}: 1,
+ {nicID: nicID1, addr: llAddr4, discovered: false}: 1,
+ {nicID: nicID2, addr: llAddr3, discovered: false}: 1,
+ {nicID: nicID2, addr: llAddr4, discovered: false}: 1,
+ }
+ if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" {
+ t.Errorf("router events mismatch (-want +got):\n%s", diff)
+ }
+ expectedPrefixEvents := map[ndpPrefixEvent]int{
+ {nicID: nicID1, prefix: subnet1, discovered: false}: 1,
+ {nicID: nicID1, prefix: subnet2, discovered: false}: 1,
+ {nicID: nicID2, prefix: subnet1, discovered: false}: 1,
+ {nicID: nicID2, prefix: subnet2, discovered: false}: 1,
+ }
+ if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" {
+ t.Errorf("prefix events mismatch (-want +got):\n%s", diff)
+ }
+ expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{
+ {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1,
+ {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1,
+ {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1,
+ {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1,
+ }
+
+ if !test.keepAutoGenLinkLocal {
+ expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID1, addr: llAddrWithPrefix1, eventType: invalidatedAddr}] = 1
+ expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID2, addr: llAddrWithPrefix2, eventType: invalidatedAddr}] = 1
+ }
+
+ if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" {
+ t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff)
+ }
+
+ // Make sure the auto-generated addresses got removed.
+ nicinfo = s.NICInfo()
+ nic1Addrs = nicinfo[nicID1].ProtocolAddresses
+ nic2Addrs = nicinfo[nicID2].ProtocolAddresses
+ if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal {
+ if test.keepAutoGenLinkLocal {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ } else {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ }
+ }
+ if containsV6Addr(nic1Addrs, e1Addr1) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
+ }
+ if containsV6Addr(nic1Addrs, e1Addr2) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
+ }
+ if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal {
+ if test.keepAutoGenLinkLocal {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ } else {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ }
+ }
+ if containsV6Addr(nic2Addrs, e2Addr1) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
+ }
+ if containsV6Addr(nic2Addrs, e2Addr2) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
+ }
+
+ // Should not get any more events (invalidation timers should have been
+ // cancelled when the NDP state was cleaned up).
+ time.Sleep(lifetimeSeconds*time.Second + defaultTimeout)
+ select {
+ case <-ndpDisp.routerC:
+ t.Error("unexpected router event")
+ default:
+ }
+ select {
+ case <-ndpDisp.prefixC:
+ t.Error("unexpected prefix event")
+ default:
+ }
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Error("unexpected auto-generated address event")
+ default:
+ }
+ })
}
}
@@ -3216,8 +3379,11 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
func TestRouterSolicitation(t *testing.T) {
t.Parallel()
+ const nicID = 1
+
tests := []struct {
name string
+ linkHeaderLen uint16
maxRtrSolicit uint8
rtrSolicitInt time.Duration
effectiveRtrSolicitInt time.Duration
@@ -3234,6 +3400,7 @@ func TestRouterSolicitation(t *testing.T) {
},
{
name: "Two RS with delay",
+ linkHeaderLen: 1,
maxRtrSolicit: 2,
rtrSolicitInt: time.Second,
effectiveRtrSolicitInt: time.Second,
@@ -3242,6 +3409,7 @@ func TestRouterSolicitation(t *testing.T) {
},
{
name: "Single RS without delay",
+ linkHeaderLen: 2,
maxRtrSolicit: 1,
rtrSolicitInt: time.Second,
effectiveRtrSolicitInt: time.Second,
@@ -3250,6 +3418,7 @@ func TestRouterSolicitation(t *testing.T) {
},
{
name: "Two RS without delay and invalid zero interval",
+ linkHeaderLen: 3,
maxRtrSolicit: 2,
rtrSolicitInt: 0,
effectiveRtrSolicitInt: 4 * time.Second,
@@ -3287,7 +3456,11 @@ func TestRouterSolicitation(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
- e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1)
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.maxRtrSolicit), 1280, linkAddr1),
+ headerLength: test.linkHeaderLen,
+ }
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
waitForPkt := func(timeout time.Duration) {
t.Helper()
ctx, _ := context.WithTimeout(context.Background(), timeout)
@@ -3300,6 +3473,12 @@ func TestRouterSolicitation(t *testing.T) {
if p.Proto != header.IPv6ProtocolNumber {
t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
}
+
+ // Make sure the right remote link address is used.
+ if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
+
checker.IPv6(t,
p.Pkt.Header.View(),
checker.SrcAddr(header.IPv6Any),
@@ -3307,6 +3486,10 @@ func TestRouterSolicitation(t *testing.T) {
checker.TTL(header.NDPHopLimit),
checker.NDPRS(),
)
+
+ if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want)
+ }
}
waitForNothing := func(timeout time.Duration) {
t.Helper()
@@ -3323,8 +3506,8 @@ func TestRouterSolicitation(t *testing.T) {
MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
},
})
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
// Make sure each RS got sent at the right
@@ -3356,77 +3539,130 @@ func TestRouterSolicitation(t *testing.T) {
})
}
-// TestStopStartSolicitingRouters tests that when forwarding is enabled or
-// disabled, router solicitations are stopped or started, respecitively.
func TestStopStartSolicitingRouters(t *testing.T) {
t.Parallel()
+ const nicID = 1
const interval = 500 * time.Millisecond
const delay = time.Second
const maxRtrSolicitations = 3
- e := channel.New(maxRtrSolicitations, 1280, linkAddr1)
- waitForPkt := func(timeout time.Duration) {
- t.Helper()
- ctx, _ := context.WithTimeout(context.Background(), timeout)
- p, ok := e.ReadContext(ctx)
- if !ok {
- t.Fatal("timed out waiting for packet")
- return
- }
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
- checker.IPv6(t, p.Pkt.Header.View(),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS())
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- MaxRtrSolicitations: maxRtrSolicitations,
- RtrSolicitationInterval: interval,
- MaxRtrSolicitationDelay: delay,
+ tests := []struct {
+ name string
+ startFn func(t *testing.T, s *stack.Stack)
+ stopFn func(t *testing.T, s *stack.Stack)
+ }{
+ // Tests that when forwarding is enabled or disabled, router solicitations
+ // are stopped or started, respectively.
+ {
+ name: "Forwarding enabled and disabled",
+ startFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ s.SetForwarding(false)
+ },
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ s.SetForwarding(true)
+ },
},
- })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
- // Enable forwarding which should stop router solicitations.
- s.SetForwarding(true)
- ctx, _ := context.WithTimeout(context.Background(), delay+defaultTimeout)
- if _, ok := e.ReadContext(ctx); ok {
- // A single RS may have been sent before forwarding was enabled.
- ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout)
- if _, ok = e.ReadContext(ctx); ok {
- t.Fatal("Should not have sent more than one RS message")
- }
- }
+ // Tests that when a NIC is enabled or disabled, router solicitations
+ // are started or stopped, respectively.
+ {
+ name: "NIC disabled and enabled",
+ startFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
- // Enabling forwarding again should do nothing.
- s.SetForwarding(true)
- ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout)
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got a packet after becoming a router")
- }
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ },
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
- // Disable forwarding which should start router solicitations.
- s.SetForwarding(false)
- waitForPkt(delay + defaultAsyncEventTimeout)
- waitForPkt(interval + defaultAsyncEventTimeout)
- waitForPkt(interval + defaultAsyncEventTimeout)
- ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout)
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ },
+ },
}
- // Disabling forwarding again should do nothing.
- s.SetForwarding(false)
- ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout)
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got a packet after becoming a router")
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(maxRtrSolicitations, 1280, linkAddr1)
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+ p, ok := e.ReadContext(ctx)
+ if !ok {
+ t.Fatal("timed out waiting for packet")
+ return
+ }
+
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+ checker.IPv6(t, p.Pkt.Header.View(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS())
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ MaxRtrSolicitations: maxRtrSolicitations,
+ RtrSolicitationInterval: interval,
+ MaxRtrSolicitationDelay: delay,
+ },
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ // Stop soliciting routers.
+ test.stopFn(t, s)
+ ctx, cancel := context.WithTimeout(context.Background(), delay+defaultTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ // A single RS may have been sent before forwarding was enabled.
+ ctx, cancel := context.WithTimeout(context.Background(), interval+defaultTimeout)
+ defer cancel()
+ if _, ok = e.ReadContext(ctx); ok {
+ t.Fatal("should not have sent more than one RS message")
+ }
+ }
+
+ // Stopping router solicitations after it has already been stopped should
+ // do nothing.
+ test.stopFn(t, s)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got a packet after router solicitation has been stopepd")
+ }
+
+ // Start soliciting routers.
+ test.startFn(t, s)
+ waitForPkt(delay + defaultAsyncEventTimeout)
+ waitForPkt(interval + defaultAsyncEventTimeout)
+ waitForPkt(interval + defaultAsyncEventTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), interval+defaultTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
+ }
+
+ // Starting router solicitations after it has already completed should do
+ // nothing.
+ test.startFn(t, s)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got a packet after finishing router solicitations")
+ }
+ })
}
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 7dad9a8cb..46d3a6646 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -16,6 +16,7 @@ package stack
import (
"log"
+ "reflect"
"sort"
"strings"
"sync/atomic"
@@ -26,6 +27,14 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
)
+var ipv4BroadcastAddr = tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: header.IPv4Broadcast,
+ PrefixLen: 8 * header.IPv4AddressSize,
+ },
+}
+
// NIC represents a "network interface card" to which the networking stack is
// attached.
type NIC struct {
@@ -39,6 +48,7 @@ type NIC struct {
mu struct {
sync.RWMutex
+ enabled bool
spoofing bool
promiscuous bool
primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint
@@ -56,6 +66,14 @@ type NIC struct {
type NICStats struct {
Tx DirectionStats
Rx DirectionStats
+
+ DisabledRx DirectionStats
+}
+
+func makeNICStats() NICStats {
+ var s NICStats
+ tcpip.InitStatCounters(reflect.ValueOf(&s).Elem())
+ return s
}
// DirectionStats includes packet and byte counts.
@@ -99,16 +117,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
name: name,
linkEP: ep,
context: ctx,
- stats: NICStats{
- Tx: DirectionStats{
- Packets: &tcpip.StatCounter{},
- Bytes: &tcpip.StatCounter{},
- },
- Rx: DirectionStats{
- Packets: &tcpip.StatCounter{},
- Bytes: &tcpip.StatCounter{},
- },
- },
+ stats: makeNICStats(),
}
nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint)
nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint)
@@ -131,20 +140,97 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{}
}
+ nic.linkEP.Attach(nic)
+
return nic
}
-// enable enables the NIC. enable will attach the link to its LinkEndpoint and
-// join the IPv6 All-Nodes Multicast address (ff02::1).
+// enabled returns true if n is enabled.
+func (n *NIC) enabled() bool {
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ n.mu.RUnlock()
+ return enabled
+}
+
+// disable disables n.
+//
+// It undoes the work done by enable.
+func (n *NIC) disable() *tcpip.Error {
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ n.mu.RUnlock()
+ if !enabled {
+ return nil
+ }
+
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if !n.mu.enabled {
+ return nil
+ }
+
+ // TODO(b/147015577): Should Routes that are currently bound to n be
+ // invalidated? Currently, Routes will continue to work when a NIC is enabled
+ // again, and applications may not know that the underlying NIC was ever
+ // disabled.
+
+ if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok {
+ n.mu.ndp.stopSolicitingRouters()
+ n.mu.ndp.cleanupState(false /* hostOnly */)
+
+ // Stop DAD for all the unicast IPv6 endpoints that are in the
+ // permanentTentative state.
+ for _, r := range n.mu.endpoints {
+ if addr := r.ep.ID().LocalAddress; r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) {
+ n.mu.ndp.stopDuplicateAddressDetection(addr)
+ }
+ }
+
+ // The NIC may have already left the multicast group.
+ if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
+ return err
+ }
+ }
+
+ if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
+ // The address may have already been removed.
+ if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress {
+ return err
+ }
+ }
+
+ n.mu.enabled = false
+ return nil
+}
+
+// enable enables n.
+//
+// If the stack has IPv6 enabled, enable will join the IPv6 All-Nodes Multicast
+// address (ff02::1), start DAD for permanent addresses, and start soliciting
+// routers if the stack is not operating as a router. If the stack is also
+// configured to auto-generate a link-local address, one will be generated.
func (n *NIC) enable() *tcpip.Error {
- n.attachLinkEndpoint()
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ n.mu.RUnlock()
+ if enabled {
+ return nil
+ }
+
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if n.mu.enabled {
+ return nil
+ }
+
+ n.mu.enabled = true
// Create an endpoint to receive broadcast packets on this interface.
if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
- if err := n.AddAddress(tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize},
- }, NeverPrimaryEndpoint); err != nil {
+ if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
return err
}
}
@@ -166,44 +252,38 @@ func (n *NIC) enable() *tcpip.Error {
return nil
}
- n.mu.Lock()
- defer n.mu.Unlock()
-
+ // Join the All-Nodes multicast group before starting DAD as responses to DAD
+ // (NDP NS) messages may be sent to the All-Nodes multicast group if the
+ // source address of the NDP NS is the unspecified address, as per RFC 4861
+ // section 7.2.4.
if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil {
return err
}
- // Do not auto-generate an IPv6 link-local address for loopback devices.
- if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() {
- var addr tcpip.Address
- if oIID := n.stack.opaqueIIDOpts; oIID.NICNameFromID != nil {
- addr = header.LinkLocalAddrWithOpaqueIID(oIID.NICNameFromID(n.ID(), n.name), 0, oIID.SecretKey)
- } else {
- l2addr := n.linkEP.LinkAddress()
-
- // Only attempt to generate the link-local address if we have a valid MAC
- // address.
- //
- // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
- // LinkEndpoint.LinkAddress) before reaching this point.
- if !header.IsValidUnicastEthernetAddress(l2addr) {
- return nil
- }
-
- addr = header.LinkLocalAddr(l2addr)
+ // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
+ // state.
+ //
+ // Addresses may have aleady completed DAD but in the time since the NIC was
+ // last enabled, other devices may have acquired the same addresses.
+ for _, r := range n.mu.endpoints {
+ addr := r.ep.ID().LocalAddress
+ if k := r.getKind(); (k != permanent && k != permanentTentative) || !header.IsV6UnicastAddress(addr) {
+ continue
}
- if _, err := n.addAddressLocked(tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen,
- },
- }, CanBePrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
+ r.setKind(permanentTentative)
+ if err := n.mu.ndp.startDuplicateAddressDetection(addr, r); err != nil {
return err
}
}
+ // Do not auto-generate an IPv6 link-local address for loopback devices.
+ if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() {
+ // The valid and preferred lifetime is infinite for the auto-generated
+ // link-local address.
+ n.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime)
+ }
+
// If we are operating as a router, then do not solicit routers since we
// won't process the RAs anyways.
//
@@ -218,6 +298,33 @@ func (n *NIC) enable() *tcpip.Error {
return nil
}
+// remove detaches NIC from the link endpoint, and marks existing referenced
+// network endpoints expired. This guarantees no packets between this NIC and
+// the network stack.
+func (n *NIC) remove() *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ // Detach from link endpoint, so no packet comes in.
+ n.linkEP.Attach(nil)
+
+ // Remove permanent and permanentTentative addresses, so no packet goes out.
+ var errs []*tcpip.Error
+ for nid, ref := range n.mu.endpoints {
+ switch ref.getKind() {
+ case permanentTentative, permanent:
+ if err := n.removePermanentAddressLocked(nid.LocalAddress); err != nil {
+ errs = append(errs, err)
+ }
+ }
+ }
+ if len(errs) > 0 {
+ return errs[0]
+ }
+
+ return nil
+}
+
// becomeIPv6Router transitions n into an IPv6 router.
//
// When transitioning into an IPv6 router, host-only state (NDP discovered
@@ -227,7 +334,7 @@ func (n *NIC) becomeIPv6Router() {
n.mu.Lock()
defer n.mu.Unlock()
- n.mu.ndp.cleanupHostOnlyState()
+ n.mu.ndp.cleanupState(true /* hostOnly */)
n.mu.ndp.stopSolicitingRouters()
}
@@ -242,12 +349,6 @@ func (n *NIC) becomeIPv6Host() {
n.mu.ndp.startSolicitingRouters()
}
-// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
-// to start delivering packets.
-func (n *NIC) attachLinkEndpoint() {
- n.linkEP.Attach(n)
-}
-
// setPromiscuousMode enables or disables promiscuous mode.
func (n *NIC) setPromiscuousMode(enable bool) {
n.mu.Lock()
@@ -633,7 +734,9 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address)
// If the address is an IPv6 address and it is a permanent address,
- // mark it as tentative so it goes through the DAD process.
+ // mark it as tentative so it goes through the DAD process if the NIC is
+ // enabled. If the NIC is not enabled, DAD will be started when the NIC is
+ // enabled.
if isIPv6Unicast && kind == permanent {
kind = permanentTentative
}
@@ -668,8 +771,8 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
n.insertPrimaryEndpointLocked(ref, peb)
- // If we are adding a tentative IPv6 address, start DAD.
- if isIPv6Unicast && kind == permanentTentative {
+ // If we are adding a tentative IPv6 address, start DAD if the NIC is enabled.
+ if isIPv6Unicast && kind == permanentTentative && n.mu.enabled {
if err := n.mu.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil {
return nil, err
}
@@ -700,11 +803,10 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
// Don't include tentative, expired or temporary endpoints to
// avoid confusion and prevent the caller from using those.
switch ref.getKind() {
- case permanentTentative, permanentExpired, temporary:
- // TODO(b/140898488): Should tentative addresses be
- // returned?
+ case permanentExpired, temporary:
continue
}
+
addrs = append(addrs, tcpip.ProtocolAddress{
Protocol: ref.protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
@@ -1002,6 +1104,15 @@ func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
return nil
}
+// isInGroup returns true if n has joined the multicast group addr.
+func (n *NIC) isInGroup(addr tcpip.Address) bool {
+ n.mu.RLock()
+ joins := n.mu.mcastJoins[NetworkEndpointID{addr}]
+ n.mu.RUnlock()
+
+ return joins != 0
+}
+
func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) {
r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
r.RemoteLinkAddress = remotelinkAddr
@@ -1016,11 +1127,23 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address,
// This rule applies only to the slice itself, not to the items of the slice;
// the ownership of the items is not retained by the caller.
func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ // If the NIC is not yet enabled, don't receive any packets.
+ if !enabled {
+ n.mu.RUnlock()
+
+ n.stats.DisabledRx.Packets.Increment()
+ n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
+ return
+ }
+
n.stats.Rx.Packets.Increment()
n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
+ n.mu.RUnlock()
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
return
}
@@ -1032,7 +1155,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
}
// Are any packet sockets listening for this network protocol?
- n.mu.RLock()
packetEPs := n.mu.packetEPs[protocol]
// Check whether there are packet sockets listening for every protocol.
// If we received a packet with protocol EthernetProtocolAll, then the
@@ -1197,11 +1319,21 @@ func (n *NIC) ID() tcpip.NICID {
return n.id
}
+// Name returns the name of n.
+func (n *NIC) Name() string {
+ return n.name
+}
+
// Stack returns the instance of the Stack that owns this NIC.
func (n *NIC) Stack() *Stack {
return n.stack
}
+// LinkEndpoint returns the link endpoint of n.
+func (n *NIC) LinkEndpoint() LinkEndpoint {
+ return n.linkEP
+}
+
// isAddrTentative returns true if addr is tentative on n.
//
// Note that if addr is not associated with n, then this function will return
@@ -1388,7 +1520,7 @@ func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
//
// r's NIC must be read locked.
func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool {
- return r.getKind() != permanentExpired || r.nic.mu.spoofing
+ return r.nic.mu.enabled && (r.getKind() != permanentExpired || r.nic.mu.spoofing)
}
// decRef decrements the ref count and cleans up the endpoint once it reaches
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
new file mode 100644
index 000000000..edaee3b86
--- /dev/null
+++ b/pkg/tcpip/stack/nic_test.go
@@ -0,0 +1,62 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
+ // When the NIC is disabled, the only field that matters is the stats field.
+ // This test is limited to stats counter checks.
+ nic := NIC{
+ stats: makeNICStats(),
+ }
+
+ if got := nic.stats.DisabledRx.Packets.Value(); got != 0 {
+ t.Errorf("got DisabledRx.Packets = %d, want = 0", got)
+ }
+ if got := nic.stats.DisabledRx.Bytes.Value(); got != 0 {
+ t.Errorf("got DisabledRx.Bytes = %d, want = 0", got)
+ }
+ if got := nic.stats.Rx.Packets.Value(); got != 0 {
+ t.Errorf("got Rx.Packets = %d, want = 0", got)
+ }
+ if got := nic.stats.Rx.Bytes.Value(); got != 0 {
+ t.Errorf("got Rx.Bytes = %d, want = 0", got)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ nic.DeliverNetworkPacket(nil, "", "", 0, tcpip.PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()})
+
+ if got := nic.stats.DisabledRx.Packets.Value(); got != 1 {
+ t.Errorf("got DisabledRx.Packets = %d, want = 1", got)
+ }
+ if got := nic.stats.DisabledRx.Bytes.Value(); got != 4 {
+ t.Errorf("got DisabledRx.Bytes = %d, want = 4", got)
+ }
+ if got := nic.stats.Rx.Packets.Value(); got != 0 {
+ t.Errorf("got Rx.Packets = %d, want = 0", got)
+ }
+ if got := nic.stats.Rx.Bytes.Value(); got != 0 {
+ t.Errorf("got Rx.Bytes = %d, want = 0", got)
+ }
+}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index ec91f60dd..f9fd8f18f 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -74,10 +74,11 @@ type TransportEndpoint interface {
// HandleControlPacket takes ownership of pkt.
HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer)
- // Close puts the endpoint in a closed state and frees all resources
- // associated with it. This cleanup may happen asynchronously. Wait can
- // be used to block on this asynchronous cleanup.
- Close()
+ // Abort initiates an expedited endpoint teardown. It puts the endpoint
+ // in a closed state and frees all resources associated with it. This
+ // cleanup may happen asynchronously. Wait can be used to block on this
+ // asynchronous cleanup.
+ Abort()
// Wait waits for any worker goroutines owned by the endpoint to stop.
//
@@ -160,6 +161,13 @@ type TransportProtocol interface {
// Option returns an error if the option is not supported or the
// provided option value is invalid.
Option(option interface{}) *tcpip.Error
+
+ // Close requests that any worker goroutines owned by the protocol
+ // stop.
+ Close()
+
+ // Wait waits for any worker goroutines owned by the protocol to stop.
+ Wait()
}
// TransportDispatcher contains the methods used by the network stack to deliver
@@ -277,7 +285,7 @@ type NetworkProtocol interface {
// DefaultPrefixLen returns the protocol's default prefix length.
DefaultPrefixLen() int
- // ParsePorts returns the source and destination addresses stored in a
+ // ParseAddresses returns the source and destination addresses stored in a
// packet of this protocol.
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
@@ -293,6 +301,13 @@ type NetworkProtocol interface {
// Option returns an error if the option is not supported or the
// provided option value is invalid.
Option(option interface{}) *tcpip.Error
+
+ // Close requests that any worker goroutines owned by the protocol
+ // stop.
+ Close()
+
+ // Wait waits for any worker goroutines owned by the protocol to stop.
+ Wait()
}
// NetworkDispatcher contains the methods used by the network stack to deliver
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 517f4b941..f565aafb2 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -225,7 +225,9 @@ func (r *Route) Release() {
// Clone Clone a route such that the original one can be released and the new
// one will remain valid.
func (r *Route) Clone() Route {
- r.ref.incRef()
+ if r.ref != nil {
+ r.ref.incRef()
+ }
return *r
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 7057b110e..ebb6c5e3b 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -795,6 +795,8 @@ func (s *Stack) Forwarding() bool {
// SetRouteTable assigns the route table to be used by this stack. It
// specifies which NIC to use for given destination address ranges.
+//
+// This method takes ownership of the table.
func (s *Stack) SetRouteTable(table []tcpip.Route) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -809,6 +811,13 @@ func (s *Stack) GetRouteTable() []tcpip.Route {
return append([]tcpip.Route(nil), s.routeTable...)
}
+// AddRoute appends a route to the route table.
+func (s *Stack) AddRoute(route tcpip.Route) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.routeTable = append(s.routeTable, route)
+}
+
// NewEndpoint creates a new transport layer endpoint of the given protocol.
func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
t, ok := s.transportProtocols[transport]
@@ -872,6 +881,8 @@ type NICOptions struct {
// CreateNICWithOptions creates a NIC with the provided id, LinkEndpoint, and
// NICOptions. See the documentation on type NICOptions for details on how
// NICs can be configured.
+//
+// LinkEndpoint.Attach will be called to bind ep with a NetworkDispatcher.
func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -881,8 +892,16 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp
return tcpip.ErrDuplicateNICID
}
- n := newNIC(s, id, opts.Name, ep, opts.Context)
+ // Make sure name is unique, unless unnamed.
+ if opts.Name != "" {
+ for _, n := range s.nics {
+ if n.Name() == opts.Name {
+ return tcpip.ErrDuplicateNICID
+ }
+ }
+ }
+ n := newNIC(s, id, opts.Name, ep, opts.Context)
s.nics[id] = n
if !opts.Disabled {
return n.enable()
@@ -892,34 +911,88 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp
}
// CreateNIC creates a NIC with the provided id and LinkEndpoint and calls
-// `LinkEndpoint.Attach` to start delivering packets to it.
+// LinkEndpoint.Attach to bind ep with a NetworkDispatcher.
func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
return s.CreateNICWithOptions(id, ep, NICOptions{})
}
+// GetNICByName gets the NIC specified by name.
+func (s *Stack) GetNICByName(name string) (*NIC, bool) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ for _, nic := range s.nics {
+ if nic.Name() == name {
+ return nic, true
+ }
+ }
+ return nil, false
+}
+
// EnableNIC enables the given NIC so that the link-layer endpoint can start
// delivering packets to it.
func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
- nic := s.nics[id]
- if nic == nil {
+ nic, ok := s.nics[id]
+ if !ok {
return tcpip.ErrUnknownNICID
}
return nic.enable()
}
+// DisableNIC disables the given NIC.
+func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.disable()
+}
+
// CheckNIC checks if a NIC is usable.
func (s *Stack) CheckNIC(id tcpip.NICID) bool {
s.mu.RLock()
+ defer s.mu.RUnlock()
+
nic, ok := s.nics[id]
- s.mu.RUnlock()
- if ok {
- return nic.linkEP.IsAttached()
+ if !ok {
+ return false
}
- return false
+
+ return nic.enabled()
+}
+
+// RemoveNIC removes NIC and all related routes from the network stack.
+func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+ delete(s.nics, id)
+
+ // Remove routes in-place. n tracks the number of routes written.
+ n := 0
+ for i, r := range s.routeTable {
+ if r.NIC != id {
+ // Keep this route.
+ if i > n {
+ s.routeTable[n] = r
+ }
+ n++
+ }
+ }
+ s.routeTable = s.routeTable[:n]
+
+ return nic.remove()
}
// NICAddressRanges returns a map of NICIDs to their associated subnets.
@@ -971,7 +1044,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
for id, nic := range s.nics {
flags := NICStateFlags{
Up: true, // Netstack interfaces are always up.
- Running: nic.linkEP.IsAttached(),
+ Running: nic.enabled(),
Promiscuous: nic.isPromiscuousMode(),
Loopback: nic.isLoopback(),
}
@@ -1133,7 +1206,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
if id != 0 && !needRoute {
- if nic, ok := s.nics[id]; ok {
+ if nic, ok := s.nics[id]; ok && nic.enabled() {
if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil
}
@@ -1143,7 +1216,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) {
continue
}
- if nic, ok := s.nics[route.NIC]; ok {
+ if nic, ok := s.nics[route.NIC]; ok && nic.enabled() {
if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
if len(remoteAddr) == 0 {
// If no remote address was provided, then the route
@@ -1373,7 +1446,13 @@ func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) {
// Endpoints created or modified during this call may not get closed.
func (s *Stack) Close() {
for _, e := range s.RegisteredEndpoints() {
- e.Close()
+ e.Abort()
+ }
+ for _, p := range s.transportProtocols {
+ p.proto.Close()
+ }
+ for _, p := range s.networkProtocols {
+ p.Close()
}
}
@@ -1391,6 +1470,12 @@ func (s *Stack) Wait() {
for _, e := range s.CleanupEndpoints() {
e.Wait()
}
+ for _, p := range s.transportProtocols {
+ p.proto.Wait()
+ }
+ for _, p := range s.networkProtocols {
+ p.Wait()
+ }
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1596,6 +1681,18 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC
return tcpip.ErrUnknownNICID
}
+// IsInGroup returns true if the NIC with ID nicID has joined the multicast
+// group multicastAddr.
+func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[nicID]; ok {
+ return nic.isInGroup(multicastAddr), nil
+ }
+ return false, tcpip.ErrUnknownNICID
+}
+
// IPTables returns the stack's iptables.
func (s *Stack) IPTables() iptables.IPTables {
s.tablesMu.RLock()
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 834fe9487..e15db40fb 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -33,6 +33,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -234,10 +235,33 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
}
}
+// Close implements TransportProtocol.Close.
+func (*fakeNetworkProtocol) Close() {}
+
+// Wait implements TransportProtocol.Wait.
+func (*fakeNetworkProtocol) Wait() {}
+
func fakeNetFactory() stack.NetworkProtocol {
return &fakeNetworkProtocol{}
}
+// linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify
+// that LinkEndpoint.Attach was called.
+type linkEPWithMockedAttach struct {
+ stack.LinkEndpoint
+ attached bool
+}
+
+// Attach implements stack.LinkEndpoint.Attach.
+func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) {
+ l.LinkEndpoint.Attach(d)
+ l.attached = true
+}
+
+func (l *linkEPWithMockedAttach) isAttached() bool {
+ return l.attached
+}
+
func TestNetworkReceive(t *testing.T) {
// Create a stack with the fake network protocol, one nic, and two
// addresses attached to it: 1 & 2.
@@ -509,6 +533,296 @@ func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr
}
}
+// TestAttachToLinkEndpointImmediately tests that a LinkEndpoint is attached to
+// a NetworkDispatcher when the NIC is created.
+func TestAttachToLinkEndpointImmediately(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ nicOpts stack.NICOptions
+ }{
+ {
+ name: "Create enabled NIC",
+ nicOpts: stack.NICOptions{Disabled: false},
+ },
+ {
+ name: "Create disabled NIC",
+ nicOpts: stack.NICOptions{Disabled: true},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ e := linkEPWithMockedAttach{
+ LinkEndpoint: loopback.New(),
+ }
+
+ if err := s.CreateNICWithOptions(nicID, &e, test.nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err)
+ }
+ if !e.isAttached() {
+ t.Fatalf("link endpoint not attached to a network disatcher")
+ }
+ })
+ }
+}
+
+func TestDisableUnknownNIC(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID {
+ t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID)
+ }
+}
+
+func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ e := loopback.New()
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ checkNIC := func(enabled bool) {
+ t.Helper()
+
+ allNICInfo := s.NICInfo()
+ nicInfo, ok := allNICInfo[nicID]
+ if !ok {
+ t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo)
+ } else if nicInfo.Flags.Running != enabled {
+ t.Errorf("got nicInfo.Flags.Running = %t, want = %t", nicInfo.Flags.Running, enabled)
+ }
+
+ if got := s.CheckNIC(nicID); got != enabled {
+ t.Errorf("got s.CheckNIC(%d) = %t, want = %t", nicID, got, enabled)
+ }
+ }
+
+ // NIC should initially report itself as disabled.
+ checkNIC(false)
+
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ checkNIC(true)
+
+ // If the NIC is not reporting a correct enabled status, we cannot trust the
+ // next check so end the test here.
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ checkNIC(false)
+}
+
+func TestRoutesWithDisabledNIC(t *testing.T) {
+ const unspecifiedNIC = 0
+ const nicID1 = 1
+ const nicID2 = 2
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep1 := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+
+ addr1 := tcpip.Address("\x01")
+ if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
+ }
+
+ ep2 := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID2, ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+
+ addr2 := tcpip.Address("\x02")
+ if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ }
+
+ // Set a route table that sends all packets with odd destination
+ // addresses through the first NIC, and all even destination address
+ // through the second one.
+ {
+ subnet0, err := tcpip.NewSubnet("\x00", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
+ {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
+ })
+ }
+
+ // Test routes to odd address.
+ testRoute(t, s, unspecifiedNIC, "", "\x05", addr1)
+ testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1)
+ testRoute(t, s, nicID1, addr1, "\x05", addr1)
+
+ // Test routes to even address.
+ testRoute(t, s, unspecifiedNIC, "", "\x06", addr2)
+ testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2)
+ testRoute(t, s, nicID2, addr2, "\x06", addr2)
+
+ // Disabling NIC1 should result in no routes to odd addresses. Routes to even
+ // addresses should continue to be available as NIC2 is still enabled.
+ if err := s.DisableNIC(nicID1); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
+ }
+ nic1Dst := tcpip.Address("\x05")
+ testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
+ testNoRoute(t, s, nicID1, addr1, nic1Dst)
+ nic2Dst := tcpip.Address("\x06")
+ testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2)
+ testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2)
+ testRoute(t, s, nicID2, addr2, nic2Dst, addr2)
+
+ // Disabling NIC2 should result in no routes to even addresses. No route
+ // should be available to any address as routes to odd addresses were made
+ // unavailable by disabling NIC1 above.
+ if err := s.DisableNIC(nicID2); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
+ }
+ testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
+ testNoRoute(t, s, nicID1, addr1, nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
+ testNoRoute(t, s, nicID2, addr2, nic2Dst)
+
+ // Enabling NIC1 should make routes to odd addresses available again. Routes
+ // to even addresses should continue to be unavailable as NIC2 is still
+ // disabled.
+ if err := s.EnableNIC(nicID1); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID1, err)
+ }
+ testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1)
+ testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1)
+ testRoute(t, s, nicID1, addr1, nic1Dst, addr1)
+ testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
+ testNoRoute(t, s, nicID2, addr2, nic2Dst)
+}
+
+func TestRouteWritePacketWithDisabledNIC(t *testing.T) {
+ const unspecifiedNIC = 0
+ const nicID1 = 1
+ const nicID2 = 2
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep1 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+
+ addr1 := tcpip.Address("\x01")
+ if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
+ }
+
+ ep2 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID2, ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+
+ addr2 := tcpip.Address("\x02")
+ if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ }
+
+ // Set a route table that sends all packets with odd destination
+ // addresses through the first NIC, and all even destination address
+ // through the second one.
+ {
+ subnet0, err := tcpip.NewSubnet("\x00", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
+ {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
+ })
+ }
+
+ nic1Dst := tcpip.Address("\x05")
+ r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err)
+ }
+ defer r1.Release()
+
+ nic2Dst := tcpip.Address("\x06")
+ r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err)
+ }
+ defer r2.Release()
+
+ // If we failed to get routes r1 or r2, we cannot proceed with the test.
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ buf := buffer.View([]byte{1})
+ testSend(t, r1, ep1, buf)
+ testSend(t, r2, ep2, buf)
+
+ // Writes with Routes that use the disabled NIC1 should fail.
+ if err := s.DisableNIC(nicID1); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
+ }
+ testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
+ testSend(t, r2, ep2, buf)
+
+ // Writes with Routes that use the disabled NIC2 should fail.
+ if err := s.DisableNIC(nicID2); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
+ }
+ testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
+ testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+
+ // Writes with Routes that use the re-enabled NIC1 should succeed.
+ // TODO(b/147015577): Should we instead completely invalidate all Routes that
+ // were bound to a disabled NIC at some point?
+ if err := s.EnableNIC(nicID1); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID1, err)
+ }
+ testSend(t, r1, ep1, buf)
+ testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+}
+
func TestRoutes(t *testing.T) {
// Create a stack with the fake network protocol, two nics, and two
// addresses per nic, the first nic has odd address, the second one has
@@ -1792,6 +2106,91 @@ func TestAddProtocolAddressWithOptions(t *testing.T) {
verifyAddresses(t, expectedAddresses, gotAddresses)
}
+func TestCreateNICWithOptions(t *testing.T) {
+ type callArgsAndExpect struct {
+ nicID tcpip.NICID
+ opts stack.NICOptions
+ err *tcpip.Error
+ }
+
+ tests := []struct {
+ desc string
+ calls []callArgsAndExpect
+ }{
+ {
+ desc: "DuplicateNICID",
+ calls: []callArgsAndExpect{
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{Name: "eth1"},
+ err: nil,
+ },
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{Name: "eth2"},
+ err: tcpip.ErrDuplicateNICID,
+ },
+ },
+ },
+ {
+ desc: "DuplicateName",
+ calls: []callArgsAndExpect{
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{Name: "lo"},
+ err: nil,
+ },
+ {
+ nicID: tcpip.NICID(2),
+ opts: stack.NICOptions{Name: "lo"},
+ err: tcpip.ErrDuplicateNICID,
+ },
+ },
+ },
+ {
+ desc: "Unnamed",
+ calls: []callArgsAndExpect{
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{},
+ err: nil,
+ },
+ {
+ nicID: tcpip.NICID(2),
+ opts: stack.NICOptions{},
+ err: nil,
+ },
+ },
+ },
+ {
+ desc: "UnnamedDuplicateNICID",
+ calls: []callArgsAndExpect{
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{},
+ err: nil,
+ },
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{},
+ err: tcpip.ErrDuplicateNICID,
+ },
+ },
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.desc, func(t *testing.T) {
+ s := stack.New(stack.Options{})
+ ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"))
+ for _, call := range test.calls {
+ if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want {
+ t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want)
+ }
+ }
+ })
+ }
+}
+
func TestNICStats(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
@@ -1894,112 +2293,6 @@ func TestNICForwarding(t *testing.T) {
}
}
-// TestNICAutoGenAddr tests the auto-generation of IPv6 link-local addresses
-// using the modified EUI-64 of the NIC's MAC address (or lack there-of if
-// disabled (default)). Note, DAD will be disabled in these tests.
-func TestNICAutoGenAddr(t *testing.T) {
- tests := []struct {
- name string
- autoGen bool
- linkAddr tcpip.LinkAddress
- iidOpts stack.OpaqueInterfaceIdentifierOptions
- shouldGen bool
- }{
- {
- "Disabled",
- false,
- linkAddr1,
- stack.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: func(nicID tcpip.NICID, _ string) string {
- return fmt.Sprintf("nic%d", nicID)
- },
- },
- false,
- },
- {
- "Enabled",
- true,
- linkAddr1,
- stack.OpaqueInterfaceIdentifierOptions{},
- true,
- },
- {
- "Nil MAC",
- true,
- tcpip.LinkAddress([]byte(nil)),
- stack.OpaqueInterfaceIdentifierOptions{},
- false,
- },
- {
- "Empty MAC",
- true,
- tcpip.LinkAddress(""),
- stack.OpaqueInterfaceIdentifierOptions{},
- false,
- },
- {
- "Invalid MAC",
- true,
- tcpip.LinkAddress("\x01\x02\x03"),
- stack.OpaqueInterfaceIdentifierOptions{},
- false,
- },
- {
- "Multicast MAC",
- true,
- tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
- stack.OpaqueInterfaceIdentifierOptions{},
- false,
- },
- {
- "Unspecified MAC",
- true,
- tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"),
- stack.OpaqueInterfaceIdentifierOptions{},
- false,
- },
- }
-
- for _, test := range tests {
- t.Run(test.name, func(t *testing.T) {
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- OpaqueIIDOpts: test.iidOpts,
- }
-
- if test.autoGen {
- // Only set opts.AutoGenIPv6LinkLocal when test.autoGen is true because
- // opts.AutoGenIPv6LinkLocal should be false by default.
- opts.AutoGenIPv6LinkLocal = true
- }
-
- e := channel.New(10, 1280, test.linkAddr)
- s := stack.New(opts)
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
-
- addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
- }
-
- if test.shouldGen {
- // Should have auto-generated an address and resolved immediately (DAD
- // is disabled).
- if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want)
- }
- } else {
- // Should not have auto-generated an address.
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
- }
- }
- })
- }
-}
-
// TestNICContextPreservation tests that you can read out via stack.NICInfo the
// Context data you pass via NICContext.Context in stack.CreateNICWithOptions.
func TestNICContextPreservation(t *testing.T) {
@@ -2040,11 +2333,9 @@ func TestNICContextPreservation(t *testing.T) {
}
}
-// TestNICAutoGenAddrWithOpaque tests the auto-generation of IPv6 link-local
-// addresses with opaque interface identifiers. Link Local addresses should
-// always be generated with opaque IIDs if configured to use them, even if the
-// NIC has an invalid MAC address.
-func TestNICAutoGenAddrWithOpaque(t *testing.T) {
+// TestNICAutoGenLinkLocalAddr tests the auto-generation of IPv6 link-local
+// addresses.
+func TestNICAutoGenLinkLocalAddr(t *testing.T) {
const nicID = 1
var secretKey [header.OpaqueIIDSecretKeyMinBytes]byte
@@ -2056,108 +2347,201 @@ func TestNICAutoGenAddrWithOpaque(t *testing.T) {
t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", header.OpaqueIIDSecretKeyMinBytes, n)
}
+ nicNameFunc := func(_ tcpip.NICID, name string) string {
+ return name
+ }
+
tests := []struct {
- name string
- nicName string
- autoGen bool
- linkAddr tcpip.LinkAddress
- secretKey []byte
+ name string
+ nicName string
+ autoGen bool
+ linkAddr tcpip.LinkAddress
+ iidOpts stack.OpaqueInterfaceIdentifierOptions
+ shouldGen bool
+ expectedAddr tcpip.Address
}{
{
name: "Disabled",
nicName: "nic1",
autoGen: false,
linkAddr: linkAddr1,
- secretKey: secretKey[:],
+ shouldGen: false,
},
{
- name: "Enabled",
- nicName: "nic1",
- autoGen: true,
- linkAddr: linkAddr1,
- secretKey: secretKey[:],
+ name: "Disabled without OIID options",
+ nicName: "nic1",
+ autoGen: false,
+ linkAddr: linkAddr1,
+ iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: nicNameFunc,
+ SecretKey: secretKey[:],
+ },
+ shouldGen: false,
},
- // These are all cases where we would not have generated a
- // link-local address if opaque IIDs were disabled.
+
+ // Tests for EUI64 based addresses.
{
- name: "Nil MAC and empty nicName",
- nicName: "",
+ name: "EUI64 Enabled",
+ autoGen: true,
+ linkAddr: linkAddr1,
+ shouldGen: true,
+ expectedAddr: header.LinkLocalAddr(linkAddr1),
+ },
+ {
+ name: "EUI64 Empty MAC",
autoGen: true,
- linkAddr: tcpip.LinkAddress([]byte(nil)),
- secretKey: secretKey[:1],
+ shouldGen: false,
},
{
- name: "Empty MAC and empty nicName",
+ name: "EUI64 Invalid MAC",
autoGen: true,
- linkAddr: tcpip.LinkAddress(""),
- secretKey: secretKey[:2],
+ linkAddr: "\x01\x02\x03",
+ shouldGen: false,
},
{
- name: "Invalid MAC",
- nicName: "test",
+ name: "EUI64 Multicast MAC",
autoGen: true,
- linkAddr: tcpip.LinkAddress("\x01\x02\x03"),
- secretKey: secretKey[:3],
+ linkAddr: "\x01\x02\x03\x04\x05\x06",
+ shouldGen: false,
},
{
- name: "Multicast MAC",
- nicName: "test2",
+ name: "EUI64 Unspecified MAC",
autoGen: true,
- linkAddr: tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
- secretKey: secretKey[:4],
+ linkAddr: "\x00\x00\x00\x00\x00\x00",
+ shouldGen: false,
},
+
+ // Tests for Opaque IID based addresses.
{
- name: "Unspecified MAC and nil SecretKey",
+ name: "OIID Enabled",
+ nicName: "nic1",
+ autoGen: true,
+ linkAddr: linkAddr1,
+ iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: nicNameFunc,
+ SecretKey: secretKey[:],
+ },
+ shouldGen: true,
+ expectedAddr: header.LinkLocalAddrWithOpaqueIID("nic1", 0, secretKey[:]),
+ },
+ // These are all cases where we would not have generated a
+ // link-local address if opaque IIDs were disabled.
+ {
+ name: "OIID Empty MAC and empty nicName",
+ autoGen: true,
+ iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: nicNameFunc,
+ SecretKey: secretKey[:1],
+ },
+ shouldGen: true,
+ expectedAddr: header.LinkLocalAddrWithOpaqueIID("", 0, secretKey[:1]),
+ },
+ {
+ name: "OIID Invalid MAC",
+ nicName: "test",
+ autoGen: true,
+ linkAddr: "\x01\x02\x03",
+ iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: nicNameFunc,
+ SecretKey: secretKey[:2],
+ },
+ shouldGen: true,
+ expectedAddr: header.LinkLocalAddrWithOpaqueIID("test", 0, secretKey[:2]),
+ },
+ {
+ name: "OIID Multicast MAC",
+ nicName: "test2",
+ autoGen: true,
+ linkAddr: "\x01\x02\x03\x04\x05\x06",
+ iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: nicNameFunc,
+ SecretKey: secretKey[:3],
+ },
+ shouldGen: true,
+ expectedAddr: header.LinkLocalAddrWithOpaqueIID("test2", 0, secretKey[:3]),
+ },
+ {
+ name: "OIID Unspecified MAC and nil SecretKey",
nicName: "test3",
autoGen: true,
- linkAddr: tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"),
+ linkAddr: "\x00\x00\x00\x00\x00\x00",
+ iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: nicNameFunc,
+ },
+ shouldGen: true,
+ expectedAddr: header.LinkLocalAddrWithOpaqueIID("test3", 0, nil),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
- NICNameFromID: func(_ tcpip.NICID, nicName string) string {
- return nicName
- },
- SecretKey: test.secretKey,
- },
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
}
-
- if test.autoGen {
- // Only set opts.AutoGenIPv6LinkLocal when
- // test.autoGen is true because
- // opts.AutoGenIPv6LinkLocal should be false by
- // default.
- opts.AutoGenIPv6LinkLocal = true
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: test.autoGen,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: test.iidOpts,
}
- e := channel.New(10, 1280, test.linkAddr)
+ e := channel.New(0, 1280, test.linkAddr)
s := stack.New(opts)
- nicOpts := stack.NICOptions{Name: test.nicName}
+ nicOpts := stack.NICOptions{Name: test.nicName, Disabled: true}
if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
}
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err)
+ // A new disabled NIC should not have any address, even if auto generation
+ // was enabled.
+ allStackAddrs := s.AllAddresses()
+ allNICAddrs, ok := allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 0 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+ }
+
+ // Enabling the NIC should attempt auto-generation of a link-local
+ // address.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
}
- if test.autoGen {
- // Should have auto-generated an address and
- // resolved immediately (DAD is disabled).
- if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddrWithOpaqueIID(test.nicName, 0, test.secretKey), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want)
+ var expectedMainAddr tcpip.AddressWithPrefix
+ if test.shouldGen {
+ expectedMainAddr = tcpip.AddressWithPrefix{
+ Address: test.expectedAddr,
+ PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen,
+ }
+
+ // Should have auto-generated an address and resolved immediately (DAD
+ // is disabled).
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, expectedMainAddr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
}
} else {
// Should not have auto-generated an address.
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address")
+ default:
}
}
+
+ gotMainAddr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+ if gotMainAddr != expectedMainAddr {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", gotMainAddr, expectedMainAddr)
+ }
})
}
}
@@ -2215,6 +2599,8 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) {
// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6
// link-local addresses will only be assigned after the DAD process resolves.
func TestNICAutoGenAddrDoesDAD(t *testing.T) {
+ const nicID = 1
+
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent),
}
@@ -2226,20 +2612,20 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
NDPDisp: &ndpDisp,
}
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(int(ndpConfigs.DupAddrDetectTransmits), 1280, linkAddr1)
s := stack.New(opts)
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
// Address should not be considered bound to the
// NIC yet (DAD ongoing).
- addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
}
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
linkLocalAddr := header.LinkLocalAddr(linkAddr1)
@@ -2253,25 +2639,16 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
// means something is wrong.
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
- if e.err != nil {
- t.Fatal("got DAD error: ", e.err)
- }
- if e.nicID != 1 {
- t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID)
- }
- if e.addr != linkLocalAddr {
- t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr)
- }
- if !e.resolved {
- t.Fatal("got DAD event w/ resolved = false, want = true")
+ if diff := checkDADEvent(e, nicID, linkLocalAddr, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
- addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
if err != nil {
- t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
}
if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
}
@@ -2413,13 +2790,14 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
const (
- linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- nicID = 1
+ linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ nicID = 1
)
// Rule 3 is not tested here, and is instead tested by NDP's AutoGenAddr test.
@@ -2493,6 +2871,18 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
expectedLocalAddr: linkLocalAddr1,
},
{
+ name: "Link Local most preferred for link local multicast (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
+ connectAddr: linkLocalMulticastAddr,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Link Local most preferred for link local multicast (first address)",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ connectAddr: linkLocalMulticastAddr,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
name: "Unique Local most preferred (last address)",
nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1},
connectAddr: uniqueLocalAddr2,
@@ -2561,3 +2951,214 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
})
}
}
+
+func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) {
+ const nicID = 1
+
+ e := loopback.New()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ })
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ allStackAddrs := s.AllAddresses()
+ allNICAddrs, ok := allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 0 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+ }
+
+ // Enabling the NIC should add the IPv4 broadcast address.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ allStackAddrs = s.AllAddresses()
+ allNICAddrs, ok = allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 1 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 1", l)
+ }
+ want := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: header.IPv4Broadcast,
+ PrefixLen: 32,
+ },
+ }
+ if allNICAddrs[0] != want {
+ t.Fatalf("got allNICAddrs[0] = %+v, want = %+v", allNICAddrs[0], want)
+ }
+
+ // Disabling the NIC should remove the IPv4 broadcast address.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ allStackAddrs = s.AllAddresses()
+ allNICAddrs, ok = allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 0 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+ }
+}
+
+func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) {
+ const nicID = 1
+
+ e := loopback.New()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ })
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ // Should not be in the IPv6 all-nodes multicast group yet because the NIC has
+ // not been enabled yet.
+ isInGroup, err := s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
+ }
+ if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress)
+ }
+
+ // The all-nodes multicast group should be joined when the NIC is enabled.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
+ }
+ if !isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, header.IPv6AllNodesMulticastAddress)
+ }
+
+ // The all-nodes multicast group should be left when the NIC is disabled.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
+ }
+ if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress)
+ }
+}
+
+// TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC
+// was disabled have DAD performed on them when the NIC is enabled.
+func TestDoDADWhenNICEnabled(t *testing.T) {
+ t.Parallel()
+
+ const dadTransmits = 1
+ const retransmitTimer = time.Second
+ const nicID = 1
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ },
+ NDPDisp: &ndpDisp,
+ }
+
+ e := channel.New(dadTransmits, 1280, linkAddr1)
+ s := stack.New(opts)
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ addr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: llAddr1,
+ PrefixLen: 128,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, addr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err)
+ }
+
+ // Address should be in the list of all addresses.
+ if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
+ t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
+ }
+
+ // Address should be tentative so it should not be a main address.
+ got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); got != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want)
+ }
+
+ // Enabling the NIC should start DAD for the address.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
+ t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
+ }
+
+ // Address should not be considered bound to the NIC yet (DAD ongoing).
+ got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); got != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want)
+ }
+
+ // Wait for DAD to resolve.
+ select {
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout):
+ t.Fatal("timed out waiting for DAD resolution")
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ }
+ if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
+ t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
+ }
+ got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if got != addr.AddressWithPrefix {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix)
+ }
+
+ // Enabling the NIC again should be a no-op.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
+ t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
+ }
+ got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if got != addr.AddressWithPrefix {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix)
+ }
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index d686e6eb8..778c0a4d6 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -306,26 +306,6 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, p
ep.mu.RUnlock() // Don't use defer for performance reasons.
}
-// Close implements stack.TransportEndpoint.Close.
-func (ep *multiPortEndpoint) Close() {
- ep.mu.RLock()
- eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
- ep.mu.RUnlock()
- for _, e := range eps {
- e.Close()
- }
-}
-
-// Wait implements stack.TransportEndpoint.Wait.
-func (ep *multiPortEndpoint) Wait() {
- ep.mu.RLock()
- eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
- ep.mu.RUnlock()
- for _, e := range eps {
- e.Wait()
- }
-}
-
// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
// list. The list might be empty already.
func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error {
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 869c69a6d..5d1da2f8b 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -61,6 +61,10 @@ func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netP
return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
}
+func (f *fakeTransportEndpoint) Abort() {
+ f.Close()
+}
+
func (f *fakeTransportEndpoint) Close() {
f.route.Release()
}
@@ -272,7 +276,7 @@ func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.N
return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil
}
-func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return nil, tcpip.ErrUnknownProtocol
}
@@ -310,6 +314,15 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
}
}
+// Abort implements TransportProtocol.Abort.
+func (*fakeTransportProtocol) Abort() {}
+
+// Close implements tcpip.Endpoint.Close.
+func (*fakeTransportProtocol) Close() {}
+
+// Wait implements TransportProtocol.Wait.
+func (*fakeTransportProtocol) Wait() {}
+
func fakeTransFactory() stack.TransportProtocol {
return &fakeTransportProtocol{}
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 59c9b3fb0..3dc5d87d6 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -323,11 +323,17 @@ type ControlMessages struct {
// TOS is the IPv4 type of service of the associated packet.
TOS uint8
- // HasTClass indicates whether Tclass is valid/set.
+ // HasTClass indicates whether TClass is valid/set.
HasTClass bool
- // Tclass is the IPv6 traffic class of the associated packet.
- TClass int32
+ // TClass is the IPv6 traffic class of the associated packet.
+ TClass uint32
+
+ // HasIPPacketInfo indicates whether PacketInfo is set.
+ HasIPPacketInfo bool
+
+ // PacketInfo holds interface and address data on an incoming packet.
+ PacketInfo IPPacketInfo
}
// Endpoint is the interface implemented by transport protocols (e.g., tcp, udp)
@@ -335,9 +341,15 @@ type ControlMessages struct {
// networking stack.
type Endpoint interface {
// Close puts the endpoint in a closed state and frees all resources
- // associated with it.
+ // associated with it. Close initiates the teardown process, the
+ // Endpoint may not be fully closed when Close returns.
Close()
+ // Abort initiates an expedited endpoint teardown. As compared to
+ // Close, Abort prioritizes closing the Endpoint quickly over cleanly.
+ // Abort is best effort; implementing Abort with Close is acceptable.
+ Abort()
+
// Read reads data from the endpoint and optionally returns the sender.
//
// This method does not block if there is no data pending. It will also
@@ -496,13 +508,25 @@ type WriteOptions struct {
type SockOptBool int
const (
+ // ReceiveTClassOption is used by SetSockOpt/GetSockOpt to specify if the
+ // IPV6_TCLASS ancillary message is passed with incoming packets.
+ ReceiveTClassOption SockOptBool = iota
+
// ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS
// ancillary message is passed with incoming packets.
- ReceiveTOSOption SockOptBool = iota
+ ReceiveTOSOption
// V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6
// socket is to be restricted to sending and receiving IPv6 packets only.
V6OnlyOption
+
+ // ReceiveIPPacketInfoOption is used by {G,S}etSockOptBool to specify
+ // if more inforamtion is provided with incoming packets such
+ // as interface index and address.
+ ReceiveIPPacketInfoOption
+
+ // TODO(b/146901447): convert existing bool socket options to be handled via
+ // Get/SetSockOptBool
)
// SockOptInt represents socket options which values have the int type.
@@ -626,6 +650,12 @@ type TCPLingerTimeoutOption time.Duration
// before being marked closed.
type TCPTimeWaitTimeoutOption time.Duration
+// TCPDeferAcceptOption is used by SetSockOpt/GetSockOpt to allow a
+// accept to return a completed connection only when there is data to be
+// read. This usually means the listening socket will drop the final ACK
+// for a handshake till the specified timeout until a segment with data arrives.
+type TCPDeferAcceptOption time.Duration
+
// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
// TTL value for multicast messages. The default is 1.
type MulticastTTLOption uint8
@@ -679,6 +709,20 @@ type IPv4TOSOption uint8
// for all subsequent outgoing IPv6 packets from the endpoint.
type IPv6TrafficClassOption uint8
+// IPPacketInfo is the message struture for IP_PKTINFO.
+//
+// +stateify savable
+type IPPacketInfo struct {
+ // NIC is the ID of the NIC to be used.
+ NIC NICID
+
+ // LocalAddr is the local address.
+ LocalAddr Address
+
+ // DestinationAddr is the destination address.
+ DestinationAddr Address
+}
+
// Route is a row in the routing table. It specifies through which NIC (and
// gateway) sets of packets should be routed. A row is considered viable if the
// masked target address matches the destination address in the row.
@@ -1118,6 +1162,10 @@ type ReadErrors struct {
// InvalidEndpointState is the number of times we found the endpoint state
// to be unexpected.
InvalidEndpointState StatCounter
+
+ // NotConnected is the number of times we tried to read but found that the
+ // endpoint was not connected.
+ NotConnected StatCounter
}
// WriteErrors collects packet write errors from an endpoint write call.
@@ -1160,7 +1208,9 @@ type TransportEndpointStats struct {
// marker interface.
func (*TransportEndpointStats) IsEndpointStats() {}
-func fillIn(v reflect.Value) {
+// InitStatCounters initializes v's fields with nil StatCounter fields to new
+// StatCounters.
+func InitStatCounters(v reflect.Value) {
for i := 0; i < v.NumField(); i++ {
v := v.Field(i)
if s, ok := v.Addr().Interface().(**StatCounter); ok {
@@ -1168,14 +1218,14 @@ func fillIn(v reflect.Value) {
*s = new(StatCounter)
}
} else {
- fillIn(v)
+ InitStatCounters(v)
}
}
}
// FillIn returns a copy of s with nil fields initialized to new StatCounters.
func (s Stats) FillIn() Stats {
- fillIn(reflect.ValueOf(&s).Elem())
+ InitStatCounters(reflect.ValueOf(&s).Elem())
return s
}
diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go
index 48764b978..2f98a996f 100644
--- a/pkg/tcpip/time_unsafe.go
+++ b/pkg/tcpip/time_unsafe.go
@@ -25,6 +25,8 @@ import (
)
// StdClock implements Clock with the time package.
+//
+// +stateify savable
type StdClock struct{}
var _ Clock = (*StdClock)(nil)
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 42afb3f5b..426da1ee6 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -96,6 +96,11 @@ func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 9ce500e80..113d92901 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -104,20 +104,26 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool {
+func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool {
return true
}
-// SetOption implements TransportProtocol.SetOption.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+// SetOption implements stack.TransportProtocol.SetOption.
+func (*protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-// Option implements TransportProtocol.Option.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+// Option implements stack.TransportProtocol.Option.
+func (*protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
// NewProtocol4 returns an ICMPv4 transport protocol.
func NewProtocol4() stack.TransportProtocol {
return &protocol{ProtocolNumber4}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index fc5bc69fa..09a1cd436 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -76,6 +76,7 @@ type endpoint struct {
sndBufSize int
closed bool
stats tcpip.TransportEndpointStats `state:"nosave"`
+ bound bool
}
// NewEndpoint returns a new packet endpoint.
@@ -98,6 +99,11 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
return ep, nil
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
// Close implements tcpip.Endpoint.Close.
func (ep *endpoint) Close() {
ep.mu.Lock()
@@ -120,6 +126,7 @@ func (ep *endpoint) Close() {
}
ep.closed = true
+ ep.bound = false
ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
@@ -211,7 +218,24 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
// sll_family (should be AF_PACKET), sll_protocol, and sll_ifindex."
// - packet(7).
- return tcpip.ErrNotSupported
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ if ep.bound {
+ return tcpip.ErrAlreadyBound
+ }
+
+ // Unregister endpoint with all the nics.
+ ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+
+ // Bind endpoint to receive packets from specific interface.
+ if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil {
+ return err
+ }
+
+ ep.bound = true
+
+ return nil
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index ee9c4c58b..2ef5fac76 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -121,6 +121,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
return e, nil
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
// Close implements tcpip.Endpoint.Close.
func (e *endpoint) Close() {
e.mu.Lock()
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 4acd9fb9a..272e8f570 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -57,6 +57,7 @@ go_library(
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/log",
"//pkg/rand",
"//pkg/sleep",
"//pkg/sync",
@@ -90,6 +91,7 @@ go_test(
tags = ["flaky"],
deps = [
":tcp",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index d469758eb..13e383ffc 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -222,13 +222,13 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
// createConnectingEndpoint creates a new endpoint in a connecting state, with
// the connection parameters given by the arguments.
-func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
+func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
netProto = s.route.NetProto
}
- n := newEndpoint(l.stack, netProto, nil)
+ n := newEndpoint(l.stack, netProto, queue)
n.v6only = l.v6only
n.ID = s.id
n.boundNICID = s.route.NICID()
@@ -271,18 +271,19 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
return n, nil
}
-// createEndpoint creates a new endpoint in connected state and then performs
-// the TCP 3-way handshake.
-func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
+// createEndpointAndPerformHandshake creates a new endpoint in connected state
+// and then performs the TCP 3-way handshake.
+func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
- ep, err := l.createConnectingEndpoint(s, isn, irs, opts)
+ ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue)
if err != nil {
return nil, err
}
// listenEP is nil when listenContext is used by tcp.Forwarder.
+ deferAccept := time.Duration(0)
if l.listenEP != nil {
l.listenEP.mu.Lock()
if l.listenEP.EndpointState() != StateListen {
@@ -290,15 +291,21 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
return nil, tcpip.ErrConnectionAborted
}
l.addPendingEndpoint(ep)
+ deferAccept = l.listenEP.deferAccept
l.listenEP.mu.Unlock()
}
// Perform the 3-way handshake.
- h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow()))
-
- h.resetToSynRcvd(isn, irs, opts)
+ h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
if err := h.execute(); err != nil {
ep.Close()
+ // Wake up any waiters. This is strictly not required normally
+ // as a socket that was never accepted can't really have any
+ // registered waiters except when stack.Wait() is called which
+ // waits for all registered endpoints to stop and expects an
+ // EventHUp.
+ ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+
if l.listenEP != nil {
l.removePendingEndpoint(ep)
}
@@ -377,16 +384,14 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
defer e.decSynRcvdCount()
defer s.decRef()
- n, err := ctx.createEndpointAndPerformHandshake(s, opts)
+ n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{})
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
return
}
ctx.removePendingEndpoint(n)
- // Start the protocol goroutine.
- wq := &waiter.Queue{}
- n.startAcceptedLoop(wq)
+ n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
e.deliverAccepted(n)
@@ -546,7 +551,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr
}
- n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions)
+ n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions, &waiter.Queue{})
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
@@ -576,8 +581,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// space available in the backlog.
// Start the protocol goroutine.
- wq := &waiter.Queue{}
- n.startAcceptedLoop(wq)
+ n.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
go e.deliverAccepted(n)
}
@@ -610,7 +614,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.mu.Unlock()
// Notify waiters that the endpoint is shutdown.
- e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut)
+ e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
}()
s := sleep.Sleeper{}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 4e3c5419c..7730e6445 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -86,6 +86,19 @@ type handshake struct {
// rcvWndScale is the receive window scale, as defined in RFC 1323.
rcvWndScale int
+
+ // startTime is the time at which the first SYN/SYN-ACK was sent.
+ startTime time.Time
+
+ // deferAccept if non-zero will drop the final ACK for a passive
+ // handshake till an ACK segment with data is received or the timeout is
+ // hit.
+ deferAccept time.Duration
+
+ // acked is true if the the final ACK for a 3-way handshake has
+ // been received. This is required to stop retransmitting the
+ // original SYN-ACK when deferAccept is enabled.
+ acked bool
}
func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
@@ -112,6 +125,12 @@ func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
return h
}
+func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake {
+ h := newHandshake(ep, rcvWnd)
+ h.resetToSynRcvd(isn, irs, opts, deferAccept)
+ return h
+}
+
// FindWndScale determines the window scale to use for the given maximum window
// size.
func FindWndScale(wnd seqnum.Size) int {
@@ -181,7 +200,7 @@ func (h *handshake) effectiveRcvWndScale() uint8 {
// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD
// state.
-func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) {
+func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) {
h.active = false
h.state = handshakeSynRcvd
h.flags = header.TCPFlagSyn | header.TCPFlagAck
@@ -189,6 +208,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea
h.ackNum = irs + 1
h.mss = opts.MSS
h.sndWndScale = opts.WS
+ h.deferAccept = deferAccept
h.ep.mu.Lock()
h.ep.setEndpointState(StateSynRecv)
h.ep.mu.Unlock()
@@ -352,6 +372,14 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
// We have previously received (and acknowledged) the peer's SYN. If the
// peer acknowledges our SYN, the handshake is completed.
if s.flagIsSet(header.TCPFlagAck) {
+ // If deferAccept is not zero and this is a bare ACK and the
+ // timeout is not hit then drop the ACK.
+ if h.deferAccept != 0 && s.data.Size() == 0 && time.Since(h.startTime) < h.deferAccept {
+ h.acked = true
+ h.ep.stack.Stats().DroppedPackets.Increment()
+ return nil
+ }
+
// If the timestamp option is negotiated and the segment does
// not carry a timestamp option then the segment must be dropped
// as per https://tools.ietf.org/html/rfc7323#section-3.2.
@@ -365,10 +393,16 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber)
}
h.state = handshakeCompleted
+
h.ep.mu.Lock()
h.ep.transitionToStateEstablishedLocked(h)
+ // If the segment has data then requeue it for the receiver
+ // to process it again once main loop is started.
+ if s.data.Size() > 0 {
+ s.incRef()
+ h.ep.enqueueSegment(s)
+ }
h.ep.mu.Unlock()
-
return nil
}
@@ -471,6 +505,7 @@ func (h *handshake) execute() *tcpip.Error {
}
}
+ h.startTime = time.Now()
// Initialize the resend timer.
resendWaker := sleep.Waker{}
timeOut := time.Duration(time.Second)
@@ -524,11 +559,21 @@ func (h *handshake) execute() *tcpip.Error {
switch index, _ := s.Fetch(true); index {
case wakerForResend:
timeOut *= 2
- if timeOut > 60*time.Second {
+ if timeOut > MaxRTO {
return tcpip.ErrTimeout
}
rt.Reset(timeOut)
- h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ // Resend the SYN/SYN-ACK only if the following conditions hold.
+ // - It's an active handshake (deferAccept does not apply)
+ // - It's a passive handshake and we have not yet got the final-ACK.
+ // - It's a passive handshake and we got an ACK but deferAccept is
+ // enabled and we are now past the deferAccept duration.
+ // The last is required to provide a way for the peer to complete
+ // the connection with another ACK or data (as ACKs are never
+ // retransmitted on their own).
+ if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept {
+ h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ }
case wakerForNotification:
n := h.ep.fetchNotifications()
@@ -944,6 +989,10 @@ func (e *endpoint) transitionToStateCloseLocked() {
// to any other listening endpoint. We reply with RST if we cannot find one.
func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route)
+ if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" {
+ // Dual-stack socket, try IPv4.
+ ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, &s.route)
+ }
if ep == nil {
replyWithReset(s)
s.decRef()
@@ -1323,7 +1372,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.snd.updateMaxPayloadSize(mtu, count)
}
- if n&notifyReset != 0 {
+ if n&notifyReset != 0 || n&notifyAbort != 0 {
return tcpip.ErrConnectionAborted
}
@@ -1606,7 +1655,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
}
case notification:
n := e.fetchNotifications()
- if n&notifyClose != 0 {
+ if n&notifyClose != 0 || n&notifyAbort != 0 {
return nil
}
if n&notifyDrain != 0 {
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index e18012ac0..d792b07d6 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -68,17 +68,28 @@ func (q *epQueue) empty() bool {
type processor struct {
epQ epQueue
newEndpointWaker sleep.Waker
+ closeWaker sleep.Waker
id int
+ wg sync.WaitGroup
}
func newProcessor(id int) *processor {
p := &processor{
id: id,
}
+ p.wg.Add(1)
go p.handleSegments()
return p
}
+func (p *processor) close() {
+ p.closeWaker.Assert()
+}
+
+func (p *processor) wait() {
+ p.wg.Wait()
+}
+
func (p *processor) queueEndpoint(ep *endpoint) {
// Queue an endpoint for processing by the processor goroutine.
p.epQ.enqueue(ep)
@@ -87,11 +98,17 @@ func (p *processor) queueEndpoint(ep *endpoint) {
func (p *processor) handleSegments() {
const newEndpointWaker = 1
+ const closeWaker = 2
s := sleep.Sleeper{}
s.AddWaker(&p.newEndpointWaker, newEndpointWaker)
+ s.AddWaker(&p.closeWaker, closeWaker)
defer s.Done()
for {
- s.Fetch(true)
+ id, ok := s.Fetch(true)
+ if ok && id == closeWaker {
+ p.wg.Done()
+ return
+ }
for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() {
if ep.segmentQueue.empty() {
continue
@@ -160,6 +177,18 @@ func newDispatcher(nProcessors int) *dispatcher {
}
}
+func (d *dispatcher) close() {
+ for _, p := range d.processors {
+ p.close()
+ }
+}
+
+func (d *dispatcher) wait() {
+ for _, p := range d.processors {
+ p.wait()
+ }
+}
+
func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
ep := stackEP.(*endpoint)
s := newSegment(r, id, pkt)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 13718ff55..f1ad19dac 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -121,6 +121,8 @@ const (
notifyDrain
notifyReset
notifyResetByPeer
+ // notifyAbort is a request for an expedited teardown.
+ notifyAbort
notifyKeepaliveChanged
notifyMSSChanged
// notifyTickleWorker is used to tickle the protocol main loop during a
@@ -498,6 +500,13 @@ type endpoint struct {
// without any data being acked.
userTimeout time.Duration
+ // deferAccept if non-zero specifies a user specified time during
+ // which the final ACK of a handshake will be dropped provided the
+ // ACK is a bare ACK and carries no data. If the timeout is crossed then
+ // the bare ACK is accepted and the connection is delivered to the
+ // listener.
+ deferAccept time.Duration
+
// pendingAccepted is a synchronization primitive used to track number
// of connections that are queued up to be delivered to the accepted
// channel. We use this to ensure that all goroutines blocked on writing
@@ -778,6 +787,24 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) {
}
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ // The abort notification is not processed synchronously, so no
+ // synchronization is needed.
+ //
+ // If the endpoint becomes connected after this check, we still close
+ // the endpoint. This worst case results in a slower abort.
+ //
+ // If the endpoint disconnected after the check, nothing needs to be
+ // done, so sending a notification which will potentially be ignored is
+ // fine.
+ if e.EndpointState().connected() {
+ e.notifyProtocolGoroutine(notifyAbort)
+ return
+ }
+ e.Close()
+}
+
// Close puts the endpoint in a closed state and frees all resources associated
// with it. It must be called only once and with no other concurrent calls to
// the endpoint.
@@ -822,9 +849,18 @@ func (e *endpoint) closeNoShutdown() {
// Either perform the local cleanup or kick the worker to make sure it
// knows it needs to cleanup.
tcpip.AddDanglingEndpoint(e)
- if !e.workerRunning {
+ switch e.EndpointState() {
+ // Sockets in StateSynRecv state(passive connections) are closed when
+ // the handshake fails or if the listening socket is closed while
+ // handshake was in progress. In such cases the handshake goroutine
+ // is already gone by the time Close is called and we need to cleanup
+ // here.
+ case StateInitial, StateBound, StateSynRecv:
e.cleanupLocked()
- } else {
+ e.setEndpointState(StateClose)
+ case StateError, StateClose:
+ // do nothing.
+ default:
e.workerCleanup = true
e.notifyProtocolGoroutine(notifyClose)
}
@@ -996,8 +1032,8 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
if s == StateError {
return buffer.View{}, tcpip.ControlMessages{}, he
}
- e.stats.ReadErrors.InvalidEndpointState.Increment()
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
+ e.stats.ReadErrors.NotConnected.Increment()
+ return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected
}
v, err := e.readLocked()
@@ -1574,6 +1610,15 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case tcpip.TCPDeferAcceptOption:
+ e.mu.Lock()
+ if time.Duration(v) > MaxRTO {
+ v = tcpip.TCPDeferAcceptOption(MaxRTO)
+ }
+ e.deferAccept = time.Duration(v)
+ e.mu.Unlock()
+ return nil
+
default:
return nil
}
@@ -1798,6 +1843,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
return nil
+ case *tcpip.TCPDeferAcceptOption:
+ e.mu.Lock()
+ *o = tcpip.TCPDeferAcceptOption(e.deferAccept)
+ e.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -2025,8 +2076,14 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// work mutex is available.
if e.workMu.TryLock() {
e.mu.Lock()
- e.resetConnectionLocked(tcpip.ErrConnectionAborted)
- e.notifyProtocolGoroutine(notifyTickleWorker)
+ // We need to double check here to make
+ // sure worker has not transitioned the
+ // endpoint out of a connected state
+ // before trying to send a reset.
+ if e.EndpointState().connected() {
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ }
e.mu.Unlock()
e.workMu.Unlock()
} else {
@@ -2138,6 +2195,9 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
e.isRegistered = true
e.setEndpointState(StateListen)
+ // The channel may be non-nil when we're restoring the endpoint, and it
+ // may be pre-populated with some previously accepted (but not Accepted)
+ // endpoints.
if e.acceptedChan == nil {
e.acceptedChan = make(chan *endpoint, backlog)
}
@@ -2149,9 +2209,8 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
// startAcceptedLoop sets up required state and starts a goroutine with the
// main loop for accepted connections.
-func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) {
+func (e *endpoint) startAcceptedLoop() {
e.mu.Lock()
- e.waiterQueue = waiterQueue
e.workerRunning = true
e.mu.Unlock()
wakerInitDone := make(chan struct{})
@@ -2177,7 +2236,6 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
default:
return nil, nil, tcpip.ErrWouldBlock
}
-
return n, n.waiterQueue, nil
}
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 7eb613be5..c9ee5bf06 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -157,13 +157,13 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
TSVal: r.synOptions.TSVal,
TSEcr: r.synOptions.TSEcr,
SACKPermitted: r.synOptions.SACKPermitted,
- })
+ }, queue)
if err != nil {
return nil, err
}
// Start the protocol goroutine.
- ep.startAcceptedLoop(queue)
+ ep.startAcceptedLoop()
return ep, nil
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 958c06fa7..73098d904 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -194,7 +194,7 @@ func replyWithReset(s *segment) {
sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */)
}
-// SetOption implements TransportProtocol.SetOption.
+// SetOption implements stack.TransportProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
switch v := option.(type) {
case SACKEnabled:
@@ -269,7 +269,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
}
}
-// Option implements TransportProtocol.Option.
+// Option implements stack.TransportProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error {
switch v := option.(type) {
case *SACKEnabled:
@@ -331,6 +331,16 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
}
}
+// Close implements stack.TransportProtocol.Close.
+func (p *protocol) Close() {
+ p.dispatcher.close()
+}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (p *protocol) Wait() {
+ p.dispatcher.wait()
+}
+
// NewProtocol returns a TCP transport protocol.
func NewProtocol() stack.TransportProtocol {
return &protocol{
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index df2fb1071..5b2b16afa 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -21,6 +21,7 @@ import (
"testing"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -542,8 +543,9 @@ func TestCurrentConnectedIncrement(t *testing.T) {
),
)
- // Wait for the TIME-WAIT state to transition to CLOSED.
- time.Sleep(1 * time.Second)
+ // Wait for a little more than the TIME-WAIT duration for the socket to
+ // transition to CLOSED state.
+ time.Sleep(1200 * time.Millisecond)
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
@@ -5404,12 +5406,11 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
}
- // Expect InvalidEndpointState errors on a read at this point.
- if _, _, err := ep.Read(nil); err != tcpip.ErrInvalidEndpointState {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrInvalidEndpointState)
+ if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected {
+ t.Errorf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrNotConnected)
}
- if got := ep.Stats().(*tcp.Stats).ReadErrors.InvalidEndpointState.Value(); got != 1 {
- t.Fatalf("got EP stats Stats.ReadErrors.InvalidEndpointState got %v want %v", got, 1)
+ if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 {
+ t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %v want %v", got, 1)
}
if err := ep.Listen(10); err != nil {
@@ -6787,3 +6788,183 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) {
),
)
}
+
+func TestTCPDeferAccept(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ const tcpDeferAccept = 1 * time.Second
+ if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err)
+ }
+
+ irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
+
+ if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock)
+ }
+
+ // Send data. This should result in an acceptable endpoint.
+ c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+
+ // Give a bit of time for the socket to be delivered to the accept queue.
+ time.Sleep(50 * time.Millisecond)
+ aep, _, err := c.EP.Accept()
+ if err != nil {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err)
+ }
+
+ aep.Close()
+ // Closing aep without reading the data should trigger a RST.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+}
+
+func TestTCPDeferAcceptTimeout(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ const tcpDeferAccept = 1 * time.Second
+ if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err)
+ }
+
+ irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
+
+ if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock)
+ }
+
+ // Sleep for a little of the tcpDeferAccept timeout.
+ time.Sleep(tcpDeferAccept + 100*time.Millisecond)
+
+ // On timeout expiry we should get a SYN-ACK retransmission.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.AckNum(uint32(irs)+1)))
+
+ // Send data. This should result in an acceptable endpoint.
+ c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+
+ // Give sometime for the endpoint to be delivered to the accept queue.
+ time.Sleep(50 * time.Millisecond)
+ aep, _, err := c.EP.Accept()
+ if err != nil {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err)
+ }
+
+ aep.Close()
+ // Closing aep without reading the data should trigger a RST.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+}
+
+func TestResetDuringClose(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ iss := seqnum.Value(789)
+ c.CreateConnected(iss, 30000, -1 /* epRecvBuf */)
+ // Send some data to make sure there is some unread
+ // data to trigger a reset on c.Close.
+ irs := c.IRS
+ c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss.Add(1),
+ AckNum: irs.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(irs.Add(1))),
+ checker.AckNum(uint32(iss.Add(5)))))
+
+ // Close in a separate goroutine so that we can trigger
+ // a race with the RST we send below. This should not
+ // panic due to the route being released depeding on
+ // whether Close() sends an active RST or the RST sent
+ // below is processed by the worker first.
+ var wg sync.WaitGroup
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ SeqNum: iss.Add(5),
+ AckNum: c.IRS.Add(5),
+ RcvWnd: 30000,
+ Flags: header.TCPFlagRst,
+ })
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ c.EP.Close()
+ }()
+
+ wg.Wait()
+}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 730ac4292..1e9a0dea3 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -1082,7 +1082,11 @@ func (c *Context) SACKEnabled() bool {
// SetGSOEnabled enables or disables generic segmentation offload.
func (c *Context) SetGSOEnabled(enable bool) {
- c.linkEP.GSO = enable
+ if enable {
+ c.linkEP.LinkEPCapabilities |= stack.CapabilityHardwareGSO
+ } else {
+ c.linkEP.LinkEPCapabilities &^= stack.CapabilityHardwareGSO
+ }
}
// MSSWithoutOptions returns the value for the MSS used by the stack when no
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index c9cbed8f4..1c6a600b8 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -29,9 +29,11 @@ import (
type udpPacket struct {
udpPacketEntry
senderAddress tcpip.FullAddress
+ packetInfo tcpip.IPPacketInfo
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
timestamp int64
- tos uint8
+ // tos stores either the receiveTOS or receiveTClass value.
+ tos uint8
}
// EndpointState represents the state of a UDP endpoint.
@@ -118,6 +120,13 @@ type endpoint struct {
// as ancillary data to ControlMessages on Read.
receiveTOS bool
+ // receiveTClass determines if the incoming IPv6 TClass header field is
+ // passed as ancillary data to ControlMessages on Read.
+ receiveTClass bool
+
+ // receiveIPPacketInfo determines if the packet info is returned by Read.
+ receiveIPPacketInfo bool
+
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
@@ -177,6 +186,11 @@ func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
@@ -254,11 +268,22 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
}
e.mu.RLock()
receiveTOS := e.receiveTOS
+ receiveTClass := e.receiveTClass
+ receiveIPPacketInfo := e.receiveIPPacketInfo
e.mu.RUnlock()
if receiveTOS {
cm.HasTOS = true
cm.TOS = p.tos
}
+ if receiveTClass {
+ cm.HasTClass = true
+ // Although TClass is an 8-bit value it's read in the CMsg as a uint32.
+ cm.TClass = uint32(p.tos)
+ }
+ if receiveIPPacketInfo {
+ cm.HasIPPacketInfo = true
+ cm.PacketInfo = p.packetInfo
+ }
return p.data.ToView(), cm, nil
}
@@ -480,6 +505,17 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
e.mu.Unlock()
return nil
+ case tcpip.ReceiveTClassOption:
+ // We only support this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrNotSupported
+ }
+
+ e.mu.Lock()
+ e.receiveTClass = v
+ e.mu.Unlock()
+ return nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
@@ -495,6 +531,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
}
e.v6only = v
+ return nil
+
+ case tcpip.ReceiveIPPacketInfoOption:
+ e.mu.Lock()
+ e.receiveIPPacketInfo = v
+ e.mu.Unlock()
+ return nil
}
return nil
@@ -692,6 +735,17 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
+ case tcpip.ReceiveTClassOption:
+ // We only support this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return false, tcpip.ErrNotSupported
+ }
+
+ e.mu.RLock()
+ v := e.receiveTClass
+ e.mu.RUnlock()
+ return v, nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
@@ -703,6 +757,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
+
+ case tcpip.ReceiveIPPacketInfoOption:
+ e.mu.RLock()
+ v := e.receiveIPPacketInfo
+ e.mu.RUnlock()
+ return v, nil
}
return false, tcpip.ErrUnknownProtocolOption
@@ -1247,6 +1307,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
switch r.NetProto {
case header.IPv4ProtocolNumber:
packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS()
+ packet.packetInfo.LocalAddr = r.LocalAddress
+ packet.packetInfo.DestinationAddr = r.RemoteAddress
+ packet.packetInfo.NIC = r.NICID()
+ case header.IPv6ProtocolNumber:
+ packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS()
}
packet.timestamp = e.stack.NowNanoseconds()
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 259c3072a..8df089d22 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -180,16 +180,22 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
return true
}
-// SetOption implements TransportProtocol.SetOption.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+// SetOption implements stack.TransportProtocol.SetOption.
+func (*protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-// Option implements TransportProtocol.Option.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+// Option implements stack.TransportProtocol.Option.
+func (*protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
// NewProtocol returns a UDP transport protocol.
func NewProtocol() stack.TransportProtocol {
return &protocol{}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index f0ff3fe71..34b7c2360 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -409,6 +409,7 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
+ TrafficClass: testTOS,
PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
NextHeader: uint8(udp.ProtocolNumber),
HopLimit: 65,
@@ -1336,7 +1337,7 @@ func TestSetTTL(t *testing.T) {
}
}
-func TestTOSV4(t *testing.T) {
+func TestSetTOS(t *testing.T) {
for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
@@ -1347,23 +1348,23 @@ func TestTOSV4(t *testing.T) {
const tos = testTOS
var v tcpip.IPv4TOSOption
if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt failed: %s", err)
+ c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
}
// Test for expected default value.
if v != 0 {
- c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0)
+ c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0)
}
if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil {
- c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err)
+ c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv4TOSOption(tos), err)
}
if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt failed: %s", err)
+ c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
}
if want := tcpip.IPv4TOSOption(tos); v != want {
- c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want)
}
testWrite(c, flow, checker.TOS(tos, 0))
@@ -1371,7 +1372,7 @@ func TestTOSV4(t *testing.T) {
}
}
-func TestTOSV6(t *testing.T) {
+func TestSetTClass(t *testing.T) {
for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
@@ -1379,71 +1380,92 @@ func TestTOSV6(t *testing.T) {
c.createEndpointForFlow(flow)
- const tos = testTOS
+ const tClass = testTOS
var v tcpip.IPv6TrafficClassOption
if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt failed: %s", err)
+ c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
}
// Test for expected default value.
if v != 0 {
- c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0)
+ c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0)
}
- if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil {
- c.t.Errorf("SetSockOpt failed: %s", err)
+ if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tClass)); err != nil {
+ c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv6TrafficClassOption(tClass), err)
}
if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt failed: %s", err)
+ c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
}
- if want := tcpip.IPv6TrafficClassOption(tos); v != want {
- c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ if want := tcpip.IPv6TrafficClassOption(tClass); v != want {
+ c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want)
}
- testWrite(c, flow, checker.TOS(tos, 0))
+ // The header getter for TClass is called TOS, so use that checker.
+ testWrite(c, flow, checker.TOS(tClass, 0))
})
}
}
-func TestReceiveTOSV4(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, broadcast} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
+func TestReceiveTosTClass(t *testing.T) {
+ testCases := []struct {
+ name string
+ getReceiveOption tcpip.SockOptBool
+ tests []testFlow
+ }{
+ {"ReceiveTosOption", tcpip.ReceiveTOSOption, []testFlow{unicastV4, broadcast}},
+ {"ReceiveTClassOption", tcpip.ReceiveTClassOption, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}},
+ }
+ for _, testCase := range testCases {
+ for _, flow := range testCase.tests {
+ t.Run(fmt.Sprintf("%s:flow:%s", testCase.name, flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
- c.createEndpointForFlow(flow)
+ c.createEndpointForFlow(flow)
+ option := testCase.getReceiveOption
+ name := testCase.name
- // Verify that setting and reading the option works.
- v, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption)
- if err != nil {
- c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err)
- }
- // Test for expected default value.
- if v != false {
- c.t.Errorf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", v, false)
- }
+ // Verify that setting and reading the option works.
+ v, err := c.ep.GetSockOptBool(option)
+ if err != nil {
+ c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err)
+ }
+ // Test for expected default value.
+ if v != false {
+ c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false)
+ }
- want := true
- if err := c.ep.SetSockOptBool(tcpip.ReceiveTOSOption, want); err != nil {
- c.t.Fatalf("SetSockOptBool(tcpip.ReceiveTOSOption, %t) failed: %s", want, err)
- }
+ want := true
+ if err := c.ep.SetSockOptBool(option, want); err != nil {
+ c.t.Fatalf("SetSockOptBool(%s, %t) failed: %s", name, want, err)
+ }
- got, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption)
- if err != nil {
- c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err)
- }
- if got != want {
- c.t.Fatalf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", got, want)
- }
+ got, err := c.ep.GetSockOptBool(option)
+ if err != nil {
+ c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err)
+ }
- // Verify that the correct received TOS is handed through as
- // ancillary data to the ControlMessages struct.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatal("Bind failed:", err)
- }
- testRead(c, flow, checker.ReceiveTOS(testTOS))
- })
+ if got != want {
+ c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want)
+ }
+
+ // Verify that the correct received TOS or TClass is handed through as
+ // ancillary data to the ControlMessages struct.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+ switch option {
+ case tcpip.ReceiveTClassOption:
+ testRead(c, flow, checker.ReceiveTClass(testTOS))
+ case tcpip.ReceiveTOSOption:
+ testRead(c, flow, checker.ReceiveTOS(testTOS))
+ default:
+ t.Fatalf("unknown test variant: %s", name)
+ }
+ })
+ }
}
}
diff --git a/pkg/usermem/BUILD b/pkg/usermem/BUILD
index ff8b9e91a..6c9ada9c7 100644
--- a/pkg/usermem/BUILD
+++ b/pkg/usermem/BUILD
@@ -25,7 +25,6 @@ go_library(
"bytes_io_unsafe.go",
"usermem.go",
"usermem_arm64.go",
- "usermem_unsafe.go",
"usermem_x86.go",
],
visibility = ["//:sandbox"],
@@ -33,6 +32,7 @@ go_library(
"//pkg/atomicbitops",
"//pkg/binary",
"//pkg/context",
+ "//pkg/gohacks",
"//pkg/log",
"//pkg/safemem",
"//pkg/syserror",
diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go
index 71fd4e155..d2f4403b0 100644
--- a/pkg/usermem/usermem.go
+++ b/pkg/usermem/usermem.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/gohacks"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -251,7 +252,7 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt
}
end, ok := addr.AddLength(uint64(readlen))
if !ok {
- return stringFromImmutableBytes(buf[:done]), syserror.EFAULT
+ return gohacks.StringFromImmutableBytes(buf[:done]), syserror.EFAULT
}
// Shorten the read to avoid crossing page boundaries, since faulting
// in a page unnecessarily is expensive. This also ensures that partial
@@ -272,16 +273,16 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt
// Look for the terminating zero byte, which may have occurred before
// hitting err.
if i := bytes.IndexByte(buf[done:done+n], byte(0)); i >= 0 {
- return stringFromImmutableBytes(buf[:done+i]), nil
+ return gohacks.StringFromImmutableBytes(buf[:done+i]), nil
}
done += n
if err != nil {
- return stringFromImmutableBytes(buf[:done]), err
+ return gohacks.StringFromImmutableBytes(buf[:done]), err
}
addr = end
}
- return stringFromImmutableBytes(buf), syserror.ENAMETOOLONG
+ return gohacks.StringFromImmutableBytes(buf), syserror.ENAMETOOLONG
}
// CopyOutVec copies bytes from src to the memory mapped at ars in uio. The
diff --git a/runsc/BUILD b/runsc/BUILD
index b35b41d81..757f6d44c 100644
--- a/runsc/BUILD
+++ b/runsc/BUILD
@@ -19,13 +19,14 @@ go_binary(
"//pkg/sentry/platform",
"//runsc/boot",
"//runsc/cmd",
+ "//runsc/flag",
"//runsc/specutils",
"@com_github_google_subcommands//:go_default_library",
],
)
# The runsc-race target is a race-compatible BUILD target. This must be built
-# via: bazel build --features=race //runsc:runsc-race
+# via: bazel build --features=race :runsc-race
#
# This is neccessary because the race feature must apply to all dependencies
# due a bug in gazelle file selection. The pure attribute must be off because
@@ -54,6 +55,7 @@ go_binary(
"//pkg/sentry/platform",
"//runsc/boot",
"//runsc/cmd",
+ "//runsc/flag",
"//runsc/specutils",
"@com_github_google_subcommands//:go_default_library",
],
@@ -117,4 +119,5 @@ sh_test(
srcs = ["version_test.sh"],
args = ["$(location :runsc)"],
data = [":runsc"],
+ tags = ["noguitar"],
)
diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD
index ae4dd102a..26f68fe3d 100644
--- a/runsc/boot/BUILD
+++ b/runsc/boot/BUILD
@@ -19,7 +19,6 @@ go_library(
"loader_amd64.go",
"loader_arm64.go",
"network.go",
- "pprof.go",
"strace.go",
"user.go",
],
@@ -91,6 +90,7 @@ go_library(
"//pkg/usermem",
"//runsc/boot/filter",
"//runsc/boot/platforms",
+ "//runsc/boot/pprof",
"//runsc/specutils",
"@com_github_golang_protobuf//proto:go_default_library",
"@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go
index 9c9e94864..17e774e0c 100644
--- a/runsc/boot/controller.go
+++ b/runsc/boot/controller.go
@@ -32,6 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/watchdog"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/urpc"
+ "gvisor.dev/gvisor/runsc/boot/pprof"
"gvisor.dev/gvisor/runsc/specutils"
)
@@ -142,7 +143,7 @@ func newController(fd int, l *Loader) (*controller, error) {
}
srv.Register(manager)
- if eps, ok := l.k.NetworkStack().(*netstack.Stack); ok {
+ if eps, ok := l.k.RootNetworkNamespace().Stack().(*netstack.Stack); ok {
net := &Network{
Stack: eps.Stack,
}
@@ -341,7 +342,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error {
return fmt.Errorf("creating memory file: %v", err)
}
k.SetMemoryFile(mf)
- networkStack := cm.l.k.NetworkStack()
+ networkStack := cm.l.k.RootNetworkNamespace().Stack()
cm.l.k = k
// Set up the restore environment.
@@ -365,9 +366,9 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error {
}
if cm.l.conf.ProfileEnable {
- // initializePProf opens /proc/self/maps, so has to be
- // called before installing seccomp filters.
- initializePProf()
+ // pprof.Initialize opens /proc/self/maps, so has to be called before
+ // installing seccomp filters.
+ pprof.Initialize()
}
// Seccomp filters have to be applied before parsing the state file.
diff --git a/runsc/boot/filter/BUILD b/runsc/boot/filter/BUILD
index ce30f6c53..ed18f0047 100644
--- a/runsc/boot/filter/BUILD
+++ b/runsc/boot/filter/BUILD
@@ -8,6 +8,7 @@ go_library(
"config.go",
"config_amd64.go",
"config_arm64.go",
+ "config_profile.go",
"extra_filters.go",
"extra_filters_msan.go",
"extra_filters_race.go",
diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go
index 4fb9adca6..a4627905e 100644
--- a/runsc/boot/filter/config.go
+++ b/runsc/boot/filter/config.go
@@ -174,6 +174,18 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_LSEEK: {},
syscall.SYS_MADVISE: {},
syscall.SYS_MINCORE: {},
+ // Used by the Go runtime as a temporarily workaround for a Linux
+ // 5.2-5.4 bug.
+ //
+ // See src/runtime/os_linux_x86.go.
+ //
+ // TODO(b/148688965): Remove once this is gone from Go.
+ syscall.SYS_MLOCK: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(4096),
+ },
+ },
syscall.SYS_MMAP: []seccomp.Rule{
{
seccomp.AllowAny{},
@@ -217,7 +229,9 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_NANOSLEEP: {},
syscall.SYS_PPOLL: {},
syscall.SYS_PREAD64: {},
+ syscall.SYS_PREADV: {},
syscall.SYS_PWRITE64: {},
+ syscall.SYS_PWRITEV: {},
syscall.SYS_READ: {},
syscall.SYS_RECVMSG: []seccomp.Rule{
{
@@ -524,16 +538,3 @@ func controlServerFilters(fd int) seccomp.SyscallRules {
},
}
}
-
-// profileFilters returns extra syscalls made by runtime/pprof package.
-func profileFilters() seccomp.SyscallRules {
- return seccomp.SyscallRules{
- syscall.SYS_OPENAT: []seccomp.Rule{
- {
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.O_RDONLY | syscall.O_LARGEFILE | syscall.O_CLOEXEC),
- },
- },
- }
-}
diff --git a/runsc/boot/filter/config_profile.go b/runsc/boot/filter/config_profile.go
new file mode 100644
index 000000000..194952a7b
--- /dev/null
+++ b/runsc/boot/filter/config_profile.go
@@ -0,0 +1,34 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package filter
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
+// profileFilters returns extra syscalls made by runtime/pprof package.
+func profileFilters() seccomp.SyscallRules {
+ return seccomp.SyscallRules{
+ syscall.SYS_OPENAT: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.O_RDONLY | syscall.O_LARGEFILE | syscall.O_CLOEXEC),
+ },
+ },
+ }
+}
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index 9f0d5d7af..e7ca98134 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -49,6 +49,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/watchdog"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -60,6 +61,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/runsc/boot/filter"
_ "gvisor.dev/gvisor/runsc/boot/platforms" // register all platforms.
+ "gvisor.dev/gvisor/runsc/boot/pprof"
"gvisor.dev/gvisor/runsc/specutils"
// Include supported socket providers.
@@ -230,11 +232,8 @@ func New(args Args) (*Loader, error) {
return nil, fmt.Errorf("enabling strace: %v", err)
}
- // Create an empty network stack because the network namespace may be empty at
- // this point. Netns is configured before Run() is called. Netstack is
- // configured using a control uRPC message. Host network is configured inside
- // Run().
- networkStack, err := newEmptyNetworkStack(args.Conf, k, k)
+ // Create root network namespace/stack.
+ netns, err := newRootNetworkNamespace(args.Conf, k, k)
if err != nil {
return nil, fmt.Errorf("creating network: %v", err)
}
@@ -277,7 +276,7 @@ func New(args Args) (*Loader, error) {
FeatureSet: cpuid.HostFeatureSet(),
Timekeeper: tk,
RootUserNamespace: creds.UserNamespace,
- NetworkStack: networkStack,
+ RootNetworkNamespace: netns,
ApplicationCores: uint(args.NumCPU),
Vdso: vdso,
RootUTSNamespace: kernel.NewUTSNamespace(args.Spec.Hostname, args.Spec.Hostname, creds.UserNamespace),
@@ -466,7 +465,7 @@ func (l *Loader) run() error {
// Delay host network configuration to this point because network namespace
// is configured after the loader is created and before Run() is called.
log.Debugf("Configuring host network")
- stack := l.k.NetworkStack().(*hostinet.Stack)
+ stack := l.k.RootNetworkNamespace().Stack().(*hostinet.Stack)
if err := stack.Configure(); err != nil {
return err
}
@@ -485,7 +484,7 @@ func (l *Loader) run() error {
// l.restore is set by the container manager when a restore call is made.
if !l.restore {
if l.conf.ProfileEnable {
- initializePProf()
+ pprof.Initialize()
}
// Finally done with all configuration. Setup filters before user code
@@ -795,16 +794,19 @@ func (l *Loader) executeAsync(args *control.ExecArgs) (kernel.ThreadID, error) {
return 0, fmt.Errorf("container %q not started", args.ContainerID)
}
+ // TODO(gvisor.dev/issue/1623): Add VFS2 support
+
// Get the container MountNamespace from the Task.
tg.Leader().WithMuLocked(func(t *kernel.Task) {
- // task.MountNamespace() does not take a ref, so we must do so
- // ourselves.
+ // task.MountNamespace() does not take a ref, so we must do so ourselves.
args.MountNamespace = t.MountNamespace()
args.MountNamespace.IncRef()
})
- defer args.MountNamespace.DecRef()
+ if args.MountNamespace != nil {
+ defer args.MountNamespace.DecRef()
+ }
- // Add the HOME enviroment varible if it is not already set.
+ // Add the HOME environment variable if it is not already set.
root := args.MountNamespace.Root()
defer root.DecRef()
ctx := fs.WithRoot(l.k.SupervisorContext(), root)
@@ -905,48 +907,92 @@ func (l *Loader) WaitExit() kernel.ExitStatus {
return l.k.GlobalInit().ExitStatus()
}
-func newEmptyNetworkStack(conf *Config, clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) {
+func newRootNetworkNamespace(conf *Config, clock tcpip.Clock, uniqueID stack.UniqueID) (*inet.Namespace, error) {
+ // Create an empty network stack because the network namespace may be empty at
+ // this point. Netns is configured before Run() is called. Netstack is
+ // configured using a control uRPC message. Host network is configured inside
+ // Run().
switch conf.Network {
case NetworkHost:
- return hostinet.NewStack(), nil
+ // No network namespacing support for hostinet yet, hence creator is nil.
+ return inet.NewRootNamespace(hostinet.NewStack(), nil), nil
case NetworkNone, NetworkSandbox:
- // NetworkNone sets up loopback using netstack.
- netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()}
- transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4()}
- s := netstack.Stack{stack.New(stack.Options{
- NetworkProtocols: netProtos,
- TransportProtocols: transProtos,
- Clock: clock,
- Stats: netstack.Metrics,
- HandleLocal: true,
- // Enable raw sockets for users with sufficient
- // privileges.
- RawFactory: raw.EndpointFactory{},
- UniqueID: uniqueID,
- })}
-
- // Enable SACK Recovery.
- if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(true)); err != nil {
- return nil, fmt.Errorf("failed to enable SACK: %v", err)
+ s, err := newEmptySandboxNetworkStack(clock, uniqueID)
+ if err != nil {
+ return nil, err
+ }
+ creator := &sandboxNetstackCreator{
+ clock: clock,
+ uniqueID: uniqueID,
}
+ return inet.NewRootNamespace(s, creator), nil
- // Set default TTLs as required by socket/netstack.
- s.Stack.SetNetworkProtocolOption(ipv4.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL))
- s.Stack.SetNetworkProtocolOption(ipv6.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL))
+ default:
+ panic(fmt.Sprintf("invalid network configuration: %v", conf.Network))
+ }
- // Enable Receive Buffer Auto-Tuning.
- if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
- return nil, fmt.Errorf("SetTransportProtocolOption failed: %v", err)
- }
+}
- s.FillDefaultIPTables()
+func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) {
+ netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()}
+ transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4()}
+ s := netstack.Stack{stack.New(stack.Options{
+ NetworkProtocols: netProtos,
+ TransportProtocols: transProtos,
+ Clock: clock,
+ Stats: netstack.Metrics,
+ HandleLocal: true,
+ // Enable raw sockets for users with sufficient
+ // privileges.
+ RawFactory: raw.EndpointFactory{},
+ UniqueID: uniqueID,
+ })}
- return &s, nil
+ // Enable SACK Recovery.
+ if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(true)); err != nil {
+ return nil, fmt.Errorf("failed to enable SACK: %v", err)
+ }
- default:
- panic(fmt.Sprintf("invalid network configuration: %v", conf.Network))
+ // Set default TTLs as required by socket/netstack.
+ s.Stack.SetNetworkProtocolOption(ipv4.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL))
+ s.Stack.SetNetworkProtocolOption(ipv6.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL))
+
+ // Enable Receive Buffer Auto-Tuning.
+ if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
+ return nil, fmt.Errorf("SetTransportProtocolOption failed: %v", err)
}
+
+ s.FillDefaultIPTables()
+
+ return &s, nil
+}
+
+// sandboxNetstackCreator implements kernel.NetworkStackCreator.
+//
+// +stateify savable
+type sandboxNetstackCreator struct {
+ clock tcpip.Clock
+ uniqueID stack.UniqueID
+}
+
+// CreateStack implements kernel.NetworkStackCreator.CreateStack.
+func (f *sandboxNetstackCreator) CreateStack() (inet.Stack, error) {
+ s, err := newEmptySandboxNetworkStack(f.clock, f.uniqueID)
+ if err != nil {
+ return nil, err
+ }
+
+ // Setup loopback.
+ n := &Network{Stack: s.(*netstack.Stack).Stack}
+ nicID := tcpip.NICID(f.uniqueID.UniqueID())
+ link := DefaultLoopbackLink
+ linkEP := loopback.New()
+ if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses); err != nil {
+ return nil, err
+ }
+
+ return s, nil
}
// signal sends a signal to one or more processes in a container. If PID is 0,
@@ -994,7 +1040,7 @@ func (l *Loader) signalProcess(cid string, tgid kernel.ThreadID, signo int32) er
execTG, _, err := l.threadGroupFromID(execID{cid: cid, pid: tgid})
if err == nil {
// Send signal directly to the identified process.
- return execTG.SendSignal(&arch.SignalInfo{Signo: signo})
+ return l.k.SendExternalSignalThreadGroup(execTG, &arch.SignalInfo{Signo: signo})
}
// The caller may be signaling a process not started directly via exec.
@@ -1011,7 +1057,7 @@ func (l *Loader) signalProcess(cid string, tgid kernel.ThreadID, signo int32) er
if tg.Leader().ContainerID() != cid {
return fmt.Errorf("process %d is part of a different container: %q", tgid, tg.Leader().ContainerID())
}
- return tg.SendSignal(&arch.SignalInfo{Signo: signo})
+ return l.k.SendExternalSignalThreadGroup(tg, &arch.SignalInfo{Signo: signo})
}
func (l *Loader) signalForegrondProcessGroup(cid string, tgid kernel.ThreadID, signo int32) error {
@@ -1029,7 +1075,7 @@ func (l *Loader) signalForegrondProcessGroup(cid string, tgid kernel.ThreadID, s
// No foreground process group has been set. Signal the
// original thread group.
log.Warningf("No foreground process group for container %q and PID %d. Sending signal directly to PID %d.", cid, tgid, tgid)
- return tg.SendSignal(&arch.SignalInfo{Signo: signo})
+ return l.k.SendExternalSignalThreadGroup(tg, &arch.SignalInfo{Signo: signo})
}
// Send the signal to all processes in the process group.
var lastErr error
@@ -1037,7 +1083,7 @@ func (l *Loader) signalForegrondProcessGroup(cid string, tgid kernel.ThreadID, s
if tg.ProcessGroup() != pg {
continue
}
- if err := tg.SendSignal(&arch.SignalInfo{Signo: signo}); err != nil {
+ if err := l.k.SendExternalSignalThreadGroup(tg, &arch.SignalInfo{Signo: signo}); err != nil {
lastErr = err
}
}
diff --git a/runsc/boot/network.go b/runsc/boot/network.go
index 6a8765ec8..bee6ee336 100644
--- a/runsc/boot/network.go
+++ b/runsc/boot/network.go
@@ -17,6 +17,7 @@ package boot
import (
"fmt"
"net"
+ "strings"
"syscall"
"gvisor.dev/gvisor/pkg/log"
@@ -31,6 +32,32 @@ import (
"gvisor.dev/gvisor/pkg/urpc"
)
+var (
+ // DefaultLoopbackLink contains IP addresses and routes of "127.0.0.1/8" and
+ // "::1/8" on "lo" interface.
+ DefaultLoopbackLink = LoopbackLink{
+ Name: "lo",
+ Addresses: []net.IP{
+ net.IP("\x7f\x00\x00\x01"),
+ net.IPv6loopback,
+ },
+ Routes: []Route{
+ {
+ Destination: net.IPNet{
+ IP: net.IPv4(0x7f, 0, 0, 0),
+ Mask: net.IPv4Mask(0xff, 0, 0, 0),
+ },
+ },
+ {
+ Destination: net.IPNet{
+ IP: net.IPv6loopback,
+ Mask: net.IPMask(strings.Repeat("\xff", net.IPv6len)),
+ },
+ },
+ },
+ }
+)
+
// Network exposes methods that can be used to configure a network stack.
type Network struct {
Stack *stack.Stack
diff --git a/runsc/boot/pprof/BUILD b/runsc/boot/pprof/BUILD
new file mode 100644
index 000000000..29cb42b2f
--- /dev/null
+++ b/runsc/boot/pprof/BUILD
@@ -0,0 +1,11 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "pprof",
+ srcs = ["pprof.go"],
+ visibility = [
+ "//runsc:__subpackages__",
+ ],
+)
diff --git a/runsc/boot/pprof.go b/runsc/boot/pprof/pprof.go
index 463362f02..1ded20dee 100644
--- a/runsc/boot/pprof.go
+++ b/runsc/boot/pprof/pprof.go
@@ -12,7 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package boot
+// Package pprof provides a stub to initialize custom profilers.
+package pprof
-func initializePProf() {
+// Initialize will be called at boot for initializing custom profilers.
+func Initialize() {
}
diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD
index 09aa46434..d0bb4613a 100644
--- a/runsc/cmd/BUILD
+++ b/runsc/cmd/BUILD
@@ -31,6 +31,7 @@ go_library(
"spec.go",
"start.go",
"state.go",
+ "statefile.go",
"syscalls.go",
"wait.go",
],
@@ -43,6 +44,8 @@ go_library(
"//pkg/sentry/control",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
+ "//pkg/state",
+ "//pkg/state/statefile",
"//pkg/sync",
"//pkg/unet",
"//pkg/urpc",
@@ -50,6 +53,7 @@ go_library(
"//runsc/boot/platforms",
"//runsc/console",
"//runsc/container",
+ "//runsc/flag",
"//runsc/fsgofer",
"//runsc/fsgofer/filter",
"//runsc/specutils",
diff --git a/runsc/cmd/boot.go b/runsc/cmd/boot.go
index b40fded5b..0f3da69a0 100644
--- a/runsc/cmd/boot.go
+++ b/runsc/cmd/boot.go
@@ -21,12 +21,12 @@ import (
"strings"
"syscall"
- "flag"
"github.com/google/subcommands"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/boot/platforms"
+ "gvisor.dev/gvisor/runsc/flag"
"gvisor.dev/gvisor/runsc/specutils"
)
diff --git a/runsc/cmd/checkpoint.go b/runsc/cmd/checkpoint.go
index d8b3a8573..8a29e521e 100644
--- a/runsc/cmd/checkpoint.go
+++ b/runsc/cmd/checkpoint.go
@@ -20,11 +20,11 @@ import (
"path/filepath"
"syscall"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
"gvisor.dev/gvisor/runsc/specutils"
)
diff --git a/runsc/cmd/create.go b/runsc/cmd/create.go
index 1815c93b9..910e97577 100644
--- a/runsc/cmd/create.go
+++ b/runsc/cmd/create.go
@@ -17,10 +17,10 @@ package cmd
import (
"context"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
"gvisor.dev/gvisor/runsc/specutils"
)
diff --git a/runsc/cmd/debug.go b/runsc/cmd/debug.go
index f37415810..79965460e 100644
--- a/runsc/cmd/debug.go
+++ b/runsc/cmd/debug.go
@@ -22,12 +22,12 @@ import (
"syscall"
"time"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
)
// Debug implements subcommands.Command for the "debug" command.
diff --git a/runsc/cmd/delete.go b/runsc/cmd/delete.go
index 30d8164b1..0e4863f50 100644
--- a/runsc/cmd/delete.go
+++ b/runsc/cmd/delete.go
@@ -19,11 +19,11 @@ import (
"fmt"
"os"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
)
// Delete implements subcommands.Command for the "delete" command.
diff --git a/runsc/cmd/do.go b/runsc/cmd/do.go
index 9a8a49054..b184bd402 100644
--- a/runsc/cmd/do.go
+++ b/runsc/cmd/do.go
@@ -27,12 +27,12 @@ import (
"strings"
"syscall"
- "flag"
"github.com/google/subcommands"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
"gvisor.dev/gvisor/runsc/specutils"
)
diff --git a/runsc/cmd/events.go b/runsc/cmd/events.go
index 3972e9224..51f6a98ed 100644
--- a/runsc/cmd/events.go
+++ b/runsc/cmd/events.go
@@ -20,11 +20,11 @@ import (
"os"
"time"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
)
// Events implements subcommands.Command for the "events" command.
diff --git a/runsc/cmd/exec.go b/runsc/cmd/exec.go
index d1e99243b..d9a94903e 100644
--- a/runsc/cmd/exec.go
+++ b/runsc/cmd/exec.go
@@ -27,7 +27,6 @@ import (
"syscall"
"time"
- "flag"
"github.com/google/subcommands"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/log"
@@ -37,6 +36,7 @@ import (
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/console"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
"gvisor.dev/gvisor/runsc/specutils"
)
diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go
index 7df7995f0..6e06f3c0f 100644
--- a/runsc/cmd/gofer.go
+++ b/runsc/cmd/gofer.go
@@ -23,7 +23,6 @@ import (
"strings"
"syscall"
- "flag"
"github.com/google/subcommands"
specs "github.com/opencontainers/runtime-spec/specs-go"
"golang.org/x/sys/unix"
@@ -32,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
"gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/flag"
"gvisor.dev/gvisor/runsc/fsgofer"
"gvisor.dev/gvisor/runsc/fsgofer/filter"
"gvisor.dev/gvisor/runsc/specutils"
diff --git a/runsc/cmd/help.go b/runsc/cmd/help.go
index 930e8454f..c7d210140 100644
--- a/runsc/cmd/help.go
+++ b/runsc/cmd/help.go
@@ -18,8 +18,8 @@ import (
"context"
"fmt"
- "flag"
"github.com/google/subcommands"
+ "gvisor.dev/gvisor/runsc/flag"
)
// NewHelp returns a help command for the given commander.
diff --git a/runsc/cmd/install.go b/runsc/cmd/install.go
index 441c1db0d..2e223e3be 100644
--- a/runsc/cmd/install.go
+++ b/runsc/cmd/install.go
@@ -23,8 +23,8 @@ import (
"os"
"path"
- "flag"
"github.com/google/subcommands"
+ "gvisor.dev/gvisor/runsc/flag"
)
// Install implements subcommands.Command.
diff --git a/runsc/cmd/kill.go b/runsc/cmd/kill.go
index 6c1f197a6..8282ea0e0 100644
--- a/runsc/cmd/kill.go
+++ b/runsc/cmd/kill.go
@@ -21,11 +21,11 @@ import (
"strings"
"syscall"
- "flag"
"github.com/google/subcommands"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
)
// Kill implements subcommands.Command for the "kill" command.
diff --git a/runsc/cmd/list.go b/runsc/cmd/list.go
index dd2d99a6b..d8d906fe3 100644
--- a/runsc/cmd/list.go
+++ b/runsc/cmd/list.go
@@ -22,11 +22,11 @@ import (
"text/tabwriter"
"time"
- "flag"
"github.com/google/subcommands"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
)
// List implements subcommands.Command for the "list" command for the "list" command.
diff --git a/runsc/cmd/pause.go b/runsc/cmd/pause.go
index 9c0e92001..6f95a9837 100644
--- a/runsc/cmd/pause.go
+++ b/runsc/cmd/pause.go
@@ -17,10 +17,10 @@ package cmd
import (
"context"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
)
// Pause implements subcommands.Command for the "pause" command.
diff --git a/runsc/cmd/ps.go b/runsc/cmd/ps.go
index 45c644f3f..7fb8041af 100644
--- a/runsc/cmd/ps.go
+++ b/runsc/cmd/ps.go
@@ -18,11 +18,11 @@ import (
"context"
"fmt"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
)
// PS implements subcommands.Command for the "ps" command.
diff --git a/runsc/cmd/restore.go b/runsc/cmd/restore.go
index 7be60cd7d..72584b326 100644
--- a/runsc/cmd/restore.go
+++ b/runsc/cmd/restore.go
@@ -19,10 +19,10 @@ import (
"path/filepath"
"syscall"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
"gvisor.dev/gvisor/runsc/specutils"
)
diff --git a/runsc/cmd/resume.go b/runsc/cmd/resume.go
index b2df5c640..61a55a554 100644
--- a/runsc/cmd/resume.go
+++ b/runsc/cmd/resume.go
@@ -17,10 +17,10 @@ package cmd
import (
"context"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
)
// Resume implements subcommands.Command for the "resume" command.
diff --git a/runsc/cmd/run.go b/runsc/cmd/run.go
index 33f4bc12b..cf41581ad 100644
--- a/runsc/cmd/run.go
+++ b/runsc/cmd/run.go
@@ -18,10 +18,10 @@ import (
"context"
"syscall"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
"gvisor.dev/gvisor/runsc/specutils"
)
diff --git a/runsc/cmd/spec.go b/runsc/cmd/spec.go
index 344da13ba..8e2b36e85 100644
--- a/runsc/cmd/spec.go
+++ b/runsc/cmd/spec.go
@@ -20,8 +20,8 @@ import (
"os"
"path/filepath"
- "flag"
"github.com/google/subcommands"
+ "gvisor.dev/gvisor/runsc/flag"
)
var specTemplate = []byte(`{
diff --git a/runsc/cmd/start.go b/runsc/cmd/start.go
index 5e9bc53ab..0205fd9f7 100644
--- a/runsc/cmd/start.go
+++ b/runsc/cmd/start.go
@@ -17,10 +17,10 @@ package cmd
import (
"context"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
)
// Start implements subcommands.Command for the "start" command.
diff --git a/runsc/cmd/state.go b/runsc/cmd/state.go
index e9f41cbd8..cf2413deb 100644
--- a/runsc/cmd/state.go
+++ b/runsc/cmd/state.go
@@ -19,11 +19,11 @@ import (
"encoding/json"
"os"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
)
// State implements subcommands.Command for the "state" command.
diff --git a/runsc/cmd/statefile.go b/runsc/cmd/statefile.go
new file mode 100644
index 000000000..e6f1907da
--- /dev/null
+++ b/runsc/cmd/statefile.go
@@ -0,0 +1,143 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "fmt"
+ "os"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/statefile"
+ "gvisor.dev/gvisor/runsc/flag"
+)
+
+// Statefile implements subcommands.Command for the "statefile" command.
+type Statefile struct {
+ list bool
+ get string
+ key string
+ output string
+ html bool
+}
+
+// Name implements subcommands.Command.
+func (*Statefile) Name() string {
+ return "state"
+}
+
+// Synopsis implements subcommands.Command.
+func (*Statefile) Synopsis() string {
+ return "shows information about a statefile"
+}
+
+// Usage implements subcommands.Command.
+func (*Statefile) Usage() string {
+ return `statefile [flags] <statefile>`
+}
+
+// SetFlags implements subcommands.Command.
+func (s *Statefile) SetFlags(f *flag.FlagSet) {
+ f.BoolVar(&s.list, "list", false, "lists the metdata in the statefile.")
+ f.StringVar(&s.get, "get", "", "extracts the given metadata key.")
+ f.StringVar(&s.key, "key", "", "the integrity key for the file.")
+ f.StringVar(&s.output, "output", "", "target to write the result.")
+ f.BoolVar(&s.html, "html", false, "outputs in HTML format.")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (s *Statefile) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ // Check arguments.
+ if s.list && s.get != "" {
+ Fatalf("error: can't specify -list and -get simultaneously.")
+ }
+
+ // Setup output.
+ var output = os.Stdout // Default.
+ if s.output != "" {
+ f, err := os.OpenFile(s.output, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0644)
+ if err != nil {
+ Fatalf("error opening output: %v", err)
+ }
+ defer func() {
+ if err := f.Close(); err != nil {
+ Fatalf("error flushing output: %v", err)
+ }
+ }()
+ output = f
+ }
+
+ // Open the file.
+ if f.NArg() != 1 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+ input, err := os.Open(f.Arg(0))
+ if err != nil {
+ Fatalf("error opening input: %v\n", err)
+ }
+
+ if s.html {
+ fmt.Fprintf(output, "<html><body>\n")
+ defer fmt.Fprintf(output, "</body></html>\n")
+ }
+
+ // Dump the full file?
+ if !s.list && s.get == "" {
+ var key []byte
+ if s.key != "" {
+ key = []byte(s.key)
+ }
+ rc, _, err := statefile.NewReader(input, key)
+ if err != nil {
+ Fatalf("error parsing statefile: %v", err)
+ }
+ if err := state.PrettyPrint(output, rc, s.html); err != nil {
+ Fatalf("error printing state: %v", err)
+ }
+ return subcommands.ExitSuccess
+ }
+
+ // Load just the metadata.
+ metadata, err := statefile.MetadataUnsafe(input)
+ if err != nil {
+ Fatalf("error reading metadata: %v", err)
+ }
+
+ // Is it a single key?
+ if s.get != "" {
+ val, ok := metadata[s.get]
+ if !ok {
+ Fatalf("metadata key %s: not found", s.get)
+ }
+ fmt.Fprintf(output, "%s\n", val)
+ return subcommands.ExitSuccess
+ }
+
+ // List all keys.
+ if s.html {
+ fmt.Fprintf(output, " <ul>\n")
+ defer fmt.Fprintf(output, " </ul>\n")
+ }
+ for key := range metadata {
+ if s.html {
+ fmt.Fprintf(output, " <li>%s</li>\n", key)
+ } else {
+ fmt.Fprintf(output, "%s\n", key)
+ }
+ }
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/cmd/syscalls.go b/runsc/cmd/syscalls.go
index fb6c1ab29..7072547be 100644
--- a/runsc/cmd/syscalls.go
+++ b/runsc/cmd/syscalls.go
@@ -25,9 +25,9 @@ import (
"strconv"
"text/tabwriter"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/runsc/flag"
)
// Syscalls implements subcommands.Command for the "syscalls" command.
diff --git a/runsc/cmd/wait.go b/runsc/cmd/wait.go
index 046489687..29c0a15f0 100644
--- a/runsc/cmd/wait.go
+++ b/runsc/cmd/wait.go
@@ -20,10 +20,10 @@ import (
"os"
"syscall"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
)
const (
diff --git a/runsc/container/BUILD b/runsc/container/BUILD
index e21431e4c..0aaeea3a8 100644
--- a/runsc/container/BUILD
+++ b/runsc/container/BUILD
@@ -30,7 +30,7 @@ go_library(
go_test(
name = "container_test",
- size = "medium",
+ size = "large",
srcs = [
"console_test.go",
"container_test.go",
diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go
index 060b63bf3..c2518d52b 100644
--- a/runsc/container/console_test.go
+++ b/runsc/container/console_test.go
@@ -196,7 +196,10 @@ func TestJobControlSignalExec(t *testing.T) {
defer ptyMaster.Close()
defer ptySlave.Close()
- // Exec bash and attach a terminal.
+ // Exec bash and attach a terminal. Note that occasionally /bin/sh
+ // may be a different shell or have a different configuration (such
+ // as disabling interactive mode and job control). Since we want to
+ // explicitly test interactive mode, use /bin/bash. See b/116981926.
execArgs := &control.ExecArgs{
Filename: "/bin/bash",
// Don't let bash execute from profile or rc files, otherwise
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index b54d8f712..bdd65b498 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -71,6 +71,7 @@ func waitForProcessCount(cont *Container, want int) error {
return &backoff.PermanentError{Err: err}
}
if got := len(pss); got != want {
+ log.Infof("Waiting for process count to reach %d. Current: %d", want, got)
return fmt.Errorf("wrong process count, got: %d, want: %d", got, want)
}
return nil
@@ -163,7 +164,7 @@ func createWriteableOutputFile(path string) (*os.File, error) {
return outputFile, nil
}
-func waitForFile(f *os.File) error {
+func waitForFileNotEmpty(f *os.File) error {
op := func() error {
fi, err := f.Stat()
if err != nil {
@@ -178,6 +179,17 @@ func waitForFile(f *os.File) error {
return testutil.Poll(op, 30*time.Second)
}
+func waitForFileExist(path string) error {
+ op := func() error {
+ if _, err := os.Stat(path); os.IsNotExist(err) {
+ return err
+ }
+ return nil
+ }
+
+ return testutil.Poll(op, 30*time.Second)
+}
+
// readOutputNum reads a file at given filepath and returns the int at the
// requested position.
func readOutputNum(file string, position int) (int, error) {
@@ -187,7 +199,7 @@ func readOutputNum(file string, position int) (int, error) {
}
// Ensure that there is content in output file.
- if err := waitForFile(f); err != nil {
+ if err := waitForFileNotEmpty(f); err != nil {
return 0, fmt.Errorf("error waiting for output file: %v", err)
}
@@ -801,7 +813,7 @@ func TestCheckpointRestore(t *testing.T) {
defer file.Close()
// Wait until application has ran.
- if err := waitForFile(outputFile); err != nil {
+ if err := waitForFileNotEmpty(outputFile); err != nil {
t.Fatalf("Failed to wait for output file: %v", err)
}
@@ -843,7 +855,7 @@ func TestCheckpointRestore(t *testing.T) {
}
// Wait until application has ran.
- if err := waitForFile(outputFile2); err != nil {
+ if err := waitForFileNotEmpty(outputFile2); err != nil {
t.Fatalf("Failed to wait for output file: %v", err)
}
@@ -887,7 +899,7 @@ func TestCheckpointRestore(t *testing.T) {
}
// Wait until application has ran.
- if err := waitForFile(outputFile3); err != nil {
+ if err := waitForFileNotEmpty(outputFile3); err != nil {
t.Fatalf("Failed to wait for output file: %v", err)
}
@@ -981,7 +993,7 @@ func TestUnixDomainSockets(t *testing.T) {
defer os.RemoveAll(imagePath)
// Wait until application has ran.
- if err := waitForFile(outputFile); err != nil {
+ if err := waitForFileNotEmpty(outputFile); err != nil {
t.Fatalf("Failed to wait for output file: %v", err)
}
@@ -1023,7 +1035,7 @@ func TestUnixDomainSockets(t *testing.T) {
}
// Wait until application has ran.
- if err := waitForFile(outputFile2); err != nil {
+ if err := waitForFileNotEmpty(outputFile2); err != nil {
t.Fatalf("Failed to wait for output file: %v", err)
}
@@ -1042,126 +1054,84 @@ func TestUnixDomainSockets(t *testing.T) {
}
// TestPauseResume tests that we can successfully pause and resume a container.
-// It checks starts running sleep and executes another sleep. It pauses and checks
-// that both processes are still running: sleep will be paused and still exist.
-// It will then unpause and confirm that both processes are running. Then it will
-// wait until one sleep completes and check to make sure the other is running.
+// The container will keep touching a file to indicate it's running. The test
+// pauses the container, removes the file, and checks that it doesn't get
+// recreated. Then it resumes the container, verify that the file gets created
+// again.
func TestPauseResume(t *testing.T) {
for _, conf := range configs(noOverlay...) {
- t.Logf("Running test with conf: %+v", conf)
- const uid = 343
- spec := testutil.NewSpecWithArgs("sleep", "20")
-
- lock, err := ioutil.TempFile(testutil.TmpDir(), "lock")
- if err != nil {
- t.Fatalf("error creating output file: %v", err)
- }
- defer lock.Close()
+ t.Run(fmt.Sprintf("conf: %+v", conf), func(t *testing.T) {
+ t.Logf("Running test with conf: %+v", conf)
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
- if err != nil {
- t.Fatalf("error setting up container: %v", err)
- }
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
-
- // Create and start the container.
- args := Args{
- ID: testutil.UniqueContainerID(),
- Spec: spec,
- BundleDir: bundleDir,
- }
- cont, err := New(conf, args)
- if err != nil {
- t.Fatalf("error creating container: %v", err)
- }
- defer cont.Destroy()
- if err := cont.Start(conf); err != nil {
- t.Fatalf("error starting container: %v", err)
- }
-
- // expectedPL lists the expected process state of the container.
- expectedPL := []*control.Process{
- {
- UID: 0,
- PID: 1,
- PPID: 0,
- C: 0,
- Cmd: "sleep",
- Threads: []kernel.ThreadID{1},
- },
- {
- UID: uid,
- PID: 2,
- PPID: 0,
- C: 0,
- Cmd: "bash",
- Threads: []kernel.ThreadID{2},
- },
- }
-
- script := fmt.Sprintf("while [[ -f %q ]]; do sleep 0.1; done", lock.Name())
- execArgs := &control.ExecArgs{
- Filename: "/bin/bash",
- Argv: []string{"bash", "-c", script},
- WorkingDirectory: "/",
- KUID: uid,
- }
+ tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "lock")
+ if err != nil {
+ t.Fatalf("error creating temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
- // First, start running exec.
- _, err = cont.Execute(execArgs)
- if err != nil {
- t.Fatalf("error executing: %v", err)
- }
+ running := path.Join(tmpDir, "running")
+ script := fmt.Sprintf("while [[ true ]]; do touch %q; sleep 0.1; done", running)
+ spec := testutil.NewSpecWithArgs("/bin/bash", "-c", script)
- // Verify that "sleep 5" is running.
- if err := waitForProcessList(cont, expectedPL); err != nil {
- t.Fatal(err)
- }
+ rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer os.RemoveAll(rootDir)
+ defer os.RemoveAll(bundleDir)
- // Pause the running container.
- if err := cont.Pause(); err != nil {
- t.Errorf("error pausing container: %v", err)
- }
- if got, want := cont.Status, Paused; got != want {
- t.Errorf("container status got %v, want %v", got, want)
- }
+ // Create and start the container.
+ args := Args{
+ ID: testutil.UniqueContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont.Destroy()
+ if err := cont.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
- if err := os.Remove(lock.Name()); err != nil {
- t.Fatalf("os.Remove(lock) failed: %v", err)
- }
- // Script loops and sleeps for 100ms. Give a bit a time for it to exit in
- // case pause didn't work.
- time.Sleep(200 * time.Millisecond)
+ // Wait until container starts running, observed by the existence of running
+ // file.
+ if err := waitForFileExist(running); err != nil {
+ t.Errorf("error waiting for container to start: %v", err)
+ }
- // Verify that the two processes still exist.
- if err := getAndCheckProcLists(cont, expectedPL); err != nil {
- t.Fatal(err)
- }
+ // Pause the running container.
+ if err := cont.Pause(); err != nil {
+ t.Errorf("error pausing container: %v", err)
+ }
+ if got, want := cont.Status, Paused; got != want {
+ t.Errorf("container status got %v, want %v", got, want)
+ }
- // Resume the running container.
- if err := cont.Resume(); err != nil {
- t.Errorf("error pausing container: %v", err)
- }
- if got, want := cont.Status, Running; got != want {
- t.Errorf("container status got %v, want %v", got, want)
- }
+ if err := os.Remove(running); err != nil {
+ t.Fatalf("os.Remove(%q) failed: %v", running, err)
+ }
+ // Script touches the file every 100ms. Give a bit a time for it to run to
+ // catch the case that pause didn't work.
+ time.Sleep(200 * time.Millisecond)
+ if _, err := os.Stat(running); !os.IsNotExist(err) {
+ t.Fatalf("container did not pause: file exist check: %v", err)
+ }
- expectedPL2 := []*control.Process{
- {
- UID: 0,
- PID: 1,
- PPID: 0,
- C: 0,
- Cmd: "sleep",
- Threads: []kernel.ThreadID{1},
- },
- }
+ // Resume the running container.
+ if err := cont.Resume(); err != nil {
+ t.Errorf("error pausing container: %v", err)
+ }
+ if got, want := cont.Status, Running; got != want {
+ t.Errorf("container status got %v, want %v", got, want)
+ }
- // Verify that deleting the file triggered the process to exit.
- if err := waitForProcessList(cont, expectedPL2); err != nil {
- t.Fatal(err)
- }
+ // Verify that the file is once again created by container.
+ if err := waitForFileExist(running); err != nil {
+ t.Fatalf("error resuming container: file exist check: %v", err)
+ }
+ })
}
}
diff --git a/runsc/container/test_app/BUILD b/runsc/container/test_app/BUILD
index e200bafd9..0defbd9fc 100644
--- a/runsc/container/test_app/BUILD
+++ b/runsc/container/test_app/BUILD
@@ -13,6 +13,7 @@ go_binary(
visibility = ["//runsc/container:__pkg__"],
deps = [
"//pkg/unet",
+ "//runsc/flag",
"//runsc/testutil",
"@com_github_google_subcommands//:go_default_library",
"@com_github_kr_pty//:go_default_library",
diff --git a/runsc/container/test_app/fds.go b/runsc/container/test_app/fds.go
index a90cc1662..2a146a2c3 100644
--- a/runsc/container/test_app/fds.go
+++ b/runsc/container/test_app/fds.go
@@ -21,9 +21,9 @@ import (
"os"
"time"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/pkg/unet"
+ "gvisor.dev/gvisor/runsc/flag"
"gvisor.dev/gvisor/runsc/testutil"
)
diff --git a/runsc/container/test_app/test_app.go b/runsc/container/test_app/test_app.go
index a1c8a741a..01c47c79f 100644
--- a/runsc/container/test_app/test_app.go
+++ b/runsc/container/test_app/test_app.go
@@ -30,9 +30,9 @@ import (
sys "syscall"
"time"
- "flag"
"github.com/google/subcommands"
"github.com/kr/pty"
+ "gvisor.dev/gvisor/runsc/flag"
"gvisor.dev/gvisor/runsc/testutil"
)
diff --git a/runsc/dockerutil/dockerutil.go b/runsc/dockerutil/dockerutil.go
index 9b6346ca2..1ff5e8cc3 100644
--- a/runsc/dockerutil/dockerutil.go
+++ b/runsc/dockerutil/dockerutil.go
@@ -143,8 +143,11 @@ func PrepareFiles(names ...string) (string, error) {
return "", fmt.Errorf("os.Chmod(%q, 0777) failed: %v", dir, err)
}
for _, name := range names {
- src := getLocalPath(name)
- dst := path.Join(dir, name)
+ src, err := testutil.FindFile(name)
+ if err != nil {
+ return "", fmt.Errorf("testutil.Preparefiles(%q) failed: %v", name, err)
+ }
+ dst := path.Join(dir, path.Base(name))
if err := testutil.Copy(src, dst); err != nil {
return "", fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err)
}
@@ -152,10 +155,6 @@ func PrepareFiles(names ...string) (string, error) {
return dir, nil
}
-func getLocalPath(file string) string {
- return path.Join(".", file)
-}
-
// do executes docker command.
func do(args ...string) (string, error) {
log.Printf("Running: docker %s\n", args)
diff --git a/runsc/flag/BUILD b/runsc/flag/BUILD
new file mode 100644
index 000000000..5cb7604a8
--- /dev/null
+++ b/runsc/flag/BUILD
@@ -0,0 +1,9 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "flag",
+ srcs = ["flag.go"],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/fspath/builder_unsafe.go b/runsc/flag/flag.go
index 75606808d..0ca4829d7 100644
--- a/pkg/fspath/builder_unsafe.go
+++ b/runsc/flag/flag.go
@@ -12,16 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package fspath
+package flag
import (
- "unsafe"
+ "flag"
)
-// String returns the accumulated string. No other methods should be called
-// after String.
-func (b *Builder) String() string {
- bs := b.buf[b.start:]
- // Compare strings.Builder.String().
- return *(*string)(unsafe.Pointer(&bs))
-}
+type FlagSet = flag.FlagSet
+
+var (
+ NewFlagSet = flag.NewFlagSet
+ String = flag.String
+ Bool = flag.Bool
+ Int = flag.Int
+ Uint = flag.Uint
+ CommandLine = flag.CommandLine
+ Parse = flag.Parse
+)
+
+const ContinueOnError = flag.ContinueOnError
diff --git a/runsc/fsgofer/filter/config.go b/runsc/fsgofer/filter/config.go
index a1792330f..1dce36965 100644
--- a/runsc/fsgofer/filter/config.go
+++ b/runsc/fsgofer/filter/config.go
@@ -128,6 +128,18 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_MADVISE: {},
unix.SYS_MEMFD_CREATE: {}, /// Used by flipcall.PacketWindowAllocator.Init().
syscall.SYS_MKDIRAT: {},
+ // Used by the Go runtime as a temporarily workaround for a Linux
+ // 5.2-5.4 bug.
+ //
+ // See src/runtime/os_linux_x86.go.
+ //
+ // TODO(b/148688965): Remove once this is gone from Go.
+ syscall.SYS_MLOCK: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(4096),
+ },
+ },
syscall.SYS_MMAP: []seccomp.Rule{
{
seccomp.AllowAny{},
diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go
index 4d84ad999..cadd83273 100644
--- a/runsc/fsgofer/fsgofer.go
+++ b/runsc/fsgofer/fsgofer.go
@@ -768,12 +768,22 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error {
}
// TODO(b/127675828): support getxattr.
-func (l *localFile) GetXattr(name string, size uint64) (string, error) {
+func (*localFile) GetXattr(string, uint64) (string, error) {
return "", syscall.EOPNOTSUPP
}
// TODO(b/127675828): support setxattr.
-func (l *localFile) SetXattr(name, value string, flags uint32) error {
+func (*localFile) SetXattr(string, string, uint32) error {
+ return syscall.EOPNOTSUPP
+}
+
+// TODO(b/148303075): support listxattr.
+func (*localFile) ListXattr(uint64) (map[string]struct{}, error) {
+ return nil, syscall.EOPNOTSUPP
+}
+
+// TODO(b/148303075): support removexattr.
+func (*localFile) RemoveXattr(string) error {
return syscall.EOPNOTSUPP
}
@@ -790,7 +800,7 @@ func (l *localFile) Allocate(mode p9.AllocateMode, offset, length uint64) error
}
// Rename implements p9.File; this should never be called.
-func (l *localFile) Rename(p9.File, string) error {
+func (*localFile) Rename(p9.File, string) error {
panic("rename called directly")
}
diff --git a/runsc/main.go b/runsc/main.go
index c2b0d9a9e..af73bed97 100644
--- a/runsc/main.go
+++ b/runsc/main.go
@@ -28,14 +28,13 @@ import (
"syscall"
"time"
- "flag"
-
"github.com/google/subcommands"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/cmd"
+ "gvisor.dev/gvisor/runsc/flag"
"gvisor.dev/gvisor/runsc/specutils"
)
@@ -117,8 +116,8 @@ func main() {
subcommands.Register(new(cmd.Resume), "")
subcommands.Register(new(cmd.Run), "")
subcommands.Register(new(cmd.Spec), "")
- subcommands.Register(new(cmd.Start), "")
subcommands.Register(new(cmd.State), "")
+ subcommands.Register(new(cmd.Start), "")
subcommands.Register(new(cmd.Wait), "")
// Register internal commands with the internal group name. This causes
@@ -128,6 +127,7 @@ func main() {
subcommands.Register(new(cmd.Boot), internalGroup)
subcommands.Register(new(cmd.Debug), internalGroup)
subcommands.Register(new(cmd.Gofer), internalGroup)
+ subcommands.Register(new(cmd.Statefile), internalGroup)
// All subcommands must be registered before flag parsing.
flag.Parse()
diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go
index ff48f5646..bc093fba5 100644
--- a/runsc/sandbox/network.go
+++ b/runsc/sandbox/network.go
@@ -21,7 +21,6 @@ import (
"path/filepath"
"runtime"
"strconv"
- "strings"
"syscall"
specs "github.com/opencontainers/runtime-spec/specs-go"
@@ -75,30 +74,8 @@ func setupNetwork(conn *urpc.Client, pid int, spec *specs.Spec, conf *boot.Confi
}
func createDefaultLoopbackInterface(conn *urpc.Client) error {
- link := boot.LoopbackLink{
- Name: "lo",
- Addresses: []net.IP{
- net.IP("\x7f\x00\x00\x01"),
- net.IPv6loopback,
- },
- Routes: []boot.Route{
- {
- Destination: net.IPNet{
-
- IP: net.IPv4(0x7f, 0, 0, 0),
- Mask: net.IPv4Mask(0xff, 0, 0, 0),
- },
- },
- {
- Destination: net.IPNet{
- IP: net.IPv6loopback,
- Mask: net.IPMask(strings.Repeat("\xff", net.IPv6len)),
- },
- },
- },
- }
if err := conn.Call(boot.NetworkCreateLinksAndRoutes, &boot.CreateLinksAndRoutesArgs{
- LoopbackLinks: []boot.LoopbackLink{link},
+ LoopbackLinks: []boot.LoopbackLink{boot.DefaultLoopbackLink},
}, nil); err != nil {
return fmt.Errorf("creating loopback link and routes: %v", err)
}
@@ -174,13 +151,13 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareG
return fmt.Errorf("fetching interface addresses for %q: %v", iface.Name, err)
}
- // We build our own loopback devices.
+ // We build our own loopback device.
if iface.Flags&net.FlagLoopback != 0 {
- links, err := loopbackLinks(iface, allAddrs)
+ link, err := loopbackLink(iface, allAddrs)
if err != nil {
- return fmt.Errorf("getting loopback routes and links for iface %q: %v", iface.Name, err)
+ return fmt.Errorf("getting loopback link for iface %q: %v", iface.Name, err)
}
- args.LoopbackLinks = append(args.LoopbackLinks, links...)
+ args.LoopbackLinks = append(args.LoopbackLinks, link)
continue
}
@@ -339,25 +316,25 @@ func createSocket(iface net.Interface, ifaceLink netlink.Link, enableGSO bool) (
return &socketEntry{deviceFile, gsoMaxSize}, nil
}
-// loopbackLinks collects the links for a loopback interface.
-func loopbackLinks(iface net.Interface, addrs []net.Addr) ([]boot.LoopbackLink, error) {
- var links []boot.LoopbackLink
+// loopbackLink returns the link with addresses and routes for a loopback
+// interface.
+func loopbackLink(iface net.Interface, addrs []net.Addr) (boot.LoopbackLink, error) {
+ link := boot.LoopbackLink{
+ Name: iface.Name,
+ }
for _, addr := range addrs {
ipNet, ok := addr.(*net.IPNet)
if !ok {
- return nil, fmt.Errorf("address is not IPNet: %+v", addr)
+ return boot.LoopbackLink{}, fmt.Errorf("address is not IPNet: %+v", addr)
}
dst := *ipNet
dst.IP = dst.IP.Mask(dst.Mask)
- links = append(links, boot.LoopbackLink{
- Name: iface.Name,
- Addresses: []net.IP{ipNet.IP},
- Routes: []boot.Route{{
- Destination: dst,
- }},
+ link.Addresses = append(link.Addresses, ipNet.IP)
+ link.Routes = append(link.Routes, boot.Route{
+ Destination: dst,
})
}
- return links, nil
+ return link, nil
}
// routesForIface iterates over all routes for the given interface and converts
diff --git a/runsc/testutil/BUILD b/runsc/testutil/BUILD
index f845120b0..945405303 100644
--- a/runsc/testutil/BUILD
+++ b/runsc/testutil/BUILD
@@ -5,7 +5,10 @@ package(licenses = ["notice"])
go_library(
name = "testutil",
testonly = 1,
- srcs = ["testutil.go"],
+ srcs = [
+ "testutil.go",
+ "testutil_runfiles.go",
+ ],
visibility = ["//:sandbox"],
deps = [
"//pkg/log",
diff --git a/runsc/testutil/testutil.go b/runsc/testutil/testutil.go
index fb22eae39..92d677e71 100644
--- a/runsc/testutil/testutil.go
+++ b/runsc/testutil/testutil.go
@@ -79,65 +79,16 @@ func ConfigureExePath() error {
return nil
}
-// FindFile searchs for a file inside the test run environment. It returns the
-// full path to the file. It fails if none or more than one file is found.
-func FindFile(path string) (string, error) {
- wd, err := os.Getwd()
- if err != nil {
- return "", err
- }
-
- // The test root is demarcated by a path element called "__main__". Search for
- // it backwards from the working directory.
- root := wd
- for {
- dir, name := filepath.Split(root)
- if name == "__main__" {
- break
- }
- if len(dir) == 0 {
- return "", fmt.Errorf("directory __main__ not found in %q", wd)
- }
- // Remove ending slash to loop around.
- root = dir[:len(dir)-1]
- }
-
- // Annoyingly, bazel adds the build type to the directory path for go
- // binaries, but not for c++ binaries. We use two different patterns to
- // to find our file.
- patterns := []string{
- // Try the obvious path first.
- filepath.Join(root, path),
- // If it was a go binary, use a wildcard to match the build
- // type. The pattern is: /test-path/__main__/directories/*/file.
- filepath.Join(root, filepath.Dir(path), "*", filepath.Base(path)),
- }
-
- for _, p := range patterns {
- matches, err := filepath.Glob(p)
- if err != nil {
- // "The only possible returned error is ErrBadPattern,
- // when pattern is malformed." -godoc
- return "", fmt.Errorf("error globbing %q: %v", p, err)
- }
- switch len(matches) {
- case 0:
- // Try the next pattern.
- case 1:
- // We found it.
- return matches[0], nil
- default:
- return "", fmt.Errorf("more than one match found for %q: %s", path, matches)
- }
- }
- return "", fmt.Errorf("file %q not found", path)
-}
-
// TestConfig returns the default configuration to use in tests. Note that
// 'RootDir' must be set by caller if required.
func TestConfig() *boot.Config {
+ logDir := ""
+ if dir, ok := os.LookupEnv("TEST_UNDECLARED_OUTPUTS_DIR"); ok {
+ logDir = dir + "/"
+ }
return &boot.Config{
Debug: true,
+ DebugLog: logDir,
LogFormat: "text",
DebugLogFormat: "text",
AlsoLogToStderr: true,
@@ -168,6 +119,13 @@ func NewSpecWithArgs(args ...string) *specs.Spec {
Capabilities: specutils.AllCapabilities(),
},
Mounts: []specs.Mount{
+ // Hide the host /etc to avoid any side-effects.
+ // For example, bash reads /etc/passwd and if it is
+ // very big, tests can fail by timeout.
+ {
+ Type: "tmpfs",
+ Destination: "/etc",
+ },
// Root is readonly, but many tests want to write to tmpdir.
// This creates a writable mount inside the root. Also, when tmpdir points
// to "/tmp", it makes the the actual /tmp to be mounted and not a tmpfs
@@ -434,43 +392,40 @@ func IsStatic(filename string) (bool, error) {
return true, nil
}
-// TestBoundsForShard calculates the beginning and end indices for the test
-// based on the TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars. The
-// returned ints are the beginning (inclusive) and end (exclusive) of the
-// subslice corresponding to the shard. If either of the env vars are not
-// present, then the function will return bounds that include all tests. If
-// there are more shards than there are tests, then the returned list may be
-// empty.
-func TestBoundsForShard(numTests int) (int, int, error) {
+// TestIndicesForShard returns indices for this test shard based on the
+// TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars.
+//
+// If either of the env vars are not present, then the function will return all
+// tests. If there are more shards than there are tests, then the returned list
+// may be empty.
+func TestIndicesForShard(numTests int) ([]int, error) {
var (
- begin = 0
- end = numTests
+ shardIndex = 0
+ shardTotal = 1
)
- indexStr, totalStr := os.Getenv("TEST_SHARD_INDEX"), os.Getenv("TEST_TOTAL_SHARDS")
- if indexStr == "" || totalStr == "" {
- return begin, end, nil
- }
- // Parse index and total to ints.
- shardIndex, err := strconv.Atoi(indexStr)
- if err != nil {
- return 0, 0, fmt.Errorf("invalid TEST_SHARD_INDEX %q: %v", indexStr, err)
- }
- shardTotal, err := strconv.Atoi(totalStr)
- if err != nil {
- return 0, 0, fmt.Errorf("invalid TEST_TOTAL_SHARDS %q: %v", totalStr, err)
+ indexStr, totalStr := os.Getenv("TEST_SHARD_INDEX"), os.Getenv("TEST_TOTAL_SHARDS")
+ if indexStr != "" && totalStr != "" {
+ // Parse index and total to ints.
+ var err error
+ shardIndex, err = strconv.Atoi(indexStr)
+ if err != nil {
+ return nil, fmt.Errorf("invalid TEST_SHARD_INDEX %q: %v", indexStr, err)
+ }
+ shardTotal, err = strconv.Atoi(totalStr)
+ if err != nil {
+ return nil, fmt.Errorf("invalid TEST_TOTAL_SHARDS %q: %v", totalStr, err)
+ }
}
// Calculate!
- shardSize := int(math.Ceil(float64(numTests) / float64(shardTotal)))
- begin = shardIndex * shardSize
- end = ((shardIndex + 1) * shardSize)
- if begin > numTests {
- // Nothing to run.
- return 0, 0, nil
- }
- if end > numTests {
- end = numTests
+ var indices []int
+ numBlocks := int(math.Ceil(float64(numTests) / float64(shardTotal)))
+ for i := 0; i < numBlocks; i++ {
+ pick := i*shardTotal + shardIndex
+ if pick < numTests {
+ indices = append(indices, pick)
+ }
}
- return begin, end, nil
+ return indices, nil
}
diff --git a/runsc/testutil/testutil_runfiles.go b/runsc/testutil/testutil_runfiles.go
new file mode 100644
index 000000000..ece9ea9a1
--- /dev/null
+++ b/runsc/testutil/testutil_runfiles.go
@@ -0,0 +1,75 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testutil
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+)
+
+// FindFile searchs for a file inside the test run environment. It returns the
+// full path to the file. It fails if none or more than one file is found.
+func FindFile(path string) (string, error) {
+ wd, err := os.Getwd()
+ if err != nil {
+ return "", err
+ }
+
+ // The test root is demarcated by a path element called "__main__". Search for
+ // it backwards from the working directory.
+ root := wd
+ for {
+ dir, name := filepath.Split(root)
+ if name == "__main__" {
+ break
+ }
+ if len(dir) == 0 {
+ return "", fmt.Errorf("directory __main__ not found in %q", wd)
+ }
+ // Remove ending slash to loop around.
+ root = dir[:len(dir)-1]
+ }
+
+ // Annoyingly, bazel adds the build type to the directory path for go
+ // binaries, but not for c++ binaries. We use two different patterns to
+ // to find our file.
+ patterns := []string{
+ // Try the obvious path first.
+ filepath.Join(root, path),
+ // If it was a go binary, use a wildcard to match the build
+ // type. The pattern is: /test-path/__main__/directories/*/file.
+ filepath.Join(root, filepath.Dir(path), "*", filepath.Base(path)),
+ }
+
+ for _, p := range patterns {
+ matches, err := filepath.Glob(p)
+ if err != nil {
+ // "The only possible returned error is ErrBadPattern,
+ // when pattern is malformed." -godoc
+ return "", fmt.Errorf("error globbing %q: %v", p, err)
+ }
+ switch len(matches) {
+ case 0:
+ // Try the next pattern.
+ case 1:
+ // We found it.
+ return matches[0], nil
+ default:
+ return "", fmt.Errorf("more than one match found for %q: %s", path, matches)
+ }
+ }
+ return "", fmt.Errorf("file %q not found", path)
+}
diff --git a/scripts/benchmark.sh b/scripts/benchmark.sh
new file mode 100644
index 000000000..a0317db02
--- /dev/null
+++ b/scripts/benchmark.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+
+# Copyright 2020 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Run in the root of the repo.
+cd "$(dirname "$0")"
+
+KEY_PATH=${KEY_PATH:-"${KOKORO_KEYSTORE_DIR}/${KOKORO_SERVICE_ACCOUNT}"}
+
+gcloud auth activate-service-account --key-file "${KEY_PATH}"
+
+gcloud compute instances list
+
diff --git a/scripts/common.sh b/scripts/common.sh
index cd91b9f8e..3ca699e4a 100755
--- a/scripts/common.sh
+++ b/scripts/common.sh
@@ -16,7 +16,17 @@
set -xeou pipefail
-source $(dirname $0)/common_build.sh
+# Get the path to the directory this script lives in.
+# If this script is being called with `source`, $0 will be the path of the
+# *sourcing* script, so we can't use `dirname $0` to find scripts in this
+# directory.
+if [[ -v BASH_SOURCE && "$0" != "$BASH_SOURCE" ]]; then
+ declare -r script_dir="$(dirname "$BASH_SOURCE")"
+else
+ declare -r script_dir="$(dirname "$0")"
+fi
+
+source "${script_dir}/common_build.sh"
# Ensure it attempts to collect logs in all cases.
trap collect_logs EXIT
diff --git a/scripts/common_build.sh b/scripts/common_build.sh
index a473a88a4..3be0bb21c 100755
--- a/scripts/common_build.sh
+++ b/scripts/common_build.sh
@@ -14,8 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# Install the latest version of Bazel and log the version.
-(which use_bazel.sh && use_bazel.sh latest) || which bazel
+which bazel
bazel version
# Switch into the workspace; only necessary if run with kokoro.
@@ -26,27 +25,30 @@ elif [[ -v KOKORO_GIT_COMMIT ]] && [[ -d github/repo ]]; then
fi
# Set the standard bazel flags.
-declare -r BAZEL_FLAGS=(
+declare -a BAZEL_FLAGS=(
"--show_timestamps"
"--test_output=errors"
"--keep_going"
"--verbose_failures=true"
)
if [[ -v KOKORO_BAZEL_AUTH_CREDENTIAL ]]; then
- declare -r BAZEL_RBE_AUTH_FLAGS=(
+ BAZEL_FLAGS+=(
"--auth_credentials=${KOKORO_BAZEL_AUTH_CREDENTIAL}"
+ "--config=remote"
)
- declare -r BAZEL_RBE_FLAGS=("--config=remote")
fi
+declare -r BAZEL_FLAGS
# Wrap bazel.
function build() {
- bazel build "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@" 2>&1 |
- tee /dev/fd/2 | grep -E '^ bazel-bin/' | awk '{ print $1; }'
+ bazel build "${BAZEL_FLAGS[@]}" "$@" 2>&1 \
+ | tee /dev/fd/2 \
+ | grep -E '^ bazel-bin/' \
+ | awk '{ print $1; }'
}
function test() {
- bazel test "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@"
+ bazel test "${BAZEL_FLAGS[@]}" "$@"
}
function run() {
@@ -68,7 +70,9 @@ function collect_logs() {
for d in `find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs dirname | sort | uniq`; do
junitparser merge `find $d -name test.xml` $d/test.xml
cat $d/shard_*_of_*/test.log > $d/test.log
- ls -l $d/shard_*_of_*/test.outputs/outputs.zip && zip -r -1 $d/outputs.zip $d/shard_*_of_*/test.outputs/outputs.zip
+ if ls -l $d/shard_*_of_*/test.outputs/outputs.zip 2>/dev/null; then
+ zip -r -1 "$d/outputs.zip" $d/shard_*_of_*/test.outputs/outputs.zip
+ fi
done
find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs rm -rf
# Move test logs to Kokoro directory. tar is used to conveniently perform
@@ -88,12 +92,21 @@ function collect_logs() {
echo " gsutil cp gs://gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive} /tmp"
echo " https://storage.cloud.google.com/gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive}"
fi
- tar --create --gzip --file="${KOKORO_ARTIFACTS_DIR}/${archive}" -C "${RUNSC_LOGS_DIR}" .
+ time tar \
+ --verbose \
+ --create \
+ --gzip \
+ --file="${KOKORO_ARTIFACTS_DIR}/${archive}" \
+ --directory "${RUNSC_LOGS_DIR}" \
+ .
fi
fi
fi
}
function find_branch_name() {
- git branch --show-current || git rev-parse HEAD || bazel info workspace | xargs basename
+ git branch --show-current \
+ || git rev-parse HEAD \
+ || bazel info workspace \
+ | xargs basename
}
diff --git a/scripts/iptables_tests.sh b/scripts/iptables_tests.sh
index c47cbd675..3069d8628 100755
--- a/scripts/iptables_tests.sh
+++ b/scripts/iptables_tests.sh
@@ -19,9 +19,9 @@ source $(dirname $0)/common.sh
install_runsc_for_test iptables
# Build the docker image for the test.
-run //test/iptables/runner --norun
+run //test/iptables/runner-image --norun
# TODO(gvisor.dev/issue/170): Also test this on runsc once iptables are better
# supported
test //test/iptables:iptables_test "--test_arg=--runtime=runc" \
- "--test_arg=--image=bazel/test/iptables/runner:runner"
+ "--test_arg=--image=bazel/test/iptables/runner:runner-image"
diff --git a/scripts/packetdrill_tests.sh b/scripts/packetdrill_tests.sh
new file mode 100755
index 000000000..fc6bef79c
--- /dev/null
+++ b/scripts/packetdrill_tests.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+install_runsc_for_test runsc-d
+test_runsc $(bazel query "attr(tags, manual, tests(//test/packetdrill/...))")
diff --git a/scripts/release.sh b/scripts/release.sh
index 091abf87f..e14ba04a7 100755
--- a/scripts/release.sh
+++ b/scripts/release.sh
@@ -25,6 +25,14 @@ if ! [[ -v KOKORO_RELEASE_TAG ]]; then
echo "No KOKORO_RELEASE_TAG provided." >&2
exit 1
fi
+if ! [[ -v KOKORO_RELNOTES ]]; then
+ echo "No KOKORO_RELNOTES provided." >&2
+ exit 1
+fi
+if ! [[ -r "${KOKORO_ARTIFACTS_DIR}/${KOKORO_RELNOTES}" ]]; then
+ echo "The file '${KOKORO_ARTIFACTS_DIR}/${KOKORO_RELNOTES}' is not readable." >&2
+ exit 1
+fi
# Unless an explicit releaser is provided, use the bot e-mail.
declare -r KOKORO_RELEASE_AUTHOR=${KOKORO_RELEASE_AUTHOR:-gvisor-bot}
@@ -46,4 +54,7 @@ EOF
fi
# Run the release tool, which pushes to the origin repository.
-tools/tag_release.sh "${KOKORO_RELEASE_COMMIT}" "${KOKORO_RELEASE_TAG}"
+tools/tag_release.sh \
+ "${KOKORO_RELEASE_COMMIT}" \
+ "${KOKORO_RELEASE_TAG}" \
+ "${KOKORO_ARTIFACTS_DIR}/${KOKORO_RELNOTES}"
diff --git a/test/image/image_test.go b/test/image/image_test.go
index d0dcb1861..0a1e19d6f 100644
--- a/test/image/image_test.go
+++ b/test/image/image_test.go
@@ -107,7 +107,7 @@ func TestHttpd(t *testing.T) {
}
d := dockerutil.MakeDocker("http-test")
- dir, err := dockerutil.PrepareFiles("latin10k.txt")
+ dir, err := dockerutil.PrepareFiles("test/image/latin10k.txt")
if err != nil {
t.Fatalf("PrepareFiles() failed: %v", err)
}
@@ -139,7 +139,7 @@ func TestNginx(t *testing.T) {
}
d := dockerutil.MakeDocker("net-test")
- dir, err := dockerutil.PrepareFiles("latin10k.txt")
+ dir, err := dockerutil.PrepareFiles("test/image/latin10k.txt")
if err != nil {
t.Fatalf("PrepareFiles() failed: %v", err)
}
@@ -183,7 +183,7 @@ func TestMysql(t *testing.T) {
}
client := dockerutil.MakeDocker("mysql-client-test")
- dir, err := dockerutil.PrepareFiles("mysql.sql")
+ dir, err := dockerutil.PrepareFiles("test/image/mysql.sql")
if err != nil {
t.Fatalf("PrepareFiles() failed: %v", err)
}
@@ -283,7 +283,7 @@ func TestRuby(t *testing.T) {
}
d := dockerutil.MakeDocker("ruby-test")
- dir, err := dockerutil.PrepareFiles("ruby.rb", "ruby.sh")
+ dir, err := dockerutil.PrepareFiles("test/image/ruby.rb", "test/image/ruby.sh")
if err != nil {
t.Fatalf("PrepareFiles() failed: %v", err)
}
diff --git a/test/iptables/README.md b/test/iptables/README.md
index 9f8e34420..cc8a2fcac 100644
--- a/test/iptables/README.md
+++ b/test/iptables/README.md
@@ -2,6 +2,9 @@
iptables tests are run via `scripts/iptables_test.sh`.
+iptables requires raw socket support, so you must add the `--net-raw=true` flag
+to `/etc/docker/daemon.json` in order to use it.
+
## Test Structure
Each test implements `TestCase`, providing (1) a function to run inside the
@@ -25,10 +28,17 @@ Your test is now runnable with bazel!
## Run individual tests
-Build the testing Docker container:
+Build and install `runsc`. Re-run this when you modify gVisor:
+
+```bash
+$ bazel build //runsc && sudo cp bazel-bin/runsc/linux_amd64_pure_stripped/runsc $(which runsc)
+```
+
+Build the testing Docker container. Re-run this when you modify the test code in
+this directory:
```bash
-$ bazel run //test/iptables/runner -- --norun
+$ bazel run //test/iptables/runner:runner-image -- --norun
```
Run an individual test via:
diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go
index fd02ff2ff..b2fb6401a 100644
--- a/test/iptables/filter_input.go
+++ b/test/iptables/filter_input.go
@@ -15,6 +15,7 @@
package iptables
import (
+ "errors"
"fmt"
"net"
"time"
@@ -25,6 +26,7 @@ const (
acceptPort = 2402
sendloopDuration = 2 * time.Second
network = "udp4"
+ chainName = "foochain"
)
func init() {
@@ -35,6 +37,16 @@ func init() {
RegisterTestCase(FilterInputDropTCPSrcPort{})
RegisterTestCase(FilterInputDropUDPPort{})
RegisterTestCase(FilterInputDropUDP{})
+ RegisterTestCase(FilterInputCreateUserChain{})
+ RegisterTestCase(FilterInputDefaultPolicyAccept{})
+ RegisterTestCase(FilterInputDefaultPolicyDrop{})
+ RegisterTestCase(FilterInputReturnUnderflow{})
+ RegisterTestCase(FilterInputSerializeJump{})
+ RegisterTestCase(FilterInputJumpBasic{})
+ RegisterTestCase(FilterInputJumpReturn{})
+ RegisterTestCase(FilterInputJumpReturnDrop{})
+ RegisterTestCase(FilterInputJumpBuiltin{})
+ RegisterTestCase(FilterInputJumpTwice{})
}
// FilterInputDropUDP tests that we can drop UDP traffic.
@@ -248,3 +260,338 @@ func (FilterInputDropAll) ContainerAction(ip net.IP) error {
func (FilterInputDropAll) LocalAction(ip net.IP) error {
return sendUDPLoop(ip, dropPort, sendloopDuration)
}
+
+// FilterInputMultiUDPRules verifies that multiple UDP rules are applied
+// correctly. This has the added benefit of testing whether we're serializing
+// rules correctly -- if we do it incorrectly, the iptables tool will
+// misunderstand and save the wrong tables.
+type FilterInputMultiUDPRules struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputMultiUDPRules) Name() string {
+ return "FilterInputMultiUDPRules"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputMultiUDPRules) ContainerAction(ip net.IP) error {
+ rules := [][]string{
+ {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"},
+ {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", acceptPort), "-j", "ACCEPT"},
+ {"-L"},
+ }
+ return filterTableRules(rules)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputMultiUDPRules) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// FilterInputRequireProtocolUDP checks that "-m udp" requires "-p udp" to be
+// specified.
+type FilterInputRequireProtocolUDP struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputRequireProtocolUDP) Name() string {
+ return "FilterInputRequireProtocolUDP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputRequireProtocolUDP) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "INPUT", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err == nil {
+ return errors.New("expected iptables to fail with out \"-p udp\", but succeeded")
+ }
+ return nil
+}
+
+func (FilterInputRequireProtocolUDP) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// FilterInputCreateUserChain tests chain creation.
+type FilterInputCreateUserChain struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputCreateUserChain) Name() string {
+ return "FilterInputCreateUserChain"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputCreateUserChain) ContainerAction(ip net.IP) error {
+ rules := [][]string{
+ // Create a chain.
+ {"-N", chainName},
+ // Add a simple rule to the chain.
+ {"-A", chainName, "-j", "DROP"},
+ }
+ return filterTableRules(rules)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputCreateUserChain) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// FilterInputDefaultPolicyAccept tests the default ACCEPT policy.
+type FilterInputDefaultPolicyAccept struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputDefaultPolicyAccept) Name() string {
+ return "FilterInputDefaultPolicyAccept"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDefaultPolicyAccept) ContainerAction(ip net.IP) error {
+ // Set the default policy to accept, then receive a packet.
+ if err := filterTable("-P", "INPUT", "ACCEPT"); err != nil {
+ return err
+ }
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDefaultPolicyAccept) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// FilterInputDefaultPolicyDrop tests the default DROP policy.
+type FilterInputDefaultPolicyDrop struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputDefaultPolicyDrop) Name() string {
+ return "FilterInputDefaultPolicyDrop"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDefaultPolicyDrop) ContainerAction(ip net.IP) error {
+ if err := filterTable("-P", "INPUT", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for UDP packets on dropPort.
+ if err := listenUDP(dropPort, sendloopDuration); err == nil {
+ return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort)
+ } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ // At this point we know that reading timed out and never received a
+ // packet.
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDefaultPolicyDrop) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// FilterInputReturnUnderflow tests that -j RETURN in a built-in chain causes
+// the underflow rule (i.e. default policy) to be executed.
+type FilterInputReturnUnderflow struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputReturnUnderflow) Name() string {
+ return "FilterInputReturnUnderflow"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputReturnUnderflow) ContainerAction(ip net.IP) error {
+ // Add a RETURN rule followed by an unconditional accept, and set the
+ // default policy to DROP.
+ rules := [][]string{
+ {"-A", "INPUT", "-j", "RETURN"},
+ {"-A", "INPUT", "-j", "DROP"},
+ {"-P", "INPUT", "ACCEPT"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ // We should receive packets, as the RETURN rule will trigger the default
+ // ACCEPT policy.
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputReturnUnderflow) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// FilterInputSerializeJump verifies that we can serialize jumps.
+type FilterInputSerializeJump struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputSerializeJump) Name() string {
+ return "FilterInputSerializeJump"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputSerializeJump) ContainerAction(ip net.IP) error {
+ // Write a JUMP rule, the serialize it with `-L`.
+ rules := [][]string{
+ {"-N", chainName},
+ {"-A", "INPUT", "-j", chainName},
+ {"-L"},
+ }
+ return filterTableRules(rules)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputSerializeJump) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// FilterInputJumpBasic jumps to a chain and executes a rule there.
+type FilterInputJumpBasic struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputJumpBasic) Name() string {
+ return "FilterInputJumpBasic"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputJumpBasic) ContainerAction(ip net.IP) error {
+ rules := [][]string{
+ {"-P", "INPUT", "DROP"},
+ {"-N", chainName},
+ {"-A", "INPUT", "-j", chainName},
+ {"-A", chainName, "-j", "ACCEPT"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ // Listen for UDP packets on acceptPort.
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputJumpBasic) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// FilterInputJumpReturn jumps, returns, and executes a rule.
+type FilterInputJumpReturn struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputJumpReturn) Name() string {
+ return "FilterInputJumpReturn"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputJumpReturn) ContainerAction(ip net.IP) error {
+ rules := [][]string{
+ {"-N", chainName},
+ {"-P", "INPUT", "ACCEPT"},
+ {"-A", "INPUT", "-j", chainName},
+ {"-A", chainName, "-j", "RETURN"},
+ {"-A", chainName, "-j", "DROP"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ // Listen for UDP packets on acceptPort.
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputJumpReturn) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// FilterInputJumpReturnDrop jumps to a chain, returns, and DROPs packets.
+type FilterInputJumpReturnDrop struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputJumpReturnDrop) Name() string {
+ return "FilterInputJumpReturnDrop"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputJumpReturnDrop) ContainerAction(ip net.IP) error {
+ rules := [][]string{
+ {"-N", chainName},
+ {"-A", "INPUT", "-j", chainName},
+ {"-A", "INPUT", "-j", "DROP"},
+ {"-A", chainName, "-j", "RETURN"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ // Listen for UDP packets on dropPort.
+ if err := listenUDP(dropPort, sendloopDuration); err == nil {
+ return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort)
+ } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ // At this point we know that reading timed out and never received a
+ // packet.
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputJumpReturnDrop) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, dropPort, sendloopDuration)
+}
+
+// FilterInputJumpBuiltin verifies that jumping to a top-levl chain is illegal.
+type FilterInputJumpBuiltin struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputJumpBuiltin) Name() string {
+ return "FilterInputJumpBuiltin"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputJumpBuiltin) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "INPUT", "-j", "OUTPUT"); err == nil {
+ return fmt.Errorf("iptables should be unable to jump to a built-in chain")
+ }
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputJumpBuiltin) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// FilterInputJumpTwice jumps twice, then returns twice and executes a rule.
+type FilterInputJumpTwice struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputJumpTwice) Name() string {
+ return "FilterInputJumpTwice"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputJumpTwice) ContainerAction(ip net.IP) error {
+ const chainName2 = chainName + "2"
+ rules := [][]string{
+ {"-P", "INPUT", "DROP"},
+ {"-N", chainName},
+ {"-N", chainName2},
+ {"-A", "INPUT", "-j", chainName},
+ {"-A", chainName, "-j", chainName2},
+ {"-A", "INPUT", "-j", "ACCEPT"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ // UDP packets should jump and return twice, eventually hitting the
+ // ACCEPT rule.
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputJumpTwice) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go
index 679a29bef..0621861eb 100644
--- a/test/iptables/iptables_test.go
+++ b/test/iptables/iptables_test.go
@@ -30,7 +30,7 @@ import (
const timeout = 18 * time.Second
-var image = flag.String("image", "bazel/test/iptables/runner:runner", "image to run tests in")
+var image = flag.String("image", "bazel/test/iptables/runner:runner-image", "image to run tests in")
type result struct {
output string
@@ -214,6 +214,30 @@ func TestFilterInputDropTCPSrcPort(t *testing.T) {
}
}
+func TestFilterInputCreateUserChain(t *testing.T) {
+ if err := singleTest(FilterInputCreateUserChain{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestFilterInputDefaultPolicyAccept(t *testing.T) {
+ if err := singleTest(FilterInputDefaultPolicyAccept{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestFilterInputDefaultPolicyDrop(t *testing.T) {
+ if err := singleTest(FilterInputDefaultPolicyDrop{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestFilterInputReturnUnderflow(t *testing.T) {
+ if err := singleTest(FilterInputReturnUnderflow{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
func TestFilterOutputDropTCPDestPort(t *testing.T) {
if err := singleTest(FilterOutputDropTCPDestPort{}); err != nil {
t.Fatal(err)
@@ -225,3 +249,39 @@ func TestFilterOutputDropTCPSrcPort(t *testing.T) {
t.Fatal(err)
}
}
+
+func TestJumpSerialize(t *testing.T) {
+ if err := singleTest(FilterInputSerializeJump{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestJumpBasic(t *testing.T) {
+ if err := singleTest(FilterInputJumpBasic{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestJumpReturn(t *testing.T) {
+ if err := singleTest(FilterInputJumpReturn{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestJumpReturnDrop(t *testing.T) {
+ if err := singleTest(FilterInputJumpReturnDrop{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestJumpBuiltin(t *testing.T) {
+ if err := singleTest(FilterInputJumpBuiltin{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestJumpTwice(t *testing.T) {
+ if err := singleTest(FilterInputJumpTwice{}); err != nil {
+ t.Fatal(err)
+ }
+}
diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go
index 043114c78..32cf5a417 100644
--- a/test/iptables/iptables_util.go
+++ b/test/iptables/iptables_util.go
@@ -27,7 +27,16 @@ const iptablesBinary = "iptables"
// filterTable calls `iptables -t filter` with the given args.
func filterTable(args ...string) error {
- args = append([]string{"-t", "filter"}, args...)
+ return tableCmd("filter", args)
+}
+
+// natTable calls `iptables -t nat` with the given args.
+func natTable(args ...string) error {
+ return tableCmd("nat", args)
+}
+
+func tableCmd(table string, args []string) error {
+ args = append([]string{"-t", table}, args...)
cmd := exec.Command(iptablesBinary, args...)
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("error running iptables with args %v\nerror: %v\noutput: %s", args, err, string(out))
@@ -35,6 +44,16 @@ func filterTable(args ...string) error {
return nil
}
+// filterTableRules is like filterTable, but runs multiple iptables commands.
+func filterTableRules(argsList [][]string) error {
+ for _, args := range argsList {
+ if err := filterTable(args...); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
// listenUDP listens on a UDP port and returns the value of net.Conn.Read() for
// the first read on that port.
func listenUDP(port int, timeout time.Duration) error {
diff --git a/test/iptables/nat.go b/test/iptables/nat.go
index b5c6f927e..a01117ec8 100644
--- a/test/iptables/nat.go
+++ b/test/iptables/nat.go
@@ -38,7 +38,7 @@ func (NATRedirectUDPPort) Name() string {
// ContainerAction implements TestCase.ContainerAction.
func (NATRedirectUDPPort) ContainerAction(ip net.IP) error {
- if err := filterTable("-t", "nat", "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil {
+ if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil {
return err
}
@@ -63,7 +63,7 @@ func (NATDropUDP) Name() string {
// ContainerAction implements TestCase.ContainerAction.
func (NATDropUDP) ContainerAction(ip net.IP) error {
- if err := filterTable("-t", "nat", "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil {
+ if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil {
return err
}
diff --git a/test/packetdrill/BUILD b/test/packetdrill/BUILD
new file mode 100644
index 000000000..d113555b1
--- /dev/null
+++ b/test/packetdrill/BUILD
@@ -0,0 +1,8 @@
+load("defs.bzl", "packetdrill_test")
+
+package(licenses = ["notice"])
+
+packetdrill_test(
+ name = "fin_wait2_timeout",
+ scripts = ["fin_wait2_timeout.pkt"],
+)
diff --git a/test/packetdrill/Dockerfile b/test/packetdrill/Dockerfile
new file mode 100644
index 000000000..bd4451355
--- /dev/null
+++ b/test/packetdrill/Dockerfile
@@ -0,0 +1,9 @@
+FROM ubuntu:bionic
+
+RUN apt-get update
+RUN apt-get install -y net-tools git iptables iputils-ping netcat tcpdump jq tar
+RUN hash -r
+RUN git clone --branch packetdrill-v2.0 \
+ https://github.com/google/packetdrill.git
+RUN cd packetdrill/gtests/net/packetdrill && ./configure && \
+ apt-get install -y bison flex make && make
diff --git a/test/packetdrill/defs.bzl b/test/packetdrill/defs.bzl
new file mode 100644
index 000000000..8623ce7b1
--- /dev/null
+++ b/test/packetdrill/defs.bzl
@@ -0,0 +1,87 @@
+"""Defines a rule for packetdrill test targets."""
+
+def _packetdrill_test_impl(ctx):
+ test_runner = ctx.executable._test_runner
+ runner = ctx.actions.declare_file("%s-runner" % ctx.label.name)
+
+ script_paths = []
+ for script in ctx.files.scripts:
+ script_paths.append(script.short_path)
+ runner_content = "\n".join([
+ "#!/bin/bash",
+ # This test will run part in a distinct user namespace. This can cause
+ # permission problems, because all runfiles may not be owned by the
+ # current user, and no other users will be mapped in that namespace.
+ # Make sure that everything is readable here.
+ "find . -type f -exec chmod a+rx {} \\;",
+ "find . -type d -exec chmod a+rx {} \\;",
+ "%s %s --init_script %s $@ -- %s\n" % (
+ test_runner.short_path,
+ " ".join(ctx.attr.flags),
+ ctx.files._init_script[0].short_path,
+ " ".join(script_paths),
+ ),
+ ])
+ ctx.actions.write(runner, runner_content, is_executable = True)
+
+ transitive_files = depset()
+ if hasattr(ctx.attr._test_runner, "data_runfiles"):
+ transitive_files = depset(ctx.attr._test_runner.data_runfiles.files)
+ runfiles = ctx.runfiles(
+ files = [test_runner] + ctx.files._init_script + ctx.files.scripts,
+ transitive_files = transitive_files,
+ collect_default = True,
+ collect_data = True,
+ )
+ return [DefaultInfo(executable = runner, runfiles = runfiles)]
+
+_packetdrill_test = rule(
+ attrs = {
+ "_test_runner": attr.label(
+ executable = True,
+ cfg = "host",
+ allow_files = True,
+ default = "packetdrill_test.sh",
+ ),
+ "_init_script": attr.label(
+ allow_single_file = True,
+ default = "packetdrill_setup.sh",
+ ),
+ "flags": attr.string_list(
+ mandatory = False,
+ default = [],
+ ),
+ "scripts": attr.label_list(
+ mandatory = True,
+ allow_files = True,
+ ),
+ },
+ test = True,
+ implementation = _packetdrill_test_impl,
+)
+
+_PACKETDRILL_TAGS = ["local", "manual"]
+
+def packetdrill_linux_test(name, **kwargs):
+ if "tags" not in kwargs:
+ kwargs["tags"] = _PACKETDRILL_TAGS
+ _packetdrill_test(
+ name = name + "_linux_test",
+ flags = ["--dut_platform", "linux"],
+ **kwargs
+ )
+
+def packetdrill_netstack_test(name, **kwargs):
+ if "tags" not in kwargs:
+ kwargs["tags"] = _PACKETDRILL_TAGS
+ _packetdrill_test(
+ name = name + "_netstack_test",
+ # This is the default runtime unless
+ # "--test_arg=--runtime=OTHER_RUNTIME" is used to override the value.
+ flags = ["--dut_platform", "netstack", "--runtime", "runsc-d"],
+ **kwargs
+ )
+
+def packetdrill_test(**kwargs):
+ packetdrill_linux_test(**kwargs)
+ packetdrill_netstack_test(**kwargs)
diff --git a/test/packetdrill/fin_wait2_timeout.pkt b/test/packetdrill/fin_wait2_timeout.pkt
new file mode 100644
index 000000000..613f0bec9
--- /dev/null
+++ b/test/packetdrill/fin_wait2_timeout.pkt
@@ -0,0 +1,23 @@
+// Test that a socket in FIN_WAIT_2 eventually times out and a subsequent
+// packet generates a RST.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 bind(3, ..., ...) = 0
+
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0 > S. 0:0(0) ack 1 <...>
++0 < P. 1:1(0) ack 1 win 257
+
++0.100 accept(3, ..., ...) = 4
+// set FIN_WAIT2 timeout to 1 seconds.
++0.100 setsockopt(4, SOL_TCP, TCP_LINGER2, [1], 4) = 0
++0 close(4) = 0
+
++0 > F. 1:1(0) ack 1 <...>
++0 < . 1:1(0) ack 2 win 257
+
++1.1 < . 1:1(0) ack 2 win 257
++0 > R 2:2(0) win 0
diff --git a/test/packetdrill/packetdrill_setup.sh b/test/packetdrill/packetdrill_setup.sh
new file mode 100755
index 000000000..b858072f0
--- /dev/null
+++ b/test/packetdrill/packetdrill_setup.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script runs both within the sentry context and natively. It should tweak
+# TCP parameters to match expectations found in the script files.
+sysctl -q net.ipv4.tcp_sack=1
+sysctl -q net.ipv4.tcp_rmem="4096 2097152 $((8*1024*1024))"
+sysctl -q net.ipv4.tcp_wmem="4096 2097152 $((8*1024*1024))"
+
+# There may be errors from the above, but they will show up in the test logs and
+# we always want to proceed from this point. It's possible that values were
+# already set correctly and the nodes were not available in the namespace.
+exit 0
diff --git a/test/packetdrill/packetdrill_test.sh b/test/packetdrill/packetdrill_test.sh
new file mode 100755
index 000000000..0b22dfd5c
--- /dev/null
+++ b/test/packetdrill/packetdrill_test.sh
@@ -0,0 +1,225 @@
+#!/bin/bash
+
+# Copyright 2020 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Run a packetdrill test. Two docker containers are made, one for the
+# Device-Under-Test (DUT) and one for the test runner. Each is attached with
+# two networks, one for control packets that aid the test and one for test
+# packets which are sent as part of the test and observed for correctness.
+
+set -euxo pipefail
+
+function failure() {
+ local lineno=$1
+ local msg=$2
+ local filename="$0"
+ echo "FAIL: $filename:$lineno: $msg"
+}
+trap 'failure ${LINENO} "$BASH_COMMAND"' ERR
+
+declare -r LONGOPTS="dut_platform:,init_script:,runtime:"
+
+# Don't use declare below so that the error from getopt will end the script.
+PARSED=$(getopt --options "" --longoptions=$LONGOPTS --name "$0" -- "$@")
+
+eval set -- "$PARSED"
+
+while true; do
+ case "$1" in
+ --dut_platform)
+ # Either "linux" or "netstack".
+ declare -r DUT_PLATFORM="$2"
+ shift 2
+ ;;
+ --init_script)
+ declare -r INIT_SCRIPT="$2"
+ shift 2
+ ;;
+ --runtime)
+ # Not readonly because there might be multiple --runtime arguments and we
+ # want to use just the last one. Only used if --dut_platform is
+ # "netstack".
+ declare RUNTIME="$2"
+ shift 2
+ ;;
+ --)
+ shift
+ break
+ ;;
+ *)
+ echo "Programming error"
+ exit 3
+ esac
+done
+
+# All the other arguments are scripts.
+declare -r scripts="$@"
+
+# Check that the required flags are defined in a way that is safe for "set -u".
+if [[ "${DUT_PLATFORM-}" == "netstack" ]]; then
+ if [[ -z "${RUNTIME-}" ]]; then
+ echo "FAIL: Missing --runtime argument: ${RUNTIME-}"
+ exit 2
+ fi
+ declare -r RUNTIME_ARG="--runtime ${RUNTIME}"
+elif [[ "${DUT_PLATFORM-}" == "linux" ]]; then
+ declare -r RUNTIME_ARG=""
+else
+ echo "FAIL: Bad or missing --dut_platform argument: ${DUT_PLATFORM-}"
+ exit 2
+fi
+if [[ ! -x "${INIT_SCRIPT-}" ]]; then
+ echo "FAIL: Bad or missing --init_script: ${INIT_SCRIPT-}"
+ exit 2
+fi
+
+# Variables specific to the control network and interface start with CTRL_.
+# Variables specific to the test network and interface start with TEST_.
+# Variables specific to the DUT start with DUT_.
+# Variables specific to the test runner start with TEST_RUNNER_.
+declare -r PACKETDRILL="/packetdrill/gtests/net/packetdrill/packetdrill"
+# Use random numbers so that test networks don't collide.
+declare -r CTRL_NET="ctrl_net-${RANDOM}${RANDOM}"
+declare -r TEST_NET="test_net-${RANDOM}${RANDOM}"
+declare -r tolerance_usecs=100000
+# On both DUT and test runner, testing packets are on the eth2 interface.
+declare -r TEST_DEVICE="eth2"
+# Number of bits in the *_NET_PREFIX variables.
+declare -r NET_MASK="24"
+function new_net_prefix() {
+ # Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24.
+ echo "$(shuf -i 192-223 -n 1).$(shuf -i 0-255 -n 1).$(shuf -i 0-255 -n 1)"
+}
+# Last bits of the DUT's IP address.
+declare -r DUT_NET_SUFFIX=".10"
+# Control port.
+declare -r CTRL_PORT="40000"
+# Last bits of the test runner's IP address.
+declare -r TEST_RUNNER_NET_SUFFIX=".20"
+declare -r TIMEOUT="60"
+declare -r IMAGE_TAG="gcr.io/gvisor-presubmit/packetdrill"
+
+# Make sure that docker is installed.
+docker --version
+
+function finish {
+ local cleanup_success=1
+ for net in "${CTRL_NET}" "${TEST_NET}"; do
+ # Kill all processes attached to ${net}.
+ for docker_command in "kill" "rm"; do
+ (docker network inspect "${net}" \
+ --format '{{range $key, $value := .Containers}}{{$key}} {{end}}' \
+ | xargs -r docker "${docker_command}") || \
+ cleanup_success=0
+ done
+ # Remove the network.
+ docker network rm "${net}" || \
+ cleanup_success=0
+ done
+
+ if ((!$cleanup_success)); then
+ echo "FAIL: Cleanup command failed"
+ exit 4
+ fi
+}
+trap finish EXIT
+
+# Subnet for control packets between test runner and DUT.
+declare CTRL_NET_PREFIX=$(new_net_prefix)
+while ! docker network create \
+ "--subnet=${CTRL_NET_PREFIX}.0/${NET_MASK}" "${CTRL_NET}"; do
+ sleep 0.1
+ declare CTRL_NET_PREFIX=$(new_net_prefix)
+done
+
+# Subnet for the packets that are part of the test.
+declare TEST_NET_PREFIX=$(new_net_prefix)
+while ! docker network create \
+ "--subnet=${TEST_NET_PREFIX}.0/${NET_MASK}" "${TEST_NET}"; do
+ sleep 0.1
+ declare TEST_NET_PREFIX=$(new_net_prefix)
+done
+
+docker pull "${IMAGE_TAG}"
+
+# Create the DUT container and connect to network.
+DUT=$(docker create ${RUNTIME_ARG} --privileged --rm \
+ --stop-timeout ${TIMEOUT} -it ${IMAGE_TAG})
+docker network connect "${CTRL_NET}" \
+ --ip "${CTRL_NET_PREFIX}${DUT_NET_SUFFIX}" "${DUT}" \
+ || (docker kill ${DUT}; docker rm ${DUT}; false)
+docker network connect "${TEST_NET}" \
+ --ip "${TEST_NET_PREFIX}${DUT_NET_SUFFIX}" "${DUT}" \
+ || (docker kill ${DUT}; docker rm ${DUT}; false)
+docker start "${DUT}"
+
+# Create the test runner container and connect to network.
+TEST_RUNNER=$(docker create --privileged --rm \
+ --stop-timeout ${TIMEOUT} -it ${IMAGE_TAG})
+docker network connect "${CTRL_NET}" \
+ --ip "${CTRL_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" "${TEST_RUNNER}" \
+ || (docker kill ${TEST_RUNNER}; docker rm ${REST_RUNNER}; false)
+docker network connect "${TEST_NET}" \
+ --ip "${TEST_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" "${TEST_RUNNER}" \
+ || (docker kill ${TEST_RUNNER}; docker rm ${REST_RUNNER}; false)
+docker start "${TEST_RUNNER}"
+
+# Run tcpdump in the test runner unbuffered, without dns resolution, just on the
+# interface with the test packets.
+docker exec -t ${TEST_RUNNER} tcpdump -U -n -i "${TEST_DEVICE}" &
+
+# Start a packetdrill server on the test_runner. The packetdrill server sends
+# packets and asserts that they are received.
+docker exec -d "${TEST_RUNNER}" \
+ ${PACKETDRILL} --wire_server --wire_server_dev="${TEST_DEVICE}" \
+ --wire_server_ip="${CTRL_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" \
+ --wire_server_port="${CTRL_PORT}" \
+ --local_ip="${TEST_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" \
+ --remote_ip="${TEST_NET_PREFIX}${DUT_NET_SUFFIX}"
+
+# Because the Linux kernel receives the SYN-ACK but didn't send the SYN it will
+# issue a RST. To prevent this IPtables can be used to filter those out.
+docker exec "${TEST_RUNNER}" \
+ iptables -A OUTPUT -p tcp --tcp-flags RST RST -j DROP
+
+# Wait for the packetdrill server on the test runner to come. Attempt to
+# connect to it from the DUT every 100 milliseconds until success.
+while ! docker exec "${DUT}" \
+ nc -zv "${CTRL_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" "${CTRL_PORT}"; do
+ sleep 0.1
+done
+
+# Copy the packetdrill setup script to the DUT.
+docker cp -L "${INIT_SCRIPT}" "${DUT}:packetdrill_setup.sh"
+
+# Copy the packetdrill scripts to the DUT.
+declare -a dut_scripts
+for script in $scripts; do
+ docker cp -L "${script}" "${DUT}:$(basename ${script})"
+ dut_scripts+=("/$(basename ${script})")
+done
+
+# Start a packetdrill client on the DUT. The packetdrill client runs POSIX
+# socket commands and also sends instructions to the server.
+docker exec -t "${DUT}" \
+ ${PACKETDRILL} --wire_client --wire_client_dev="${TEST_DEVICE}" \
+ --wire_server_ip="${CTRL_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" \
+ --wire_server_port="${CTRL_PORT}" \
+ --local_ip="${TEST_NET_PREFIX}${DUT_NET_SUFFIX}" \
+ --remote_ip="${TEST_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" \
+ --init_scripts=/packetdrill_setup.sh \
+ --tolerance_usecs="${tolerance_usecs}" "${dut_scripts[@]}"
+
+echo PASS: No errors.
diff --git a/test/perf/BUILD b/test/perf/BUILD
new file mode 100644
index 000000000..346a28e16
--- /dev/null
+++ b/test/perf/BUILD
@@ -0,0 +1,114 @@
+load("//test/runner:defs.bzl", "syscall_test")
+
+package(licenses = ["notice"])
+
+syscall_test(
+ test = "//test/perf/linux:clock_getres_benchmark",
+)
+
+syscall_test(
+ test = "//test/perf/linux:clock_gettime_benchmark",
+)
+
+syscall_test(
+ test = "//test/perf/linux:death_benchmark",
+)
+
+syscall_test(
+ test = "//test/perf/linux:epoll_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:fork_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:futex_benchmark",
+)
+
+syscall_test(
+ size = "enormous",
+ test = "//test/perf/linux:getdents_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:getpid_benchmark",
+)
+
+syscall_test(
+ size = "enormous",
+ test = "//test/perf/linux:gettid_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:mapping_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:open_benchmark",
+)
+
+syscall_test(
+ test = "//test/perf/linux:pipe_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:randread_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:read_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:sched_yield_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:send_recv_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:seqwrite_benchmark",
+)
+
+syscall_test(
+ size = "enormous",
+ test = "//test/perf/linux:signal_benchmark",
+)
+
+syscall_test(
+ test = "//test/perf/linux:sleep_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:stat_benchmark",
+)
+
+syscall_test(
+ size = "enormous",
+ add_overlay = True,
+ test = "//test/perf/linux:unlink_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:write_benchmark",
+)
diff --git a/test/perf/linux/BUILD b/test/perf/linux/BUILD
new file mode 100644
index 000000000..b4e907826
--- /dev/null
+++ b/test/perf/linux/BUILD
@@ -0,0 +1,356 @@
+load("//tools:defs.bzl", "cc_binary", "gbenchmark", "gtest")
+
+package(
+ default_visibility = ["//:sandbox"],
+ licenses = ["notice"],
+)
+
+cc_binary(
+ name = "getpid_benchmark",
+ testonly = 1,
+ srcs = [
+ "getpid_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "send_recv_benchmark",
+ testonly = 1,
+ srcs = [
+ "send_recv_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/syscalls/linux:socket_test_util",
+ "//test/util:file_descriptor",
+ "//test/util:logging",
+ "//test/util:posix_error",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+cc_binary(
+ name = "gettid_benchmark",
+ testonly = 1,
+ srcs = [
+ "gettid_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "sched_yield_benchmark",
+ testonly = 1,
+ srcs = [
+ "sched_yield_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "clock_getres_benchmark",
+ testonly = 1,
+ srcs = [
+ "clock_getres_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "clock_gettime_benchmark",
+ testonly = 1,
+ srcs = [
+ "clock_gettime_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:test_main",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_binary(
+ name = "open_benchmark",
+ testonly = 1,
+ srcs = [
+ "open_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:fs_util",
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "read_benchmark",
+ testonly = 1,
+ srcs = [
+ "read_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:fs_util",
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "randread_benchmark",
+ testonly = 1,
+ srcs = [
+ "randread_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:file_descriptor",
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/random",
+ ],
+)
+
+cc_binary(
+ name = "write_benchmark",
+ testonly = 1,
+ srcs = [
+ "write_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "seqwrite_benchmark",
+ testonly = 1,
+ srcs = [
+ "seqwrite_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/random",
+ ],
+)
+
+cc_binary(
+ name = "pipe_benchmark",
+ testonly = 1,
+ srcs = [
+ "pipe_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "fork_benchmark",
+ testonly = 1,
+ srcs = [
+ "fork_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:cleanup",
+ "//test/util:file_descriptor",
+ "//test/util:logging",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+cc_binary(
+ name = "futex_benchmark",
+ testonly = 1,
+ srcs = [
+ "futex_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_binary(
+ name = "epoll_benchmark",
+ testonly = 1,
+ srcs = [
+ "epoll_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:epoll_util",
+ "//test/util:file_descriptor",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_binary(
+ name = "death_benchmark",
+ testonly = 1,
+ srcs = [
+ "death_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "mapping_benchmark",
+ testonly = 1,
+ srcs = [
+ "mapping_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:memory_util",
+ "//test/util:posix_error",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "signal_benchmark",
+ testonly = 1,
+ srcs = [
+ "signal_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "getdents_benchmark",
+ testonly = 1,
+ srcs = [
+ "getdents_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "sleep_benchmark",
+ testonly = 1,
+ srcs = [
+ "sleep_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "stat_benchmark",
+ testonly = 1,
+ srcs = [
+ "stat_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:fs_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_binary(
+ name = "unlink_benchmark",
+ testonly = 1,
+ srcs = [
+ "unlink_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:fs_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
diff --git a/test/perf/linux/clock_getres_benchmark.cc b/test/perf/linux/clock_getres_benchmark.cc
new file mode 100644
index 000000000..b051293ad
--- /dev/null
+++ b/test/perf/linux/clock_getres_benchmark.cc
@@ -0,0 +1,39 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <time.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// clock_getres(1) is very nearly a no-op syscall, but it does require copying
+// out to a userspace struct. It thus provides a nice small copy-out benchmark.
+void BM_ClockGetRes(benchmark::State& state) {
+ struct timespec ts;
+ for (auto _ : state) {
+ clock_getres(CLOCK_MONOTONIC, &ts);
+ }
+}
+
+BENCHMARK(BM_ClockGetRes);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/clock_gettime_benchmark.cc b/test/perf/linux/clock_gettime_benchmark.cc
new file mode 100644
index 000000000..6691bebd9
--- /dev/null
+++ b/test/perf/linux/clock_gettime_benchmark.cc
@@ -0,0 +1,60 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <pthread.h>
+#include <time.h>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "benchmark/benchmark.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_ClockGettimeThreadCPUTime(benchmark::State& state) {
+ clockid_t clockid;
+ ASSERT_EQ(0, pthread_getcpuclockid(pthread_self(), &clockid));
+ struct timespec tp;
+
+ for (auto _ : state) {
+ clock_gettime(clockid, &tp);
+ }
+}
+
+BENCHMARK(BM_ClockGettimeThreadCPUTime);
+
+void BM_VDSOClockGettime(benchmark::State& state) {
+ const clockid_t clock = state.range(0);
+ struct timespec tp;
+ absl::Time start = absl::Now();
+
+ // Don't benchmark the calibration phase.
+ while (absl::Now() < start + absl::Milliseconds(2100)) {
+ clock_gettime(clock, &tp);
+ }
+
+ for (auto _ : state) {
+ clock_gettime(clock, &tp);
+ }
+}
+
+BENCHMARK(BM_VDSOClockGettime)->Arg(CLOCK_MONOTONIC)->Arg(CLOCK_REALTIME);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/death_benchmark.cc b/test/perf/linux/death_benchmark.cc
new file mode 100644
index 000000000..cb2b6fd07
--- /dev/null
+++ b/test/perf/linux/death_benchmark.cc
@@ -0,0 +1,36 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <signal.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// DeathTest is not so much a microbenchmark as a macrobenchmark. It is testing
+// the ability of gVisor (on whatever platform) to execute all the related
+// stack-dumping routines associated with EXPECT_EXIT / EXPECT_DEATH.
+TEST(DeathTest, ZeroEqualsOne) {
+ EXPECT_EXIT({ TEST_CHECK(0 == 1); }, ::testing::KilledBySignal(SIGABRT), "");
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/epoll_benchmark.cc b/test/perf/linux/epoll_benchmark.cc
new file mode 100644
index 000000000..0b121338a
--- /dev/null
+++ b/test/perf/linux/epoll_benchmark.cc
@@ -0,0 +1,99 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/epoll.h>
+#include <sys/eventfd.h>
+
+#include <atomic>
+#include <cerrno>
+#include <cstdint>
+#include <cstdlib>
+#include <ctime>
+#include <memory>
+
+#include "gtest/gtest.h"
+#include "absl/time/time.h"
+#include "benchmark/benchmark.h"
+#include "test/util/epoll_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Returns a new eventfd.
+PosixErrorOr<FileDescriptor> NewEventFD() {
+ int fd = eventfd(0, /* flags = */ 0);
+ MaybeSave();
+ if (fd < 0) {
+ return PosixError(errno, "eventfd");
+ }
+ return FileDescriptor(fd);
+}
+
+// Also stolen from epoll.cc unit tests.
+void BM_EpollTimeout(benchmark::State& state) {
+ constexpr int kFDsPerEpoll = 3;
+ auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+
+ std::vector<FileDescriptor> eventfds;
+ for (int i = 0; i < kFDsPerEpoll; i++) {
+ eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()));
+ ASSERT_NO_ERRNO(
+ RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, 0));
+ }
+
+ struct epoll_event result[kFDsPerEpoll];
+ int timeout_ms = state.range(0);
+
+ for (auto _ : state) {
+ EXPECT_EQ(0, epoll_wait(epollfd.get(), result, kFDsPerEpoll, timeout_ms));
+ }
+}
+
+BENCHMARK(BM_EpollTimeout)->Range(0, 8);
+
+// Also stolen from epoll.cc unit tests.
+void BM_EpollAllEvents(benchmark::State& state) {
+ auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+ const int fds_per_epoll = state.range(0);
+ constexpr uint64_t kEventVal = 5;
+
+ std::vector<FileDescriptor> eventfds;
+ for (int i = 0; i < fds_per_epoll; i++) {
+ eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()));
+ ASSERT_NO_ERRNO(
+ RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, 0));
+
+ ASSERT_THAT(WriteFd(eventfds[i].get(), &kEventVal, sizeof(kEventVal)),
+ SyscallSucceedsWithValue(sizeof(kEventVal)));
+ }
+
+ std::vector<struct epoll_event> result(fds_per_epoll);
+
+ for (auto _ : state) {
+ EXPECT_EQ(fds_per_epoll,
+ epoll_wait(epollfd.get(), result.data(), fds_per_epoll, 0));
+ }
+}
+
+BENCHMARK(BM_EpollAllEvents)->Range(2, 1024);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/fork_benchmark.cc b/test/perf/linux/fork_benchmark.cc
new file mode 100644
index 000000000..84fdbc8a0
--- /dev/null
+++ b/test/perf/linux/fork_benchmark.cc
@@ -0,0 +1,350 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/synchronization/barrier.h"
+#include "benchmark/benchmark.h"
+#include "test/util/cleanup.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/logging.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr int kBusyMax = 250;
+
+// Do some CPU-bound busy-work.
+int busy(int max) {
+ // Prevent the compiler from optimizing this work away,
+ volatile int count = 0;
+
+ for (int i = 1; i < max; i++) {
+ for (int j = 2; j < i / 2; j++) {
+ if (i % j == 0) {
+ count++;
+ }
+ }
+ }
+
+ return count;
+}
+
+void BM_CPUBoundUniprocess(benchmark::State& state) {
+ for (auto _ : state) {
+ busy(kBusyMax);
+ }
+}
+
+BENCHMARK(BM_CPUBoundUniprocess);
+
+void BM_CPUBoundAsymmetric(benchmark::State& state) {
+ const size_t max = state.max_iterations;
+ pid_t child = fork();
+ if (child == 0) {
+ for (int i = 0; i < max; i++) {
+ busy(kBusyMax);
+ }
+ _exit(0);
+ }
+ ASSERT_THAT(child, SyscallSucceeds());
+ ASSERT_TRUE(state.KeepRunningBatch(max));
+
+ int status;
+ EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status));
+ EXPECT_EQ(0, WEXITSTATUS(status));
+ ASSERT_FALSE(state.KeepRunning());
+}
+
+BENCHMARK(BM_CPUBoundAsymmetric)->UseRealTime();
+
+void BM_CPUBoundSymmetric(benchmark::State& state) {
+ std::vector<pid_t> children;
+ auto child_cleanup = Cleanup([&] {
+ for (const pid_t child : children) {
+ int status;
+ EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status));
+ EXPECT_EQ(0, WEXITSTATUS(status));
+ }
+ ASSERT_FALSE(state.KeepRunning());
+ });
+
+ const int processes = state.range(0);
+ for (int i = 0; i < processes; i++) {
+ size_t cur = (state.max_iterations + (processes - 1)) / processes;
+ if ((state.iterations() + cur) >= state.max_iterations) {
+ cur = state.max_iterations - state.iterations();
+ }
+ pid_t child = fork();
+ if (child == 0) {
+ for (int i = 0; i < cur; i++) {
+ busy(kBusyMax);
+ }
+ _exit(0);
+ }
+ ASSERT_THAT(child, SyscallSucceeds());
+ if (cur > 0) {
+ // We can have a zero cur here, depending.
+ ASSERT_TRUE(state.KeepRunningBatch(cur));
+ }
+ children.push_back(child);
+ }
+}
+
+BENCHMARK(BM_CPUBoundSymmetric)->Range(2, 16)->UseRealTime();
+
+// Child routine for ProcessSwitch/ThreadSwitch.
+// Reads from readfd and writes the result to writefd.
+void SwitchChild(int readfd, int writefd) {
+ while (1) {
+ char buf;
+ int ret = ReadFd(readfd, &buf, 1);
+ if (ret == 0) {
+ break;
+ }
+ TEST_CHECK_MSG(ret == 1, "read failed");
+
+ ret = WriteFd(writefd, &buf, 1);
+ if (ret == -1) {
+ TEST_CHECK_MSG(errno == EPIPE, "unexpected write failure");
+ break;
+ }
+ TEST_CHECK_MSG(ret == 1, "write failed");
+ }
+}
+
+// Send bytes in a loop through a series of pipes, each passing through a
+// different process.
+//
+// Proc 0 Proc 1
+// * ----------> *
+// ^ Pipe 1 |
+// | |
+// | Pipe 0 | Pipe 2
+// | |
+// | |
+// | Pipe 3 v
+// * <---------- *
+// Proc 3 Proc 2
+//
+// This exercises context switching through multiple processes.
+void BM_ProcessSwitch(benchmark::State& state) {
+ // Code below assumes there are at least two processes.
+ const int num_processes = state.range(0);
+ ASSERT_GE(num_processes, 2);
+
+ std::vector<pid_t> children;
+ auto child_cleanup = Cleanup([&] {
+ for (const pid_t child : children) {
+ int status;
+ EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status));
+ EXPECT_EQ(0, WEXITSTATUS(status));
+ }
+ });
+
+ // Must come after children, as the FDs must be closed before the children
+ // will exit.
+ std::vector<FileDescriptor> read_fds;
+ std::vector<FileDescriptor> write_fds;
+
+ for (int i = 0; i < num_processes; i++) {
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ read_fds.emplace_back(fds[0]);
+ write_fds.emplace_back(fds[1]);
+ }
+
+ // This process is one of the processes in the loop. It will be considered
+ // index 0.
+ for (int i = 1; i < num_processes; i++) {
+ // Read from current pipe index, write to next.
+ const int read_index = i;
+ const int read_fd = read_fds[read_index].get();
+
+ const int write_index = (i + 1) % num_processes;
+ const int write_fd = write_fds[write_index].get();
+
+ // std::vector isn't safe to use from the fork child.
+ FileDescriptor* read_array = read_fds.data();
+ FileDescriptor* write_array = write_fds.data();
+
+ pid_t child = fork();
+ if (!child) {
+ // Close all other FDs.
+ for (int j = 0; j < num_processes; j++) {
+ if (j != read_index) {
+ read_array[j].reset();
+ }
+ if (j != write_index) {
+ write_array[j].reset();
+ }
+ }
+
+ SwitchChild(read_fd, write_fd);
+ _exit(0);
+ }
+ ASSERT_THAT(child, SyscallSucceeds());
+ children.push_back(child);
+ }
+
+ // Read from current pipe index (0), write to next (1).
+ const int read_index = 0;
+ const int read_fd = read_fds[read_index].get();
+
+ const int write_index = 1;
+ const int write_fd = write_fds[write_index].get();
+
+ // Kick start the loop.
+ char buf = 'a';
+ ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1));
+
+ for (auto _ : state) {
+ ASSERT_THAT(ReadFd(read_fd, &buf, 1), SyscallSucceedsWithValue(1));
+ ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1));
+ }
+}
+
+BENCHMARK(BM_ProcessSwitch)->Range(2, 16)->UseRealTime();
+
+// Equivalent to BM_ThreadSwitch using threads instead of processes.
+void BM_ThreadSwitch(benchmark::State& state) {
+ // Code below assumes there are at least two threads.
+ const int num_threads = state.range(0);
+ ASSERT_GE(num_threads, 2);
+
+ // Must come after threads, as the FDs must be closed before the children
+ // will exit.
+ std::vector<std::unique_ptr<ScopedThread>> threads;
+ std::vector<FileDescriptor> read_fds;
+ std::vector<FileDescriptor> write_fds;
+
+ for (int i = 0; i < num_threads; i++) {
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ read_fds.emplace_back(fds[0]);
+ write_fds.emplace_back(fds[1]);
+ }
+
+ // This thread is one of the threads in the loop. It will be considered
+ // index 0.
+ for (int i = 1; i < num_threads; i++) {
+ // Read from current pipe index, write to next.
+ //
+ // Transfer ownership of the FDs to the thread.
+ const int read_index = i;
+ const int read_fd = read_fds[read_index].release();
+
+ const int write_index = (i + 1) % num_threads;
+ const int write_fd = write_fds[write_index].release();
+
+ threads.emplace_back(std::make_unique<ScopedThread>([read_fd, write_fd] {
+ FileDescriptor read(read_fd);
+ FileDescriptor write(write_fd);
+ SwitchChild(read.get(), write.get());
+ }));
+ }
+
+ // Read from current pipe index (0), write to next (1).
+ const int read_index = 0;
+ const int read_fd = read_fds[read_index].get();
+
+ const int write_index = 1;
+ const int write_fd = write_fds[write_index].get();
+
+ // Kick start the loop.
+ char buf = 'a';
+ ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1));
+
+ for (auto _ : state) {
+ ASSERT_THAT(ReadFd(read_fd, &buf, 1), SyscallSucceedsWithValue(1));
+ ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1));
+ }
+
+ // The two FDs still owned by this thread are closed, causing the next thread
+ // to exit its loop and close its FDs, and so on until all threads exit.
+}
+
+BENCHMARK(BM_ThreadSwitch)->Range(2, 16)->UseRealTime();
+
+void BM_ThreadStart(benchmark::State& state) {
+ const int num_threads = state.range(0);
+
+ for (auto _ : state) {
+ state.PauseTiming();
+
+ auto barrier = new absl::Barrier(num_threads + 1);
+ std::vector<std::unique_ptr<ScopedThread>> threads;
+
+ state.ResumeTiming();
+
+ for (size_t i = 0; i < num_threads; ++i) {
+ threads.emplace_back(std::make_unique<ScopedThread>([barrier] {
+ if (barrier->Block()) {
+ delete barrier;
+ }
+ }));
+ }
+
+ if (barrier->Block()) {
+ delete barrier;
+ }
+
+ state.PauseTiming();
+
+ for (const auto& thread : threads) {
+ thread->Join();
+ }
+
+ state.ResumeTiming();
+ }
+}
+
+BENCHMARK(BM_ThreadStart)->Range(1, 2048)->UseRealTime();
+
+// Benchmark the complete fork + exit + wait.
+void BM_ProcessLifecycle(benchmark::State& state) {
+ const int num_procs = state.range(0);
+
+ std::vector<pid_t> pids(num_procs);
+ for (auto _ : state) {
+ for (size_t i = 0; i < num_procs; ++i) {
+ int pid = fork();
+ if (pid == 0) {
+ _exit(0);
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+ pids[i] = pid;
+ }
+
+ for (const int pid : pids) {
+ ASSERT_THAT(RetryEINTR(waitpid)(pid, nullptr, 0),
+ SyscallSucceedsWithValue(pid));
+ }
+ }
+}
+
+BENCHMARK(BM_ProcessLifecycle)->Range(1, 512)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/futex_benchmark.cc b/test/perf/linux/futex_benchmark.cc
new file mode 100644
index 000000000..b349d50bf
--- /dev/null
+++ b/test/perf/linux/futex_benchmark.cc
@@ -0,0 +1,248 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <linux/futex.h>
+
+#include <atomic>
+#include <cerrno>
+#include <cstdint>
+#include <cstdlib>
+#include <ctime>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+inline int FutexWait(std::atomic<int32_t>* v, int32_t val) {
+ return syscall(SYS_futex, v, FUTEX_BITSET_MATCH_ANY, nullptr);
+}
+
+inline int FutexWaitRelativeTimeout(std::atomic<int32_t>* v, int32_t val,
+ const struct timespec* reltime) {
+ return syscall(SYS_futex, v, FUTEX_WAIT_PRIVATE, reltime);
+}
+
+inline int FutexWaitAbsoluteTimeout(std::atomic<int32_t>* v, int32_t val,
+ const struct timespec* abstime) {
+ return syscall(SYS_futex, v, FUTEX_BITSET_MATCH_ANY, abstime);
+}
+
+inline int FutexWaitBitsetAbsoluteTimeout(std::atomic<int32_t>* v, int32_t val,
+ int32_t bits,
+ const struct timespec* abstime) {
+ return syscall(SYS_futex, v, FUTEX_WAIT_BITSET_PRIVATE | FUTEX_CLOCK_REALTIME,
+ val, abstime, nullptr, bits);
+}
+
+inline int FutexWake(std::atomic<int32_t>* v, int32_t count) {
+ return syscall(SYS_futex, v, FUTEX_WAKE_PRIVATE, count);
+}
+
+// This just uses FUTEX_WAKE on an address with nothing waiting, very simple.
+void BM_FutexWakeNop(benchmark::State& state) {
+ std::atomic<int32_t> v(0);
+
+ for (auto _ : state) {
+ EXPECT_EQ(0, FutexWake(&v, 1));
+ }
+}
+
+BENCHMARK(BM_FutexWakeNop);
+
+// This just uses FUTEX_WAIT on an address whose value has changed, i.e., the
+// syscall won't wait.
+void BM_FutexWaitNop(benchmark::State& state) {
+ std::atomic<int32_t> v(0);
+
+ for (auto _ : state) {
+ EXPECT_EQ(-EAGAIN, FutexWait(&v, 1));
+ }
+}
+
+BENCHMARK(BM_FutexWaitNop);
+
+// This uses FUTEX_WAIT with a timeout on an address whose value never
+// changes, such that it always times out. Timeout overhead can be estimated by
+// timer overruns for short timeouts.
+void BM_FutexWaitTimeout(benchmark::State& state) {
+ const int timeout_ns = state.range(0);
+ std::atomic<int32_t> v(0);
+ auto ts = absl::ToTimespec(absl::Nanoseconds(timeout_ns));
+
+ for (auto _ : state) {
+ EXPECT_EQ(-ETIMEDOUT, FutexWaitRelativeTimeout(&v, 0, &ts));
+ }
+}
+
+BENCHMARK(BM_FutexWaitTimeout)
+ ->Arg(1)
+ ->Arg(10)
+ ->Arg(100)
+ ->Arg(1000)
+ ->Arg(10000);
+
+// This calls FUTEX_WAIT_BITSET with CLOCK_REALTIME.
+void BM_FutexWaitBitset(benchmark::State& state) {
+ std::atomic<int32_t> v(0);
+ int timeout_ns = state.range(0);
+ auto ts = absl::ToTimespec(absl::Nanoseconds(timeout_ns));
+ for (auto _ : state) {
+ EXPECT_EQ(-ETIMEDOUT, FutexWaitBitsetAbsoluteTimeout(&v, 0, 1, &ts));
+ }
+}
+
+BENCHMARK(BM_FutexWaitBitset)->Range(0, 100000);
+
+int64_t GetCurrentMonotonicTimeNanos() {
+ struct timespec ts;
+ TEST_CHECK(clock_gettime(CLOCK_MONOTONIC, &ts) != -1);
+ return ts.tv_sec * 1000000000ULL + ts.tv_nsec;
+}
+
+void SpinNanos(int64_t delay_ns) {
+ if (delay_ns <= 0) {
+ return;
+ }
+ const int64_t end = GetCurrentMonotonicTimeNanos() + delay_ns;
+ while (GetCurrentMonotonicTimeNanos() < end) {
+ // spin
+ }
+}
+
+// Each iteration of FutexRoundtripDelayed involves a thread sending a futex
+// wakeup to another thread, which spins for delay_us and then sends a futex
+// wakeup back. The time per iteration is 2* (delay_us + kBeforeWakeDelayNs +
+// futex/scheduling overhead).
+void BM_FutexRoundtripDelayed(benchmark::State& state) {
+ const int delay_us = state.range(0);
+
+ const int64_t delay_ns = delay_us * 1000;
+ // Spin for an extra kBeforeWakeDelayNs before invoking FUTEX_WAKE to reduce
+ // the probability that the wakeup comes before the wait, preventing the wait
+ // from ever taking effect and causing the benchmark to underestimate the
+ // actual wakeup time.
+ constexpr int64_t kBeforeWakeDelayNs = 500;
+ std::atomic<int32_t> v(0);
+ ScopedThread t([&] {
+ for (int i = 0; i < state.max_iterations; i++) {
+ SpinNanos(delay_ns);
+ while (v.load(std::memory_order_acquire) == 0) {
+ FutexWait(&v, 0);
+ }
+ SpinNanos(kBeforeWakeDelayNs + delay_ns);
+ v.store(0, std::memory_order_release);
+ FutexWake(&v, 1);
+ }
+ });
+ for (auto _ : state) {
+ SpinNanos(kBeforeWakeDelayNs + delay_ns);
+ v.store(1, std::memory_order_release);
+ FutexWake(&v, 1);
+ SpinNanos(delay_ns);
+ while (v.load(std::memory_order_acquire) == 1) {
+ FutexWait(&v, 1);
+ }
+ }
+}
+
+BENCHMARK(BM_FutexRoundtripDelayed)
+ ->Arg(0)
+ ->Arg(10)
+ ->Arg(20)
+ ->Arg(50)
+ ->Arg(100);
+
+// FutexLock is a simple, dumb futex based lock implementation.
+// It will try to acquire the lock by atomically incrementing the
+// lock word. If it did not increment the lock from 0 to 1, someone
+// else has the lock, so it will FUTEX_WAIT until it is woken in
+// the unlock path.
+class FutexLock {
+ public:
+ FutexLock() : lock_word_(0) {}
+
+ void lock(struct timespec* deadline) {
+ int32_t val;
+ while ((val = lock_word_.fetch_add(1, std::memory_order_acquire) + 1) !=
+ 1) {
+ // If we didn't get the lock by incrementing from 0 to 1,
+ // do a FUTEX_WAIT with the desired current value set to
+ // val. If val is no longer what the atomic increment returned,
+ // someone might have set it to 0 so we can try to acquire
+ // again.
+ int ret = FutexWaitAbsoluteTimeout(&lock_word_, val, deadline);
+ if (ret == 0 || ret == -EWOULDBLOCK || ret == -EINTR) {
+ continue;
+ } else {
+ FAIL() << "unexpected FUTEX_WAIT return: " << ret;
+ }
+ }
+ }
+
+ void unlock() {
+ // Store 0 into the lock word and wake one waiter. We intentionally
+ // ignore the return value of the FUTEX_WAKE here, since there may be
+ // no waiters to wake anyway.
+ lock_word_.store(0, std::memory_order_release);
+ (void)FutexWake(&lock_word_, 1);
+ }
+
+ private:
+ std::atomic<int32_t> lock_word_;
+};
+
+FutexLock* test_lock; // Used below.
+
+void FutexContend(benchmark::State& state, int thread_index,
+ struct timespec* deadline) {
+ int counter = 0;
+ if (thread_index == 0) {
+ test_lock = new FutexLock();
+ }
+ for (auto _ : state) {
+ test_lock->lock(deadline);
+ counter++;
+ test_lock->unlock();
+ }
+ if (thread_index == 0) {
+ delete test_lock;
+ }
+ state.SetItemsProcessed(state.iterations());
+}
+
+void BM_FutexContend(benchmark::State& state) {
+ FutexContend(state, state.thread_index, nullptr);
+}
+
+BENCHMARK(BM_FutexContend)->ThreadRange(1, 1024)->UseRealTime();
+
+void BM_FutexDeadlineContend(benchmark::State& state) {
+ auto deadline = absl::ToTimespec(absl::Now() + absl::Minutes(10));
+ FutexContend(state, state.thread_index, &deadline);
+}
+
+BENCHMARK(BM_FutexDeadlineContend)->ThreadRange(1, 1024)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/getdents_benchmark.cc b/test/perf/linux/getdents_benchmark.cc
new file mode 100644
index 000000000..afc599ad2
--- /dev/null
+++ b/test/perf/linux/getdents_benchmark.cc
@@ -0,0 +1,149 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+#ifndef SYS_getdents64
+#if defined(__x86_64__)
+#define SYS_getdents64 217
+#elif defined(__aarch64__)
+#define SYS_getdents64 217
+#else
+#error "Unknown architecture"
+#endif
+#endif // SYS_getdents64
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr int kBufferSize = 16384;
+
+PosixErrorOr<TempPath> CreateDirectory(int count,
+ std::vector<std::string>* files) {
+ ASSIGN_OR_RETURN_ERRNO(TempPath dir, TempPath::CreateDir());
+
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor dfd,
+ Open(dir.path(), O_RDONLY | O_DIRECTORY));
+
+ for (int i = 0; i < count; i++) {
+ auto file = NewTempRelPath();
+ auto res = MknodAt(dfd, file, S_IFREG | 0644, 0);
+ RETURN_IF_ERRNO(res);
+ files->push_back(file);
+ }
+
+ return std::move(dir);
+}
+
+PosixError CleanupDirectory(const TempPath& dir,
+ std::vector<std::string>* files) {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor dfd,
+ Open(dir.path(), O_RDONLY | O_DIRECTORY));
+
+ for (auto it = files->begin(); it != files->end(); ++it) {
+ auto res = UnlinkAt(dfd, *it, 0);
+ RETURN_IF_ERRNO(res);
+ }
+ return NoError();
+}
+
+// Creates a directory containing `files` files, and reads all the directory
+// entries from the directory using a single FD.
+void BM_GetdentsSameFD(benchmark::State& state) {
+ // Create directory with given files.
+ const int count = state.range(0);
+
+ // Keep a vector of all of the file TempPaths that is destroyed before dir.
+ //
+ // Normally, we'd simply allow dir to recursively clean up the contained
+ // files, but that recursive cleanup uses getdents, which may be very slow in
+ // extreme benchmarks.
+ TempPath dir;
+ std::vector<std::string> files;
+ dir = ASSERT_NO_ERRNO_AND_VALUE(CreateDirectory(count, &files));
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY));
+ char buffer[kBufferSize];
+
+ // We read all directory entries on each iteration, but report this as a
+ // "batch" iteration so that reported times are per file.
+ while (state.KeepRunningBatch(count)) {
+ ASSERT_THAT(lseek(fd.get(), 0, SEEK_SET), SyscallSucceeds());
+
+ int ret;
+ do {
+ ASSERT_THAT(ret = syscall(SYS_getdents64, fd.get(), buffer, kBufferSize),
+ SyscallSucceeds());
+ } while (ret > 0);
+ }
+
+ ASSERT_NO_ERRNO(CleanupDirectory(dir, &files));
+
+ state.SetItemsProcessed(state.iterations());
+}
+
+BENCHMARK(BM_GetdentsSameFD)->Range(1, 1 << 16)->UseRealTime();
+
+// Creates a directory containing `files` files, and reads all the directory
+// entries from the directory using a new FD each time.
+void BM_GetdentsNewFD(benchmark::State& state) {
+ // Create directory with given files.
+ const int count = state.range(0);
+
+ // Keep a vector of all of the file TempPaths that is destroyed before dir.
+ //
+ // Normally, we'd simply allow dir to recursively clean up the contained
+ // files, but that recursive cleanup uses getdents, which may be very slow in
+ // extreme benchmarks.
+ TempPath dir;
+ std::vector<std::string> files;
+ dir = ASSERT_NO_ERRNO_AND_VALUE(CreateDirectory(count, &files));
+ char buffer[kBufferSize];
+
+ // We read all directory entries on each iteration, but report this as a
+ // "batch" iteration so that reported times are per file.
+ while (state.KeepRunningBatch(count)) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY));
+
+ int ret;
+ do {
+ ASSERT_THAT(ret = syscall(SYS_getdents64, fd.get(), buffer, kBufferSize),
+ SyscallSucceeds());
+ } while (ret > 0);
+ }
+
+ ASSERT_NO_ERRNO(CleanupDirectory(dir, &files));
+
+ state.SetItemsProcessed(state.iterations());
+}
+
+BENCHMARK(BM_GetdentsNewFD)->Range(1, 1 << 12)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/getpid_benchmark.cc b/test/perf/linux/getpid_benchmark.cc
new file mode 100644
index 000000000..db74cb264
--- /dev/null
+++ b/test/perf/linux/getpid_benchmark.cc
@@ -0,0 +1,37 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/syscall.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Getpid(benchmark::State& state) {
+ for (auto _ : state) {
+ syscall(SYS_getpid);
+ }
+}
+
+BENCHMARK(BM_Getpid);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/gettid_benchmark.cc b/test/perf/linux/gettid_benchmark.cc
new file mode 100644
index 000000000..8f4961f5e
--- /dev/null
+++ b/test/perf/linux/gettid_benchmark.cc
@@ -0,0 +1,38 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Gettid(benchmark::State& state) {
+ for (auto _ : state) {
+ syscall(SYS_gettid);
+ }
+}
+
+BENCHMARK(BM_Gettid)->ThreadRange(1, 4000)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/mapping_benchmark.cc b/test/perf/linux/mapping_benchmark.cc
new file mode 100644
index 000000000..39c30fe69
--- /dev/null
+++ b/test/perf/linux/mapping_benchmark.cc
@@ -0,0 +1,163 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdlib.h>
+#include <sys/mman.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/memory_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Conservative value for /proc/sys/vm/max_map_count, which limits the number of
+// VMAs, minus a safety margin for VMAs that already exist for the test binary.
+// The default value for max_map_count is
+// include/linux/mm.h:DEFAULT_MAX_MAP_COUNT = 65530.
+constexpr size_t kMaxVMAs = 64001;
+
+// Map then unmap pages without touching them.
+void BM_MapUnmap(benchmark::State& state) {
+ // Number of pages to map.
+ const int pages = state.range(0);
+
+ while (state.KeepRunning()) {
+ void* addr = mmap(0, pages * kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+ TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed");
+
+ int ret = munmap(addr, pages * kPageSize);
+ TEST_CHECK_MSG(ret == 0, "munmap failed");
+ }
+}
+
+BENCHMARK(BM_MapUnmap)->Range(1, 1 << 17)->UseRealTime();
+
+// Map, touch, then unmap pages.
+void BM_MapTouchUnmap(benchmark::State& state) {
+ // Number of pages to map.
+ const int pages = state.range(0);
+
+ while (state.KeepRunning()) {
+ void* addr = mmap(0, pages * kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+ TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed");
+
+ char* c = reinterpret_cast<char*>(addr);
+ char* end = c + pages * kPageSize;
+ while (c < end) {
+ *c = 42;
+ c += kPageSize;
+ }
+
+ int ret = munmap(addr, pages * kPageSize);
+ TEST_CHECK_MSG(ret == 0, "munmap failed");
+ }
+}
+
+BENCHMARK(BM_MapTouchUnmap)->Range(1, 1 << 17)->UseRealTime();
+
+// Map and touch many pages, unmapping all at once.
+//
+// NOTE(b/111429208): This is a regression test to ensure performant mapping and
+// allocation even with tons of mappings.
+void BM_MapTouchMany(benchmark::State& state) {
+ // Number of pages to map.
+ const int page_count = state.range(0);
+
+ while (state.KeepRunning()) {
+ std::vector<void*> pages;
+
+ for (int i = 0; i < page_count; i++) {
+ void* addr = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+ TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed");
+
+ char* c = reinterpret_cast<char*>(addr);
+ *c = 42;
+
+ pages.push_back(addr);
+ }
+
+ for (void* addr : pages) {
+ int ret = munmap(addr, kPageSize);
+ TEST_CHECK_MSG(ret == 0, "munmap failed");
+ }
+ }
+
+ state.SetBytesProcessed(kPageSize * page_count * state.iterations());
+}
+
+BENCHMARK(BM_MapTouchMany)->Range(1, 1 << 12)->UseRealTime();
+
+void BM_PageFault(benchmark::State& state) {
+ // Map the region in which we will take page faults. To ensure that each page
+ // fault maps only a single page, each page we touch must correspond to a
+ // distinct VMA. Thus we need a 1-page gap between each 1-page VMA. However,
+ // each gap consists of a PROT_NONE VMA, instead of an unmapped hole, so that
+ // if there are background threads running, they can't inadvertently creating
+ // mappings in our gaps that are unmapped when the test ends.
+ size_t test_pages = kMaxVMAs;
+ // Ensure that test_pages is odd, since we want the test region to both
+ // begin and end with a mapped page.
+ if (test_pages % 2 == 0) {
+ test_pages--;
+ }
+ const size_t test_region_bytes = test_pages * kPageSize;
+ // Use MAP_SHARED here because madvise(MADV_DONTNEED) on private mappings on
+ // gVisor won't force future sentry page faults (by design). Use MAP_POPULATE
+ // so that Linux pre-allocates the shmem file used to back the mapping.
+ Mapping m = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(test_region_bytes, PROT_READ, MAP_SHARED | MAP_POPULATE));
+ for (size_t i = 0; i < test_pages / 2; i++) {
+ ASSERT_THAT(
+ mprotect(reinterpret_cast<void*>(m.addr() + ((2 * i + 1) * kPageSize)),
+ kPageSize, PROT_NONE),
+ SyscallSucceeds());
+ }
+
+ const size_t mapped_pages = test_pages / 2 + 1;
+ // "Start" at the end of the mapped region to force the mapped region to be
+ // reset, since we mapped it with MAP_POPULATE.
+ size_t cur_page = mapped_pages;
+ for (auto _ : state) {
+ if (cur_page >= mapped_pages) {
+ // We've reached the end of our mapped region and have to reset it to
+ // incur page faults again.
+ state.PauseTiming();
+ ASSERT_THAT(madvise(m.ptr(), test_region_bytes, MADV_DONTNEED),
+ SyscallSucceeds());
+ cur_page = 0;
+ state.ResumeTiming();
+ }
+ const uintptr_t addr = m.addr() + (2 * cur_page * kPageSize);
+ const char c = *reinterpret_cast<volatile char*>(addr);
+ benchmark::DoNotOptimize(c);
+ cur_page++;
+ }
+}
+
+BENCHMARK(BM_PageFault)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/open_benchmark.cc b/test/perf/linux/open_benchmark.cc
new file mode 100644
index 000000000..68008f6d5
--- /dev/null
+++ b/test/perf/linux/open_benchmark.cc
@@ -0,0 +1,56 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <unistd.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/fs_util.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Open(benchmark::State& state) {
+ const int size = state.range(0);
+ std::vector<TempPath> cache;
+ for (int i = 0; i < size; i++) {
+ auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ cache.emplace_back(std::move(path));
+ }
+
+ unsigned int seed = 1;
+ for (auto _ : state) {
+ const int chosen = rand_r(&seed) % size;
+ int fd = open(cache[chosen].path().c_str(), O_RDONLY);
+ TEST_CHECK(fd != -1);
+ close(fd);
+ }
+}
+
+BENCHMARK(BM_Open)->Range(1, 128)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/pipe_benchmark.cc b/test/perf/linux/pipe_benchmark.cc
new file mode 100644
index 000000000..8f5f6a2a3
--- /dev/null
+++ b/test/perf/linux/pipe_benchmark.cc
@@ -0,0 +1,66 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include <cerrno>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Pipe(benchmark::State& state) {
+ int fds[2];
+ TEST_CHECK(pipe(fds) == 0);
+
+ const int size = state.range(0);
+ std::vector<char> wbuf(size);
+ std::vector<char> rbuf(size);
+ RandomizeBuffer(wbuf.data(), size);
+
+ ScopedThread t([&] {
+ auto const fd = fds[1];
+ for (int i = 0; i < state.max_iterations; i++) {
+ TEST_CHECK(WriteFd(fd, wbuf.data(), wbuf.size()) == size);
+ }
+ });
+
+ for (auto _ : state) {
+ TEST_CHECK(ReadFd(fds[0], rbuf.data(), rbuf.size()) == size);
+ }
+
+ t.Join();
+
+ close(fds[0]);
+ close(fds[1]);
+
+ state.SetBytesProcessed(static_cast<int64_t>(size) *
+ static_cast<int64_t>(state.iterations()));
+}
+
+BENCHMARK(BM_Pipe)->Range(1, 1 << 20)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/randread_benchmark.cc b/test/perf/linux/randread_benchmark.cc
new file mode 100644
index 000000000..b0eb8c24e
--- /dev/null
+++ b/test/perf/linux/randread_benchmark.cc
@@ -0,0 +1,100 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <sys/uio.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Create a 1GB file that will be read from at random positions. This should
+// invalid any performance gains from caching.
+const uint64_t kFileSize = 1ULL << 30;
+
+// How many bytes to write at once to initialize the file used to read from.
+const uint32_t kWriteSize = 65536;
+
+// Largest benchmarked read unit.
+const uint32_t kMaxRead = 1UL << 26;
+
+TempPath CreateFile(uint64_t file_size) {
+ auto path = TempPath::CreateFile().ValueOrDie();
+ FileDescriptor fd = Open(path.path(), O_WRONLY).ValueOrDie();
+
+ // Try to minimize syscalls by using maximum size writev() requests.
+ std::vector<char> buffer(kWriteSize);
+ RandomizeBuffer(buffer.data(), buffer.size());
+ const std::vector<std::vector<struct iovec>> iovecs_list =
+ GenerateIovecs(file_size, buffer.data(), buffer.size());
+ for (const auto& iovecs : iovecs_list) {
+ TEST_CHECK(writev(fd.get(), iovecs.data(), iovecs.size()) >= 0);
+ }
+
+ return path;
+}
+
+// Global test state, initialized once per process lifetime.
+struct GlobalState {
+ const TempPath tmpfile;
+ explicit GlobalState(TempPath tfile) : tmpfile(std::move(tfile)) {}
+};
+
+GlobalState& GetGlobalState() {
+ // This gets created only once throughout the lifetime of the process.
+ // Use a dynamically allocated object (that is never deleted) to avoid order
+ // of destruction of static storage variables issues.
+ static GlobalState* const state =
+ // The actual file size is the maximum random seek range (kFileSize) + the
+ // maximum read size so we can read that number of bytes at the end of the
+ // file.
+ new GlobalState(CreateFile(kFileSize + kMaxRead));
+ return *state;
+}
+
+void BM_RandRead(benchmark::State& state) {
+ const int size = state.range(0);
+
+ GlobalState& global_state = GetGlobalState();
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(global_state.tmpfile.path(), O_RDONLY));
+ std::vector<char> buf(size);
+
+ unsigned int seed = 1;
+ for (auto _ : state) {
+ TEST_CHECK(PreadFd(fd.get(), buf.data(), buf.size(),
+ rand_r(&seed) % kFileSize) == size);
+ }
+
+ state.SetBytesProcessed(static_cast<int64_t>(size) *
+ static_cast<int64_t>(state.iterations()));
+}
+
+BENCHMARK(BM_RandRead)->Range(1, kMaxRead)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/read_benchmark.cc b/test/perf/linux/read_benchmark.cc
new file mode 100644
index 000000000..62445867d
--- /dev/null
+++ b/test/perf/linux/read_benchmark.cc
@@ -0,0 +1,53 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/fs_util.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Read(benchmark::State& state) {
+ const int size = state.range(0);
+ const std::string contents(size, 0);
+ auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), contents, TempPath::kDefaultFileMode));
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_RDONLY));
+
+ std::vector<char> buf(size);
+ for (auto _ : state) {
+ TEST_CHECK(PreadFd(fd.get(), buf.data(), buf.size(), 0) == size);
+ }
+
+ state.SetBytesProcessed(static_cast<int64_t>(size) *
+ static_cast<int64_t>(state.iterations()));
+}
+
+BENCHMARK(BM_Read)->Range(1, 1 << 26)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/sched_yield_benchmark.cc b/test/perf/linux/sched_yield_benchmark.cc
new file mode 100644
index 000000000..6756b5575
--- /dev/null
+++ b/test/perf/linux/sched_yield_benchmark.cc
@@ -0,0 +1,37 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sched.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Sched_yield(benchmark::State& state) {
+ for (auto ignored : state) {
+ TEST_CHECK(sched_yield() == 0);
+ }
+}
+
+BENCHMARK(BM_Sched_yield)->ThreadRange(1, 2000)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/send_recv_benchmark.cc b/test/perf/linux/send_recv_benchmark.cc
new file mode 100644
index 000000000..d73e49523
--- /dev/null
+++ b/test/perf/linux/send_recv_benchmark.cc
@@ -0,0 +1,372 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <poll.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+
+#include <cstring>
+
+#include "gtest/gtest.h"
+#include "absl/synchronization/notification.h"
+#include "benchmark/benchmark.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/logging.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr ssize_t kMessageSize = 1024;
+
+class Message {
+ public:
+ explicit Message(int byte = 0) : Message(byte, kMessageSize, 0) {}
+
+ explicit Message(int byte, int sz) : Message(byte, sz, 0) {}
+
+ explicit Message(int byte, int sz, int cmsg_sz)
+ : buffer_(sz, byte), cmsg_buffer_(cmsg_sz, 0) {
+ iov_.iov_base = buffer_.data();
+ iov_.iov_len = sz;
+ hdr_.msg_iov = &iov_;
+ hdr_.msg_iovlen = 1;
+ hdr_.msg_control = cmsg_buffer_.data();
+ hdr_.msg_controllen = cmsg_sz;
+ }
+
+ struct msghdr* header() {
+ return &hdr_;
+ }
+
+ private:
+ std::vector<char> buffer_;
+ std::vector<char> cmsg_buffer_;
+ struct iovec iov_ = {};
+ struct msghdr hdr_ = {};
+};
+
+void BM_Recvmsg(benchmark::State& state) {
+ int sockets[2];
+ TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0);
+ FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]);
+ absl::Notification notification;
+ Message send_msg('a'), recv_msg;
+
+ ScopedThread t([&send_msg, &send_socket, &notification] {
+ while (!notification.HasBeenNotified()) {
+ sendmsg(send_socket.get(), send_msg.header(), 0);
+ }
+ });
+
+ int64_t bytes_received = 0;
+ for (auto ignored : state) {
+ int n = recvmsg(recv_socket.get(), recv_msg.header(), 0);
+ TEST_CHECK(n > 0);
+ bytes_received += n;
+ }
+
+ notification.Notify();
+ recv_socket.reset();
+
+ state.SetBytesProcessed(bytes_received);
+}
+
+BENCHMARK(BM_Recvmsg)->UseRealTime();
+
+void BM_Sendmsg(benchmark::State& state) {
+ int sockets[2];
+ TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0);
+ FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]);
+ absl::Notification notification;
+ Message send_msg('a'), recv_msg;
+
+ ScopedThread t([&recv_msg, &recv_socket, &notification] {
+ while (!notification.HasBeenNotified()) {
+ recvmsg(recv_socket.get(), recv_msg.header(), 0);
+ }
+ });
+
+ int64_t bytes_sent = 0;
+ for (auto ignored : state) {
+ int n = sendmsg(send_socket.get(), send_msg.header(), 0);
+ TEST_CHECK(n > 0);
+ bytes_sent += n;
+ }
+
+ notification.Notify();
+ send_socket.reset();
+
+ state.SetBytesProcessed(bytes_sent);
+}
+
+BENCHMARK(BM_Sendmsg)->UseRealTime();
+
+void BM_Recvfrom(benchmark::State& state) {
+ int sockets[2];
+ TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0);
+ FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]);
+ absl::Notification notification;
+ char send_buffer[kMessageSize], recv_buffer[kMessageSize];
+
+ ScopedThread t([&send_socket, &send_buffer, &notification] {
+ while (!notification.HasBeenNotified()) {
+ sendto(send_socket.get(), send_buffer, kMessageSize, 0, nullptr, 0);
+ }
+ });
+
+ int bytes_received = 0;
+ for (auto ignored : state) {
+ int n = recvfrom(recv_socket.get(), recv_buffer, kMessageSize, 0, nullptr,
+ nullptr);
+ TEST_CHECK(n > 0);
+ bytes_received += n;
+ }
+
+ notification.Notify();
+ recv_socket.reset();
+
+ state.SetBytesProcessed(bytes_received);
+}
+
+BENCHMARK(BM_Recvfrom)->UseRealTime();
+
+void BM_Sendto(benchmark::State& state) {
+ int sockets[2];
+ TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0);
+ FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]);
+ absl::Notification notification;
+ char send_buffer[kMessageSize], recv_buffer[kMessageSize];
+
+ ScopedThread t([&recv_socket, &recv_buffer, &notification] {
+ while (!notification.HasBeenNotified()) {
+ recvfrom(recv_socket.get(), recv_buffer, kMessageSize, 0, nullptr,
+ nullptr);
+ }
+ });
+
+ int64_t bytes_sent = 0;
+ for (auto ignored : state) {
+ int n = sendto(send_socket.get(), send_buffer, kMessageSize, 0, nullptr, 0);
+ TEST_CHECK(n > 0);
+ bytes_sent += n;
+ }
+
+ notification.Notify();
+ send_socket.reset();
+
+ state.SetBytesProcessed(bytes_sent);
+}
+
+BENCHMARK(BM_Sendto)->UseRealTime();
+
+PosixErrorOr<sockaddr_storage> InetLoopbackAddr(int family) {
+ struct sockaddr_storage addr;
+ memset(&addr, 0, sizeof(addr));
+ addr.ss_family = family;
+ switch (family) {
+ case AF_INET:
+ reinterpret_cast<struct sockaddr_in*>(&addr)->sin_addr.s_addr =
+ htonl(INADDR_LOOPBACK);
+ break;
+ case AF_INET6:
+ reinterpret_cast<struct sockaddr_in6*>(&addr)->sin6_addr =
+ in6addr_loopback;
+ break;
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+ return addr;
+}
+
+// BM_RecvmsgWithControlBuf measures the performance of recvmsg when we allocate
+// space for control messages. Note that we do not expect to receive any.
+void BM_RecvmsgWithControlBuf(benchmark::State& state) {
+ auto listen_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP));
+
+ // Initialize address to the loopback one.
+ sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET6));
+ socklen_t addrlen = sizeof(addr);
+
+ // Bind to some port then start listening.
+ ASSERT_THAT(bind(listen_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the address we're listening on, then connect to it. We need to do this
+ // because we're allowing the stack to pick a port for us.
+ ASSERT_THAT(getsockname(listen_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+
+ auto send_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP));
+
+ ASSERT_THAT(
+ RetryEINTR(connect)(send_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ auto recv_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr));
+
+ absl::Notification notification;
+ Message send_msg('a');
+ // Create a msghdr with a buffer allocated for control messages.
+ Message recv_msg(0, kMessageSize, /*cmsg_sz=*/24);
+
+ ScopedThread t([&send_msg, &send_socket, &notification] {
+ while (!notification.HasBeenNotified()) {
+ sendmsg(send_socket.get(), send_msg.header(), 0);
+ }
+ });
+
+ int64_t bytes_received = 0;
+ for (auto ignored : state) {
+ int n = recvmsg(recv_socket.get(), recv_msg.header(), 0);
+ TEST_CHECK(n > 0);
+ bytes_received += n;
+ }
+
+ notification.Notify();
+ recv_socket.reset();
+
+ state.SetBytesProcessed(bytes_received);
+}
+
+BENCHMARK(BM_RecvmsgWithControlBuf)->UseRealTime();
+
+// BM_SendmsgTCP measures the sendmsg throughput with varying payload sizes.
+//
+// state.Args[0] indicates whether the underlying socket should be blocking or
+// non-blocking w/ 0 indicating non-blocking and 1 to indicate blocking.
+// state.Args[1] is the size of the payload to be used per sendmsg call.
+void BM_SendmsgTCP(benchmark::State& state) {
+ auto listen_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
+
+ // Initialize address to the loopback one.
+ sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET));
+ socklen_t addrlen = sizeof(addr);
+
+ // Bind to some port then start listening.
+ ASSERT_THAT(bind(listen_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the address we're listening on, then connect to it. We need to do this
+ // because we're allowing the stack to pick a port for us.
+ ASSERT_THAT(getsockname(listen_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+
+ auto send_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
+
+ ASSERT_THAT(
+ RetryEINTR(connect)(send_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ auto recv_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr));
+
+ // Check if we want to run the test w/ a blocking send socket
+ // or non-blocking.
+ const int blocking = state.range(0);
+ if (!blocking) {
+ // Set the send FD to O_NONBLOCK.
+ int opts;
+ ASSERT_THAT(opts = fcntl(send_socket.get(), F_GETFL), SyscallSucceeds());
+ opts |= O_NONBLOCK;
+ ASSERT_THAT(fcntl(send_socket.get(), F_SETFL, opts), SyscallSucceeds());
+ }
+
+ absl::Notification notification;
+
+ // Get the buffer size we should use for this iteration of the test.
+ const int buf_size = state.range(1);
+ Message send_msg('a', buf_size), recv_msg(0, buf_size);
+
+ ScopedThread t([&recv_msg, &recv_socket, &notification] {
+ while (!notification.HasBeenNotified()) {
+ TEST_CHECK(recvmsg(recv_socket.get(), recv_msg.header(), 0) >= 0);
+ }
+ });
+
+ int64_t bytes_sent = 0;
+ int ncalls = 0;
+ for (auto ignored : state) {
+ int sent = 0;
+ while (true) {
+ struct msghdr hdr = {};
+ struct iovec iov = {};
+ struct msghdr* snd_header = send_msg.header();
+ iov.iov_base = static_cast<char*>(snd_header->msg_iov->iov_base) + sent;
+ iov.iov_len = snd_header->msg_iov->iov_len - sent;
+ hdr.msg_iov = &iov;
+ hdr.msg_iovlen = 1;
+ int n = RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0);
+ ncalls++;
+ if (n > 0) {
+ sent += n;
+ if (sent == buf_size) {
+ break;
+ }
+ // n can be > 0 but less than requested size. In which case we don't
+ // poll.
+ continue;
+ }
+ // Poll the fd for it to become writable.
+ struct pollfd poll_fd = {send_socket.get(), POLL_OUT, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10),
+ SyscallSucceedsWithValue(0));
+ }
+ bytes_sent += static_cast<int64_t>(sent);
+ }
+
+ notification.Notify();
+ send_socket.reset();
+ state.SetBytesProcessed(bytes_sent);
+}
+
+void Args(benchmark::internal::Benchmark* benchmark) {
+ for (int blocking = 0; blocking < 2; blocking++) {
+ for (int buf_size = 1024; buf_size <= 256 << 20; buf_size *= 2) {
+ benchmark->Args({blocking, buf_size});
+ }
+ }
+}
+
+BENCHMARK(BM_SendmsgTCP)->Apply(&Args)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/seqwrite_benchmark.cc b/test/perf/linux/seqwrite_benchmark.cc
new file mode 100644
index 000000000..af49e4477
--- /dev/null
+++ b/test/perf/linux/seqwrite_benchmark.cc
@@ -0,0 +1,66 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// The maximum file size of the test file, when writes get beyond this point
+// they wrap around. This should be large enough to blow away caches.
+const uint64_t kMaxFile = 1 << 30;
+
+// Perform writes of various sizes sequentially to one file. Wraps around if it
+// goes above a certain maximum file size.
+void BM_SeqWrite(benchmark::State& state) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_WRONLY));
+
+ const int size = state.range(0);
+ std::vector<char> buf(size);
+ RandomizeBuffer(buf.data(), buf.size());
+
+ // Start writes at offset 0.
+ uint64_t offset = 0;
+ for (auto _ : state) {
+ TEST_CHECK(PwriteFd(fd.get(), buf.data(), buf.size(), offset) ==
+ buf.size());
+ offset += buf.size();
+ // Wrap around if going above the maximum file size.
+ if (offset >= kMaxFile) {
+ offset = 0;
+ }
+ }
+
+ state.SetBytesProcessed(static_cast<int64_t>(size) *
+ static_cast<int64_t>(state.iterations()));
+}
+
+BENCHMARK(BM_SeqWrite)->Range(1, 1 << 26)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/signal_benchmark.cc b/test/perf/linux/signal_benchmark.cc
new file mode 100644
index 000000000..a6928df58
--- /dev/null
+++ b/test/perf/linux/signal_benchmark.cc
@@ -0,0 +1,59 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <signal.h>
+#include <string.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void FixupHandler(int sig, siginfo_t* si, void* void_ctx) {
+ static unsigned int dataval = 0;
+
+ // Skip the offending instruction.
+ ucontext_t* ctx = reinterpret_cast<ucontext_t*>(void_ctx);
+ ctx->uc_mcontext.gregs[REG_RAX] = reinterpret_cast<greg_t>(&dataval);
+}
+
+void BM_FaultSignalFixup(benchmark::State& state) {
+ // Set up the signal handler.
+ struct sigaction sa = {};
+ sigemptyset(&sa.sa_mask);
+ sa.sa_sigaction = FixupHandler;
+ sa.sa_flags = SA_SIGINFO;
+ TEST_CHECK(sigaction(SIGSEGV, &sa, nullptr) == 0);
+
+ // Fault, fault, fault.
+ for (auto _ : state) {
+ register volatile unsigned int* ptr asm("rax");
+
+ // Trigger the segfault.
+ ptr = nullptr;
+ *ptr = 0;
+ }
+}
+
+BENCHMARK(BM_FaultSignalFixup)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/sleep_benchmark.cc b/test/perf/linux/sleep_benchmark.cc
new file mode 100644
index 000000000..99ef05117
--- /dev/null
+++ b/test/perf/linux/sleep_benchmark.cc
@@ -0,0 +1,60 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <sys/syscall.h>
+#include <time.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Sleep for 'param' nanoseconds.
+void BM_Sleep(benchmark::State& state) {
+ const int nanoseconds = state.range(0);
+
+ for (auto _ : state) {
+ struct timespec ts;
+ ts.tv_sec = 0;
+ ts.tv_nsec = nanoseconds;
+
+ int ret;
+ do {
+ ret = syscall(SYS_nanosleep, &ts, &ts);
+ if (ret < 0) {
+ TEST_CHECK(errno == EINTR);
+ }
+ } while (ret < 0);
+ }
+}
+
+BENCHMARK(BM_Sleep)
+ ->Arg(0)
+ ->Arg(1)
+ ->Arg(1000) // 1us
+ ->Arg(1000 * 1000) // 1ms
+ ->Arg(10 * 1000 * 1000) // 10ms
+ ->Arg(50 * 1000 * 1000) // 50ms
+ ->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/stat_benchmark.cc b/test/perf/linux/stat_benchmark.cc
new file mode 100644
index 000000000..f15424482
--- /dev/null
+++ b/test/perf/linux/stat_benchmark.cc
@@ -0,0 +1,62 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "benchmark/benchmark.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Creates a file in a nested directory hierarchy at least `depth` directories
+// deep, and stats that file multiple times.
+void BM_Stat(benchmark::State& state) {
+ // Create nested directories with given depth.
+ int depth = state.range(0);
+ const TempPath top_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ std::string dir_path = top_dir.path();
+
+ while (depth-- > 0) {
+ // Don't use TempPath because it will make paths too long to use.
+ //
+ // The top_dir destructor will clean up this whole tree.
+ dir_path = JoinPath(dir_path, absl::StrCat(depth));
+ ASSERT_NO_ERRNO(Mkdir(dir_path, 0755));
+ }
+
+ // Create the file that will be stat'd.
+ const TempPath file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir_path));
+
+ struct stat st;
+ for (auto _ : state) {
+ ASSERT_THAT(stat(file.path().c_str(), &st), SyscallSucceeds());
+ }
+}
+
+BENCHMARK(BM_Stat)->Range(1, 100)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/unlink_benchmark.cc b/test/perf/linux/unlink_benchmark.cc
new file mode 100644
index 000000000..92243a042
--- /dev/null
+++ b/test/perf/linux/unlink_benchmark.cc
@@ -0,0 +1,66 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Creates a directory containing `files` files, and unlinks all the files.
+void BM_Unlink(benchmark::State& state) {
+ // Create directory with given files.
+ const int file_count = state.range(0);
+
+ // We unlink all files on each iteration, but report this as a "batch"
+ // iteration so that reported times are per file.
+ TempPath dir;
+ while (state.KeepRunningBatch(file_count)) {
+ state.PauseTiming();
+ // N.B. dir is declared outside the loop so that destruction of the previous
+ // iteration's directory occurs here, inside of PauseTiming.
+ dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ std::vector<TempPath> files;
+ for (int i = 0; i < file_count; i++) {
+ TempPath file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path()));
+ files.push_back(std::move(file));
+ }
+ state.ResumeTiming();
+
+ while (!files.empty()) {
+ // Destructor unlinks.
+ files.pop_back();
+ }
+ }
+
+ state.SetItemsProcessed(state.iterations());
+}
+
+BENCHMARK(BM_Unlink)->Range(1, 100 * 1000)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/write_benchmark.cc b/test/perf/linux/write_benchmark.cc
new file mode 100644
index 000000000..7b060c70e
--- /dev/null
+++ b/test/perf/linux/write_benchmark.cc
@@ -0,0 +1,52 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Write(benchmark::State& state) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_WRONLY));
+
+ const int size = state.range(0);
+ std::vector<char> buf(size);
+ RandomizeBuffer(buf.data(), size);
+
+ for (auto _ : state) {
+ TEST_CHECK(PwriteFd(fd.get(), buf.data(), size, 0) == size);
+ }
+
+ state.SetBytesProcessed(static_cast<int64_t>(size) *
+ static_cast<int64_t>(state.iterations()));
+}
+
+BENCHMARK(BM_Write)->Range(1, 1 << 26)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/root/testdata/BUILD b/test/root/testdata/BUILD
index bca5f9cab..6859541ad 100644
--- a/test/root/testdata/BUILD
+++ b/test/root/testdata/BUILD
@@ -13,6 +13,6 @@ go_library(
"simple.go",
],
visibility = [
- "//visibility:public",
+ "//:sandbox",
],
)
diff --git a/test/runner/BUILD b/test/runner/BUILD
new file mode 100644
index 000000000..9959ef9b0
--- /dev/null
+++ b/test/runner/BUILD
@@ -0,0 +1,22 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "runner",
+ testonly = 1,
+ srcs = ["runner.go"],
+ data = [
+ "//runsc",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/log",
+ "//runsc/specutils",
+ "//runsc/testutil",
+ "//test/runner/gtest",
+ "//test/uds",
+ "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl
new file mode 100644
index 000000000..56743a526
--- /dev/null
+++ b/test/runner/defs.bzl
@@ -0,0 +1,198 @@
+"""Defines a rule for syscall test targets."""
+
+load("//tools:defs.bzl", "default_platform", "loopback", "platforms")
+
+def _runner_test_impl(ctx):
+ # Generate a runner binary.
+ runner = ctx.actions.declare_file("%s-runner" % ctx.label.name)
+ runner_content = "\n".join([
+ "#!/bin/bash",
+ "set -euf -x -o pipefail",
+ "if [[ -n \"${TEST_UNDECLARED_OUTPUTS_DIR}\" ]]; then",
+ " mkdir -p \"${TEST_UNDECLARED_OUTPUTS_DIR}\"",
+ " chmod a+rwx \"${TEST_UNDECLARED_OUTPUTS_DIR}\"",
+ "fi",
+ "exec %s %s %s\n" % (
+ ctx.files.runner[0].short_path,
+ " ".join(ctx.attr.runner_args),
+ ctx.files.test[0].short_path,
+ ),
+ ])
+ ctx.actions.write(runner, runner_content, is_executable = True)
+
+ # Return with all transitive files.
+ runfiles = ctx.runfiles(
+ transitive_files = depset(transitive = [
+ depset(target.data_runfiles.files)
+ for target in (ctx.attr.runner, ctx.attr.test)
+ if hasattr(target, "data_runfiles")
+ ]),
+ files = ctx.files.runner + ctx.files.test,
+ collect_default = True,
+ collect_data = True,
+ )
+ return [DefaultInfo(executable = runner, runfiles = runfiles)]
+
+_runner_test = rule(
+ attrs = {
+ "runner": attr.label(
+ default = "//test/runner:runner",
+ ),
+ "test": attr.label(
+ mandatory = True,
+ ),
+ "runner_args": attr.string_list(),
+ "data": attr.label_list(
+ allow_files = True,
+ ),
+ },
+ test = True,
+ implementation = _runner_test_impl,
+)
+
+def _syscall_test(
+ test,
+ shard_count,
+ size,
+ platform,
+ use_tmpfs,
+ tags,
+ network = "none",
+ file_access = "exclusive",
+ overlay = False,
+ add_uds_tree = False):
+ # Prepend "runsc" to non-native platform names.
+ full_platform = platform if platform == "native" else "runsc_" + platform
+
+ # Name the test appropriately.
+ name = test.split(":")[1] + "_" + full_platform
+ if file_access == "shared":
+ name += "_shared"
+ if overlay:
+ name += "_overlay"
+ if network != "none":
+ name += "_" + network + "net"
+
+ # Apply all tags.
+ if tags == None:
+ tags = []
+
+ # Add the full_platform and file access in a tag to make it easier to run
+ # all the tests on a specific flavor. Use --test_tag_filters=ptrace,file_shared.
+ tags += [full_platform, "file_" + file_access]
+
+ # Hash this target into one of 15 buckets. This can be used to
+ # randomly split targets between different workflows.
+ hash15 = hash(native.package_name() + name) % 15
+ tags.append("hash15:" + str(hash15))
+
+ # TODO(b/139838000): Tests using hostinet must be disabled on Guitar until
+ # we figure out how to request ipv4 sockets on Guitar machines.
+ if network == "host":
+ tags.append("noguitar")
+
+ # Disable off-host networking.
+ tags.append("requires-net:loopback")
+
+ runner_args = [
+ # Arguments are passed directly to runner binary.
+ "--platform=" + platform,
+ "--network=" + network,
+ "--use-tmpfs=" + str(use_tmpfs),
+ "--file-access=" + file_access,
+ "--overlay=" + str(overlay),
+ "--add-uds-tree=" + str(add_uds_tree),
+ ]
+
+ # Call the rule above.
+ _runner_test(
+ name = name,
+ test = test,
+ runner_args = runner_args,
+ data = [loopback],
+ size = size,
+ tags = tags,
+ shard_count = shard_count,
+ )
+
+def syscall_test(
+ test,
+ shard_count = 5,
+ size = "small",
+ use_tmpfs = False,
+ add_overlay = False,
+ add_uds_tree = False,
+ add_hostinet = False,
+ tags = None):
+ """syscall_test is a macro that will create targets for all platforms.
+
+ Args:
+ test: the test target.
+ shard_count: shards for defined tests.
+ size: the defined test size.
+ use_tmpfs: use tmpfs in the defined tests.
+ add_overlay: add an overlay test.
+ add_uds_tree: add a UDS test.
+ add_hostinet: add a hostinet test.
+ tags: starting test tags.
+ """
+ if not tags:
+ tags = []
+
+ _syscall_test(
+ test = test,
+ shard_count = shard_count,
+ size = size,
+ platform = "native",
+ use_tmpfs = False,
+ add_uds_tree = add_uds_tree,
+ tags = tags,
+ )
+
+ for (platform, platform_tags) in platforms.items():
+ _syscall_test(
+ test = test,
+ shard_count = shard_count,
+ size = size,
+ platform = platform,
+ use_tmpfs = use_tmpfs,
+ add_uds_tree = add_uds_tree,
+ tags = platform_tags + tags,
+ )
+
+ if add_overlay:
+ _syscall_test(
+ test = test,
+ shard_count = shard_count,
+ size = size,
+ platform = default_platform,
+ use_tmpfs = False, # overlay is adding a writable tmpfs on top of root.
+ add_uds_tree = add_uds_tree,
+ tags = platforms[default_platform] + tags,
+ overlay = True,
+ )
+
+ if not use_tmpfs:
+ # Also test shared gofer access.
+ _syscall_test(
+ test = test,
+ shard_count = shard_count,
+ size = size,
+ platform = default_platform,
+ use_tmpfs = use_tmpfs,
+ add_uds_tree = add_uds_tree,
+ tags = platforms[default_platform] + tags,
+ file_access = "shared",
+ )
+
+ if add_hostinet:
+ _syscall_test(
+ test = test,
+ shard_count = shard_count,
+ size = size,
+ platform = default_platform,
+ use_tmpfs = use_tmpfs,
+ network = "host",
+ add_uds_tree = add_uds_tree,
+ tags = platforms[default_platform] + tags,
+ )
diff --git a/test/syscalls/gtest/BUILD b/test/runner/gtest/BUILD
index de4b2727c..de4b2727c 100644
--- a/test/syscalls/gtest/BUILD
+++ b/test/runner/gtest/BUILD
diff --git a/test/runner/gtest/gtest.go b/test/runner/gtest/gtest.go
new file mode 100644
index 000000000..f96e2415e
--- /dev/null
+++ b/test/runner/gtest/gtest.go
@@ -0,0 +1,167 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gtest contains helpers for running google-test tests from Go.
+package gtest
+
+import (
+ "fmt"
+ "os/exec"
+ "strings"
+)
+
+var (
+ // listTestFlag is the flag that will list tests in gtest binaries.
+ listTestFlag = "--gtest_list_tests"
+
+ // filterTestFlag is the flag that will filter tests in gtest binaries.
+ filterTestFlag = "--gtest_filter"
+
+ // listBechmarkFlag is the flag that will list benchmarks in gtest binaries.
+ listBenchmarkFlag = "--benchmark_list_tests"
+
+ // filterBenchmarkFlag is the flag that will run specified benchmarks.
+ filterBenchmarkFlag = "--benchmark_filter"
+)
+
+// TestCase is a single gtest test case.
+type TestCase struct {
+ // Suite is the suite for this test.
+ Suite string
+
+ // Name is the name of this individual test.
+ Name string
+
+ // all indicates that this will run without flags. This takes
+ // precendence over benchmark below.
+ all bool
+
+ // benchmark indicates that this is a benchmark. In this case, the
+ // suite will be empty, and we will use the appropriate test and
+ // benchmark flags.
+ benchmark bool
+}
+
+// FullName returns the name of the test including the suite. It is suitable to
+// pass to "-gtest_filter".
+func (tc TestCase) FullName() string {
+ return fmt.Sprintf("%s.%s", tc.Suite, tc.Name)
+}
+
+// Args returns arguments to be passed when invoking the test.
+func (tc TestCase) Args() []string {
+ if tc.all {
+ return []string{} // No arguments.
+ }
+ if tc.benchmark {
+ return []string{
+ fmt.Sprintf("%s=^$", filterTestFlag),
+ fmt.Sprintf("%s=^%s$", filterBenchmarkFlag, tc.Name),
+ }
+ }
+ return []string{
+ fmt.Sprintf("%s=^%s$", filterTestFlag, tc.FullName()),
+ fmt.Sprintf("%s=^$", filterBenchmarkFlag),
+ }
+}
+
+// ParseTestCases calls a gtest test binary to list its test and returns a
+// slice with the name and suite of each test.
+//
+// If benchmarks is true, then benchmarks will be included in the list of test
+// cases provided. Note that this requires the binary to support the
+// benchmarks_list_tests flag.
+func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]TestCase, error) {
+ // Run to extract test cases.
+ args := append([]string{listTestFlag}, extraArgs...)
+ cmd := exec.Command(testBin, args...)
+ out, err := cmd.Output()
+ if err != nil {
+ // We failed to list tests with the given flags. Just
+ // return something that will run the binary with no
+ // flags, which should execute all tests.
+ return []TestCase{
+ TestCase{
+ Suite: "Default",
+ Name: "All",
+ all: true,
+ },
+ }, nil
+ }
+
+ // Parse test output.
+ var t []TestCase
+ var suite string
+ for _, line := range strings.Split(string(out), "\n") {
+ // Strip comments.
+ line = strings.Split(line, "#")[0]
+
+ // New suite?
+ if !strings.HasPrefix(line, " ") {
+ suite = strings.TrimSuffix(strings.TrimSpace(line), ".")
+ continue
+ }
+
+ // Individual test.
+ name := strings.TrimSpace(line)
+
+ // Do we have a suite yet?
+ if suite == "" {
+ return nil, fmt.Errorf("test without a suite: %v", name)
+ }
+
+ // Add this individual test.
+ t = append(t, TestCase{
+ Suite: suite,
+ Name: name,
+ })
+ }
+
+ // Finished?
+ if !benchmarks {
+ return t, nil
+ }
+
+ // Run again to extract benchmarks.
+ args = append([]string{listBenchmarkFlag}, extraArgs...)
+ cmd = exec.Command(testBin, args...)
+ out, err = cmd.Output()
+ if err != nil {
+ // We were able to enumerate tests above, but not benchmarks?
+ // We requested them, so we return an error in this case.
+ exitErr, ok := err.(*exec.ExitError)
+ if !ok {
+ return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v", err)
+ }
+ return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v\nstderr\n%s", err, exitErr.Stderr)
+ }
+
+ // Parse benchmark output.
+ for _, line := range strings.Split(string(out), "\n") {
+ // Strip comments.
+ line = strings.Split(line, "#")[0]
+
+ // Single benchmark.
+ name := strings.TrimSpace(line)
+
+ // Add the single benchmark.
+ t = append(t, TestCase{
+ Suite: "Benchmarks",
+ Name: name,
+ benchmark: true,
+ })
+ }
+
+ return t, nil
+}
diff --git a/test/syscalls/syscall_test_runner.go b/test/runner/runner.go
index b9fd885ff..a78ef38e0 100644
--- a/test/syscalls/syscall_test_runner.go
+++ b/test/runner/runner.go
@@ -34,15 +34,11 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/specutils"
"gvisor.dev/gvisor/runsc/testutil"
- "gvisor.dev/gvisor/test/syscalls/gtest"
+ "gvisor.dev/gvisor/test/runner/gtest"
"gvisor.dev/gvisor/test/uds"
)
-// Location of syscall tests, relative to the repo root.
-const testDir = "test/syscalls/linux"
-
var (
- testName = flag.String("test-name", "", "name of test binary to run")
debug = flag.Bool("debug", false, "enable debug logs")
strace = flag.Bool("strace", false, "enable strace logs")
platform = flag.String("platform", "ptrace", "platform to run on")
@@ -103,7 +99,7 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) {
env = append(env, "TEST_UDS_ATTACH_TREE="+socketDir)
}
- cmd := exec.Command(testBin, gtest.FilterTestFlag+"="+tc.FullName())
+ cmd := exec.Command(testBin, tc.Args()...)
cmd.Env = env
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
@@ -296,7 +292,7 @@ func setupUDSTree(spec *specs.Spec) (cleanup func(), err error) {
func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
// Run a new container with the test executable and filter for the
// given test suite and name.
- spec := testutil.NewSpecWithArgs(testBin, gtest.FilterTestFlag+"="+tc.FullName())
+ spec := testutil.NewSpecWithArgs(append([]string{testBin}, tc.Args()...)...)
// Mark the root as writeable, as some tests attempt to
// write to the rootfs, and expect EACCES, not EROFS.
@@ -404,9 +400,10 @@ func matchString(a, b string) (bool, error) {
func main() {
flag.Parse()
- if *testName == "" {
- fatalf("test-name flag must be provided")
+ if flag.NArg() != 1 {
+ fatalf("test must be provided")
}
+ testBin := flag.Args()[0] // Only argument.
log.SetLevel(log.Info)
if *debug {
@@ -436,34 +433,31 @@ func main() {
}
}
- // Get path to test binary.
- fullTestName := filepath.Join(testDir, *testName)
- testBin, err := testutil.FindFile(fullTestName)
- if err != nil {
- fatalf("FindFile(%q) failed: %v", fullTestName, err)
- }
-
// Get all test cases in each binary.
- testCases, err := gtest.ParseTestCases(testBin)
+ testCases, err := gtest.ParseTestCases(testBin, true)
if err != nil {
fatalf("ParseTestCases(%q) failed: %v", testBin, err)
}
// Get subset of tests corresponding to shard.
- begin, end, err := testutil.TestBoundsForShard(len(testCases))
+ indices, err := testutil.TestIndicesForShard(len(testCases))
if err != nil {
fatalf("TestsForShard() failed: %v", err)
}
- testCases = testCases[begin:end]
+
+ // Resolve the absolute path for the binary.
+ testBin, err = filepath.Abs(testBin)
+ if err != nil {
+ fatalf("Abs() failed: %v", err)
+ }
// Run the tests.
var tests []testing.InternalTest
- for _, tc := range testCases {
+ for _, tci := range indices {
// Capture tc.
- tc := tc
- testName := fmt.Sprintf("%s_%s", tc.Suite, tc.Name)
+ tc := testCases[tci]
tests = append(tests, testing.InternalTest{
- Name: testName,
+ Name: fmt.Sprintf("%s_%s", tc.Suite, tc.Name),
F: func(t *testing.T) {
if *parallel {
t.Parallel()
diff --git a/test/runtimes/README.md b/test/runtimes/README.md
index e41e78f77..42d722553 100644
--- a/test/runtimes/README.md
+++ b/test/runtimes/README.md
@@ -12,24 +12,39 @@ The following runtimes are currently supported:
- PHP 7.3
- Python 3.7
-#### Prerequisites:
+### Building and pushing the images:
-1) [Install and configure Docker](https://docs.docker.com/install/)
-
-2) Build each Docker container from the runtimes/images directory:
+The canonical source of images is the
+[gvisor-presubmit container registry](https://gcr.io/gvisor-presubmit/). You can
+build new images with the following command:
```bash
$ cd images
$ docker build -f Dockerfile_$LANG [-t $NAME] .
```
-### Testing:
+To push them to our container registry, set the tag in the command above to
+`gcr.io/gvisor-presubmit/$LANG`, then push them. (Note that you will need
+appropriate permissions to the `gvisor-presubmit` GCP project.)
+
+```bash
+gcloud docker -- push gcr.io/gvisor-presubmit/$LANG
+```
+
+#### Running in Docker locally:
+
+1) [Install and configure Docker](https://docs.docker.com/install/)
+
+2) Pull the image you want to run:
+
+```bash
+$ docker pull gcr.io/gvisor-presubmit/$LANG
+```
-If the prerequisites have been fulfilled, you can run the tests with the
-following command:
+3) Run docker with the image.
```bash
-$ docker run --rm -it $NAME [FLAG]
+$ docker run [--runtime=runsc] --rm -it $NAME [FLAG]
```
Running the command with no flags will cause all the available tests to execute.
diff --git a/test/runtimes/runner.go b/test/runtimes/runner.go
index bec37c69d..ddb890dbc 100644
--- a/test/runtimes/runner.go
+++ b/test/runtimes/runner.go
@@ -20,7 +20,6 @@ import (
"flag"
"fmt"
"io"
- "log"
"os"
"sort"
"strings"
@@ -101,17 +100,15 @@ func getTests(d dockerutil.Docker, blacklist map[string]struct{}) ([]testing.Int
// shard.
tests := strings.Fields(list)
sort.Strings(tests)
- begin, end, err := testutil.TestBoundsForShard(len(tests))
+ indices, err := testutil.TestIndicesForShard(len(tests))
if err != nil {
return nil, fmt.Errorf("TestsForShard() failed: %v", err)
}
- log.Printf("Got bounds [%d:%d) for shard out of %d total tests", begin, end, len(tests))
- tests = tests[begin:end]
var itests []testing.InternalTest
- for _, tc := range tests {
+ for _, tci := range indices {
// Capture tc in this scope.
- tc := tc
+ tc := tests[tci]
itests = append(itests, testing.InternalTest{
Name: tc,
F: func(t *testing.T) {
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index 40e974314..3518e862d 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -1,5 +1,4 @@
-load("//tools:defs.bzl", "go_binary")
-load("//test/syscalls:build_defs.bzl", "syscall_test")
+load("//test/runner:defs.bzl", "syscall_test")
package(licenses = ["notice"])
@@ -46,6 +45,15 @@ syscall_test(test = "//test/syscalls/linux:brk_test")
syscall_test(test = "//test/syscalls/linux:socket_test")
syscall_test(
+ size = "large",
+ shard_count = 50,
+ # Takes too long for TSAN. Since this is kind of a stress test that doesn't
+ # involve much concurrency, TSAN's usefulness here is limited anyway.
+ tags = ["nogotsan"],
+ test = "//test/syscalls/linux:socket_stress_test",
+)
+
+syscall_test(
add_overlay = True,
test = "//test/syscalls/linux:chdir_test",
)
@@ -225,7 +233,6 @@ syscall_test(
syscall_test(
add_overlay = True,
test = "//test/syscalls/linux:mknod_test",
- use_tmpfs = True, # mknod is not supported over gofer.
)
syscall_test(
@@ -251,6 +258,8 @@ syscall_test(
syscall_test(test = "//test/syscalls/linux:munmap_test")
+syscall_test(test = "//test/syscalls/linux:network_namespace_test")
+
syscall_test(
add_overlay = True,
test = "//test/syscalls/linux:open_create_test",
@@ -669,6 +678,8 @@ syscall_test(
test = "//test/syscalls/linux:truncate_test",
)
+syscall_test(test = "//test/syscalls/linux:tuntap_test")
+
syscall_test(test = "//test/syscalls/linux:udp_bind_test")
syscall_test(
@@ -718,21 +729,3 @@ syscall_test(test = "//test/syscalls/linux:proc_net_unix_test")
syscall_test(test = "//test/syscalls/linux:proc_net_tcp_test")
syscall_test(test = "//test/syscalls/linux:proc_net_udp_test")
-
-go_binary(
- name = "syscall_test_runner",
- testonly = 1,
- srcs = ["syscall_test_runner.go"],
- data = [
- "//runsc",
- ],
- deps = [
- "//pkg/log",
- "//runsc/specutils",
- "//runsc/testutil",
- "//test/syscalls/gtest",
- "//test/uds",
- "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
- "@org_golang_x_sys//unix:go_default_library",
- ],
-)
diff --git a/test/syscalls/build_defs.bzl b/test/syscalls/build_defs.bzl
deleted file mode 100644
index 1df761dd0..000000000
--- a/test/syscalls/build_defs.bzl
+++ /dev/null
@@ -1,153 +0,0 @@
-"""Defines a rule for syscall test targets."""
-
-load("//tools:defs.bzl", "loopback")
-
-# syscall_test is a macro that will create targets to run the given test target
-# on the host (native) and runsc.
-def syscall_test(
- test,
- shard_count = 5,
- size = "small",
- use_tmpfs = False,
- add_overlay = False,
- add_uds_tree = False,
- add_hostinet = False,
- tags = None):
- _syscall_test(
- test = test,
- shard_count = shard_count,
- size = size,
- platform = "native",
- use_tmpfs = False,
- add_uds_tree = add_uds_tree,
- tags = tags,
- )
-
- _syscall_test(
- test = test,
- shard_count = shard_count,
- size = size,
- platform = "kvm",
- use_tmpfs = use_tmpfs,
- add_uds_tree = add_uds_tree,
- tags = tags,
- )
-
- _syscall_test(
- test = test,
- shard_count = shard_count,
- size = size,
- platform = "ptrace",
- use_tmpfs = use_tmpfs,
- add_uds_tree = add_uds_tree,
- tags = tags,
- )
-
- if add_overlay:
- _syscall_test(
- test = test,
- shard_count = shard_count,
- size = size,
- platform = "ptrace",
- use_tmpfs = False, # overlay is adding a writable tmpfs on top of root.
- add_uds_tree = add_uds_tree,
- tags = tags,
- overlay = True,
- )
-
- if not use_tmpfs:
- # Also test shared gofer access.
- _syscall_test(
- test = test,
- shard_count = shard_count,
- size = size,
- platform = "ptrace",
- use_tmpfs = use_tmpfs,
- add_uds_tree = add_uds_tree,
- tags = tags,
- file_access = "shared",
- )
-
- if add_hostinet:
- _syscall_test(
- test = test,
- shard_count = shard_count,
- size = size,
- platform = "ptrace",
- use_tmpfs = use_tmpfs,
- network = "host",
- add_uds_tree = add_uds_tree,
- tags = tags,
- )
-
-def _syscall_test(
- test,
- shard_count,
- size,
- platform,
- use_tmpfs,
- tags,
- network = "none",
- file_access = "exclusive",
- overlay = False,
- add_uds_tree = False):
- test_name = test.split(":")[1]
-
- # Prepend "runsc" to non-native platform names.
- full_platform = platform if platform == "native" else "runsc_" + platform
-
- name = test_name + "_" + full_platform
- if file_access == "shared":
- name += "_shared"
- if overlay:
- name += "_overlay"
- if network != "none":
- name += "_" + network + "net"
-
- if tags == None:
- tags = []
-
- # Add the full_platform and file access in a tag to make it easier to run
- # all the tests on a specific flavor. Use --test_tag_filters=ptrace,file_shared.
- tags += [full_platform, "file_" + file_access]
-
- # Add tag to prevent the tests from running in a Bazel sandbox.
- # TODO(b/120560048): Make the tests run without this tag.
- tags.append("no-sandbox")
-
- # TODO(b/112165693): KVM tests are tagged "manual" to until the platform is
- # more stable.
- if platform == "kvm":
- tags += ["manual"]
- tags += ["requires-kvm"]
-
- args = [
- # Arguments are passed directly to syscall_test_runner binary.
- "--test-name=" + test_name,
- "--platform=" + platform,
- "--network=" + network,
- "--use-tmpfs=" + str(use_tmpfs),
- "--file-access=" + file_access,
- "--overlay=" + str(overlay),
- "--add-uds-tree=" + str(add_uds_tree),
- ]
-
- sh_test(
- srcs = ["syscall_test_runner.sh"],
- name = name,
- data = [
- ":syscall_test_runner",
- loopback,
- test,
- ],
- args = args,
- size = size,
- tags = tags,
- shard_count = shard_count,
- )
-
-def sh_test(**kwargs):
- """Wraps the standard sh_test."""
- native.sh_test(
- **kwargs
- )
diff --git a/test/syscalls/gtest/gtest.go b/test/syscalls/gtest/gtest.go
deleted file mode 100644
index bdec8eb07..000000000
--- a/test/syscalls/gtest/gtest.go
+++ /dev/null
@@ -1,93 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package gtest contains helpers for running google-test tests from Go.
-package gtest
-
-import (
- "fmt"
- "os/exec"
- "strings"
-)
-
-var (
- // ListTestFlag is the flag that will list tests in gtest binaries.
- ListTestFlag = "--gtest_list_tests"
-
- // FilterTestFlag is the flag that will filter tests in gtest binaries.
- FilterTestFlag = "--gtest_filter"
-)
-
-// TestCase is a single gtest test case.
-type TestCase struct {
- // Suite is the suite for this test.
- Suite string
-
- // Name is the name of this individual test.
- Name string
-}
-
-// FullName returns the name of the test including the suite. It is suitable to
-// pass to "-gtest_filter".
-func (tc TestCase) FullName() string {
- return fmt.Sprintf("%s.%s", tc.Suite, tc.Name)
-}
-
-// ParseTestCases calls a gtest test binary to list its test and returns a
-// slice with the name and suite of each test.
-func ParseTestCases(testBin string, extraArgs ...string) ([]TestCase, error) {
- args := append([]string{ListTestFlag}, extraArgs...)
- cmd := exec.Command(testBin, args...)
- out, err := cmd.Output()
- if err != nil {
- exitErr, ok := err.(*exec.ExitError)
- if !ok {
- return nil, fmt.Errorf("could not enumerate gtest tests: %v", err)
- }
- return nil, fmt.Errorf("could not enumerate gtest tests: %v\nstderr:\n%s", err, exitErr.Stderr)
- }
-
- var t []TestCase
- var suite string
- for _, line := range strings.Split(string(out), "\n") {
- // Strip comments.
- line = strings.Split(line, "#")[0]
-
- // New suite?
- if !strings.HasPrefix(line, " ") {
- suite = strings.TrimSuffix(strings.TrimSpace(line), ".")
- continue
- }
-
- // Individual test.
- name := strings.TrimSpace(line)
-
- // Do we have a suite yet?
- if suite == "" {
- return nil, fmt.Errorf("test without a suite: %v", name)
- }
-
- // Add this individual test.
- t = append(t, TestCase{
- Suite: suite,
- Name: name,
- })
-
- }
-
- if len(t) == 0 {
- return nil, fmt.Errorf("no tests parsed from %v", testBin)
- }
- return t, nil
-}
diff --git a/test/syscalls/linux/32bit.cc b/test/syscalls/linux/32bit.cc
index 2751fb4e7..3c825477c 100644
--- a/test/syscalls/linux/32bit.cc
+++ b/test/syscalls/linux/32bit.cc
@@ -74,7 +74,7 @@ void ExitGroup32(const char instruction[2], int code) {
"int $3\n"
:
: [ code ] "m"(code), [ ip ] "d"(m.ptr())
- : "rax", "rbx", "rsp");
+ : "rax", "rbx");
}
constexpr int kExitCode = 42;
@@ -102,7 +102,8 @@ TEST(Syscall32Bit, Int80) {
}
TEST(Syscall32Bit, Sysenter) {
- if (PlatformSupport32Bit() == PlatformSupport::Allowed &&
+ if ((PlatformSupport32Bit() == PlatformSupport::Allowed ||
+ PlatformSupport32Bit() == PlatformSupport::Ignored) &&
GetCPUVendor() == CPUVendor::kAMD) {
// SYSENTER is an illegal instruction in compatibility mode on AMD.
EXPECT_EXIT(ExitGroup32(kSysenter, kExitCode),
@@ -133,7 +134,8 @@ TEST(Syscall32Bit, Sysenter) {
}
TEST(Syscall32Bit, Syscall) {
- if (PlatformSupport32Bit() == PlatformSupport::Allowed &&
+ if ((PlatformSupport32Bit() == PlatformSupport::Allowed ||
+ PlatformSupport32Bit() == PlatformSupport::Ignored) &&
GetCPUVendor() == CPUVendor::kIntel) {
// SYSCALL is an illegal instruction in compatibility mode on Intel.
EXPECT_EXIT(ExitGroup32(kSyscall, kExitCode),
@@ -153,7 +155,7 @@ TEST(Syscall32Bit, Syscall) {
case PlatformSupport::Ignored:
// See above.
EXPECT_EXIT(ExitGroup32(kSyscall, kExitCode),
- ::testing::KilledBySignal(SIGILL), "");
+ ::testing::KilledBySignal(SIGSEGV), "");
break;
case PlatformSupport::Allowed:
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 74bf068ec..704bae17b 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "cc_binary", "cc_library", "default_net_util", "select_system")
+load("//tools:defs.bzl", "cc_binary", "cc_library", "default_net_util", "gtest", "select_arch", "select_system")
package(
default_visibility = ["//:sandbox"],
@@ -10,9 +10,20 @@ exports_files(
"socket.cc",
"socket_inet_loopback.cc",
"socket_ip_loopback_blocking.cc",
+ "socket_ip_tcp_generic_loopback.cc",
"socket_ip_tcp_loopback.cc",
+ "socket_ip_tcp_loopback_blocking.cc",
+ "socket_ip_tcp_loopback_nonblock.cc",
+ "socket_ip_tcp_udp_generic.cc",
+ "socket_ip_udp_loopback.cc",
+ "socket_ip_udp_loopback_blocking.cc",
+ "socket_ip_udp_loopback_nonblock.cc",
+ "socket_ip_unbound.cc",
+ "socket_ipv4_tcp_unbound_external_networking_test.cc",
+ "socket_ipv4_udp_unbound_external_networking_test.cc",
"socket_ipv4_udp_unbound_loopback.cc",
"tcp_socket.cc",
+ "udp_bind.cc",
"udp_socket.cc",
],
visibility = ["//:sandbox"],
@@ -82,14 +93,14 @@ cc_library(
srcs = ["base_poll_test.cc"],
hdrs = ["base_poll_test.h"],
deps = [
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:signal_util",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -99,11 +110,11 @@ cc_library(
hdrs = ["file_base.h"],
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -121,6 +132,17 @@ cc_library(
)
cc_library(
+ name = "socket_netlink_route_util",
+ testonly = 1,
+ srcs = ["socket_netlink_route_util.cc"],
+ hdrs = ["socket_netlink_route_util.h"],
+ deps = [
+ ":socket_netlink_util",
+ "@com_google_absl//absl/types:optional",
+ ],
+)
+
+cc_library(
name = "socket_test_util",
testonly = 1,
srcs = [
@@ -130,11 +152,12 @@ cc_library(
hdrs = ["socket_test_util.h"],
defines = select_system(),
deps = default_net_util() + [
- "@com_google_googletest//:gtest",
+ gtest,
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
+ "@com_google_absl//absl/types:optional",
"//test/util:file_descriptor",
"//test/util:posix_error",
"//test/util:temp_path",
@@ -155,9 +178,9 @@ cc_library(
hdrs = ["unix_domain_socket_test_util.h"],
deps = [
":socket_test_util",
- "//test/util:test_util",
"@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
+ gtest,
+ "//test/util:test_util",
],
)
@@ -179,30 +202,33 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:cleanup",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:posix_error",
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:timer_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
cc_binary(
name = "32bit_test",
testonly = 1,
- srcs = ["32bit.cc"],
+ srcs = select_arch(
+ amd64 = ["32bit.cc"],
+ arm64 = [],
+ ),
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/base:core_headers",
+ gtest,
"//test/util:memory_util",
"//test/util:platform_util",
"//test/util:posix_error",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_googletest//:gtest",
],
)
@@ -215,9 +241,9 @@ cc_binary(
":socket_test_util",
":unix_domain_socket_test_util",
"//test/util:file_descriptor",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -230,9 +256,9 @@ cc_binary(
":socket_test_util",
":unix_domain_socket_test_util",
"//test/util:file_descriptor",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -244,10 +270,10 @@ cc_binary(
deps = [
"//test/util:capability_util",
"//test/util:fs_util",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -259,12 +285,12 @@ cc_binary(
deps = [
"//test/util:cleanup",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -277,12 +303,11 @@ cc_binary(
],
linkstatic = 1,
deps = [
- # The heapchecker doesn't recognize that io_destroy munmaps.
- "@com_google_googletest//:gtest",
- "@com_google_absl//absl/strings",
"//test/util:cleanup",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:memory_util",
"//test/util:posix_error",
"//test/util:proc_util",
@@ -299,12 +324,12 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:signal_util",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -317,9 +342,9 @@ cc_binary(
"//:sandbox",
],
deps = [
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -331,9 +356,9 @@ cc_binary(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -345,9 +370,9 @@ cc_binary(
deps = [
":socket_test_util",
"//test/util:file_descriptor",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -369,10 +394,10 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:capability_util",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -385,10 +410,10 @@ cc_binary(
"//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -401,14 +426,14 @@ cc_binary(
"//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/synchronization",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/synchronization",
- "@com_google_googletest//:gtest",
],
)
@@ -421,12 +446,12 @@ cc_binary(
"//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/flags:flag",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/flags:flag",
- "@com_google_googletest//:gtest",
],
)
@@ -440,12 +465,12 @@ cc_binary(
"//test/util:cleanup",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:mount_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -455,9 +480,9 @@ cc_binary(
srcs = ["clock_getres.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -467,11 +492,11 @@ cc_binary(
srcs = ["clock_gettime.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -481,13 +506,13 @@ cc_binary(
srcs = ["concurrency.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:platform_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -500,9 +525,9 @@ cc_binary(
":socket_test_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -513,10 +538,10 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:fs_util",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -527,9 +552,9 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -541,11 +566,11 @@ cc_binary(
deps = [
"//test/util:eventfd_util",
"//test/util:file_descriptor",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -558,10 +583,10 @@ cc_binary(
"//test/util:epoll_util",
"//test/util:eventfd_util",
"//test/util:file_descriptor",
+ gtest,
"//test/util:posix_error",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -573,10 +598,10 @@ cc_binary(
deps = [
"//test/util:epoll_util",
"//test/util:eventfd_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_googletest//:gtest",
],
)
@@ -586,12 +611,12 @@ cc_binary(
srcs = ["exceptions.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:logging",
"//test/util:platform_util",
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -601,10 +626,10 @@ cc_binary(
srcs = ["getcpu.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -614,10 +639,10 @@ cc_binary(
srcs = ["getcpu.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -627,33 +652,36 @@ cc_binary(
srcs = ["getrusage.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:memory_util",
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
cc_binary(
name = "exec_binary_test",
testonly = 1,
- srcs = ["exec_binary.cc"],
+ srcs = select_arch(
+ amd64 = ["exec_binary.cc"],
+ arm64 = [],
+ ),
linkstatic = 1,
deps = [
"//test/util:cleanup",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:multiprocess_util",
"//test/util:posix_error",
"//test/util:proc_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -676,15 +704,15 @@ cc_binary(
deps = [
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/types:optional",
+ gtest,
"//test/util:multiprocess_util",
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/types:optional",
- "@com_google_googletest//:gtest",
],
)
@@ -695,11 +723,11 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
"//test/util:time_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -712,10 +740,10 @@ cc_binary(
":file_base",
"//test/util:cleanup",
"//test/util:file_descriptor",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -725,9 +753,9 @@ cc_binary(
srcs = ["fault.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -738,10 +766,10 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:capability_util",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -754,18 +782,19 @@ cc_binary(
":socket_test_util",
"//test/util:cleanup",
"//test/util:eventfd_util",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:multiprocess_util",
"//test/util:posix_error",
"//test/util:save_util",
"//test/util:temp_path",
"//test/util:test_util",
"//test/util:timer_util",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -779,15 +808,15 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:timer_util",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -798,40 +827,46 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:capability_util",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:memory_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
cc_binary(
name = "fpsig_fork_test",
testonly = 1,
- srcs = ["fpsig_fork.cc"],
+ srcs = select_arch(
+ amd64 = ["fpsig_fork.cc"],
+ arm64 = [],
+ ),
linkstatic = 1,
deps = [
+ gtest,
"//test/util:logging",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_googletest//:gtest",
],
)
cc_binary(
name = "fpsig_nested_test",
testonly = 1,
- srcs = ["fpsig_nested.cc"],
+ srcs = select_arch(
+ amd64 = ["fpsig_nested.cc"],
+ arm64 = [],
+ ),
linkstatic = 1,
deps = [
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_googletest//:gtest",
],
)
@@ -842,10 +877,10 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -856,10 +891,10 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -871,6 +906,9 @@ cc_binary(
deps = [
"//test/util:cleanup",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:memory_util",
"//test/util:save_util",
"//test/util:temp_path",
@@ -879,9 +917,6 @@ cc_binary(
"//test/util:thread_util",
"//test/util:time_util",
"//test/util:timer_util",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -894,12 +929,12 @@ cc_binary(
"//test/util:eventfd_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -909,9 +944,9 @@ cc_binary(
srcs = ["getrandom.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -944,10 +979,10 @@ cc_binary(
":socket_test_util",
":unix_domain_socket_test_util",
"//test/util:file_descriptor",
+ gtest,
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -971,9 +1006,9 @@ cc_binary(
":socket_test_util",
"//test/util:capability_util",
"//test/util:file_descriptor",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -984,6 +1019,9 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:multiprocess_util",
"//test/util:posix_error",
@@ -991,9 +1029,6 @@ cc_binary(
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:timer_util",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1005,15 +1040,15 @@ cc_binary(
deps = [
"//test/util:capability_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1026,14 +1061,14 @@ cc_binary(
"//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -1044,10 +1079,10 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1058,6 +1093,7 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ gtest,
"//test/util:logging",
"//test/util:memory_util",
"//test/util:multiprocess_util",
@@ -1065,7 +1101,6 @@ cc_binary(
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1076,12 +1111,12 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:cleanup",
+ "@com_google_absl//absl/memory",
+ gtest,
"//test/util:memory_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/memory",
- "@com_google_googletest//:gtest",
],
)
@@ -1091,11 +1126,11 @@ cc_binary(
srcs = ["mincore.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:memory_util",
"//test/util:posix_error",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1108,10 +1143,10 @@ cc_binary(
":temp_umask",
"//test/util:capability_util",
"//test/util:fs_util",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1122,11 +1157,11 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1138,12 +1173,12 @@ cc_binary(
deps = [
"//test/util:capability_util",
"//test/util:cleanup",
+ gtest,
"//test/util:memory_util",
"//test/util:multiprocess_util",
"//test/util:rlimit_util",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1156,13 +1191,13 @@ cc_binary(
"//test/util:cleanup",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:memory_util",
"//test/util:multiprocess_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -1175,6 +1210,9 @@ cc_binary(
"//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:mount_util",
"//test/util:multiprocess_util",
"//test/util:posix_error",
@@ -1182,9 +1220,6 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1194,10 +1229,9 @@ cc_binary(
srcs = ["mremap.cc"],
linkstatic = 1,
deps = [
- # The heap check fails due to MremapDeathTest
- "@com_google_googletest//:gtest",
- "@com_google_absl//absl/strings",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:logging",
"//test/util:memory_util",
"//test/util:multiprocess_util",
@@ -1229,9 +1263,9 @@ cc_binary(
srcs = ["munmap.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1248,14 +1282,14 @@ cc_binary(
"//test/util:cleanup",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -1269,10 +1303,10 @@ cc_binary(
"//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1286,11 +1320,11 @@ cc_binary(
":unix_domain_socket_test_util",
"//test/util:capability_util",
"//test/util:file_descriptor",
- "//test/util:test_main",
- "//test/util:test_util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:endian",
- "@com_google_googletest//:gtest",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
],
)
@@ -1304,11 +1338,11 @@ cc_binary(
":unix_domain_socket_test_util",
"//test/util:capability_util",
"//test/util:file_descriptor",
- "//test/util:test_main",
- "//test/util:test_util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:endian",
- "@com_google_googletest//:gtest",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
],
)
@@ -1320,16 +1354,16 @@ cc_binary(
deps = [
"//test/util:capability_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:posix_error",
"//test/util:pty_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1341,12 +1375,12 @@ cc_binary(
deps = [
"//test/util:capability_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/base:core_headers",
+ gtest,
"//test/util:posix_error",
"//test/util:pty_util",
"//test/util:test_main",
"//test/util:thread_util",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_googletest//:gtest",
],
)
@@ -1359,12 +1393,12 @@ cc_binary(
"//test/syscalls/linux:socket_test_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1374,13 +1408,13 @@ cc_binary(
srcs = ["pause.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1391,15 +1425,16 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1412,13 +1447,13 @@ cc_binary(
":base_poll_test",
"//test/util:eventfd_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1429,11 +1464,11 @@ cc_binary(
linkstatic = 1,
deps = [
":base_poll_test",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1444,9 +1479,9 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1458,12 +1493,12 @@ cc_binary(
deps = [
"//test/util:capability_util",
"//test/util:cleanup",
+ "@com_google_absl//absl/flags:flag",
+ gtest,
"//test/util:multiprocess_util",
"//test/util:posix_error",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/flags:flag",
- "@com_google_googletest//:gtest",
],
)
@@ -1474,13 +1509,13 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:capability_util",
+ "@com_google_absl//absl/flags:flag",
+ gtest,
"//test/util:logging",
"//test/util:multiprocess_util",
"//test/util:posix_error",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/flags:flag",
- "@com_google_googletest//:gtest",
],
)
@@ -1491,10 +1526,10 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1505,6 +1540,8 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:memory_util",
"//test/util:temp_path",
@@ -1512,8 +1549,6 @@ cc_binary(
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:timer_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1527,13 +1562,13 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -1545,11 +1580,11 @@ cc_binary(
deps = [
"//test/util:capability_util",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -1563,6 +1598,10 @@ cc_binary(
"//test/util:cleanup",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:memory_util",
"//test/util:posix_error",
"//test/util:temp_path",
@@ -1570,10 +1609,6 @@ cc_binary(
"//test/util:thread_util",
"//test/util:time_util",
"//test/util:timer_util",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1587,11 +1622,11 @@ cc_binary(
"//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
- "//test/util:test_main",
- "//test/util:test_util",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
],
)
@@ -1603,17 +1638,17 @@ cc_binary(
deps = [
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:optional",
+ gtest,
"//test/util:memory_util",
"//test/util:posix_error",
"//test/util:proc_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:str_format",
- "@com_google_absl//absl/types:optional",
- "@com_google_googletest//:gtest",
],
)
@@ -1627,6 +1662,8 @@ cc_binary(
"//test/util:cleanup",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:logging",
"//test/util:multiprocess_util",
"//test/util:posix_error",
@@ -1634,8 +1671,6 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:time_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -1646,11 +1681,11 @@ cc_binary(
linkstatic = 1,
deps = [
":base_poll_test",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1660,6 +1695,9 @@ cc_binary(
srcs = ["ptrace.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:multiprocess_util",
"//test/util:platform_util",
@@ -1667,9 +1705,6 @@ cc_binary(
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:time_util",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1679,10 +1714,10 @@ cc_binary(
srcs = ["pwrite64.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1696,12 +1731,12 @@ cc_binary(
deps = [
":file_base",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -1715,11 +1750,11 @@ cc_binary(
":unix_domain_socket_test_util",
"//test/util:capability_util",
"//test/util:file_descriptor",
- "//test/util:test_main",
- "//test/util:test_util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:endian",
- "@com_google_googletest//:gtest",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
],
)
@@ -1733,10 +1768,10 @@ cc_binary(
":unix_domain_socket_test_util",
"//test/util:capability_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/base:core_headers",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_googletest//:gtest",
],
)
@@ -1750,10 +1785,10 @@ cc_binary(
":unix_domain_socket_test_util",
"//test/util:capability_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/base:core_headers",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_googletest//:gtest",
],
)
@@ -1764,10 +1799,10 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1778,10 +1813,10 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1797,13 +1832,13 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:timer_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -1818,12 +1853,12 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -1837,11 +1872,11 @@ cc_binary(
"//test/util:cleanup",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -1865,11 +1900,11 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/syscalls/linux/rseq:lib",
+ gtest,
"//test/util:logging",
"//test/util:multiprocess_util",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1880,11 +1915,11 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:cleanup",
+ gtest,
"//test/util:logging",
"//test/util:posix_error",
"//test/util:signal_util",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1894,9 +1929,9 @@ cc_binary(
srcs = ["sched.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1906,9 +1941,9 @@ cc_binary(
srcs = ["sched_yield.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1918,6 +1953,8 @@ cc_binary(
srcs = ["seccomp.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/base:core_headers",
+ gtest,
"//test/util:logging",
"//test/util:memory_util",
"//test/util:multiprocess_util",
@@ -1925,8 +1962,6 @@ cc_binary(
"//test/util:proc_util",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_googletest//:gtest",
],
)
@@ -1938,14 +1973,14 @@ cc_binary(
deps = [
":base_poll_test",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:multiprocess_util",
"//test/util:posix_error",
"//test/util:rlimit_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1957,13 +1992,13 @@ cc_binary(
deps = [
"//test/util:eventfd_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1975,12 +2010,12 @@ cc_binary(
deps = [
":socket_test_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -1991,13 +2026,13 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -2007,9 +2042,9 @@ cc_binary(
srcs = ["sigaction.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -2024,28 +2059,34 @@ cc_binary(
deps = [
"//test/util:cleanup",
"//test/util:fs_util",
+ gtest,
"//test/util:multiprocess_util",
"//test/util:posix_error",
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_googletest//:gtest",
],
)
cc_binary(
name = "sigiret_test",
testonly = 1,
- srcs = ["sigiret.cc"],
+ srcs = select_arch(
+ amd64 = ["sigiret.cc"],
+ arm64 = [],
+ ),
linkstatic = 1,
deps = [
+ gtest,
"//test/util:logging",
"//test/util:signal_util",
"//test/util:test_util",
"//test/util:timer_util",
- "@com_google_googletest//:gtest",
- ],
+ ] + select_arch(
+ amd64 = [],
+ arm64 = ["//test/util:test_main"],
+ ),
)
cc_binary(
@@ -2055,14 +2096,14 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/synchronization",
+ gtest,
"//test/util:logging",
"//test/util:posix_error",
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/synchronization",
- "@com_google_googletest//:gtest",
],
)
@@ -2072,10 +2113,10 @@ cc_binary(
srcs = ["sigprocmask.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -2085,13 +2126,13 @@ cc_binary(
srcs = ["sigstop.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:multiprocess_util",
"//test/util:posix_error",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -2102,13 +2143,13 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:signal_util",
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:timer_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -2124,14 +2165,30 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
- "//test/util:test_util",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
- "@com_google_googletest//:gtest",
+ gtest,
+ "//test/util:test_util",
],
alwayslink = 1,
)
+cc_binary(
+ name = "socket_stress_test",
+ testonly = 1,
+ srcs = [
+ "socket_generic_stress.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_test_util",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
cc_library(
name = "socket_unix_dgram_test_cases",
testonly = 1,
@@ -2140,8 +2197,8 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2154,8 +2211,8 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2171,10 +2228,11 @@ cc_library(
],
deps = [
":socket_test_util",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2191,8 +2249,8 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2209,9 +2267,9 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:memory_util",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2229,8 +2287,8 @@ cc_library(
":ip_socket_test_util",
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2247,8 +2305,8 @@ cc_library(
deps = [
":ip_socket_test_util",
":socket_test_util",
+ gtest,
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2265,9 +2323,9 @@ cc_library(
deps = [
":ip_socket_test_util",
":socket_test_util",
- "//test/util:test_util",
"@com_google_absl//absl/memory",
- "@com_google_googletest//:gtest",
+ gtest,
+ "//test/util:test_util",
],
alwayslink = 1,
)
@@ -2284,8 +2342,8 @@ cc_library(
deps = [
":ip_socket_test_util",
":socket_test_util",
+ gtest,
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2302,8 +2360,8 @@ cc_library(
deps = [
":ip_socket_test_util",
":socket_test_util",
+ gtest,
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2366,9 +2424,9 @@ cc_binary(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -2398,9 +2456,9 @@ cc_binary(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -2430,9 +2488,9 @@ cc_binary(
deps = [
":ip_socket_test_util",
":socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -2530,10 +2588,10 @@ cc_binary(
":socket_bind_to_device_util",
":socket_test_util",
"//test/util:capability_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_googletest//:gtest",
],
)
@@ -2549,10 +2607,10 @@ cc_binary(
":socket_bind_to_device_util",
":socket_test_util",
"//test/util:capability_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_googletest//:gtest",
],
)
@@ -2568,10 +2626,10 @@ cc_binary(
":socket_bind_to_device_util",
":socket_test_util",
"//test/util:capability_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_googletest//:gtest",
],
)
@@ -2617,9 +2675,9 @@ cc_binary(
deps = [
":ip_socket_test_util",
":socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -2698,15 +2756,15 @@ cc_binary(
":ip_socket_test_util",
":socket_test_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:posix_error",
"//test/util:save_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -2718,9 +2776,9 @@ cc_binary(
deps = [
":socket_test_util",
"//test/util:file_descriptor",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -2732,12 +2790,14 @@ cc_binary(
deps = [
":socket_netlink_util",
":socket_test_util",
+ "//test/util:capability_util",
"//test/util:cleanup",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:optional",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings:str_format",
- "@com_google_googletest//:gtest",
],
)
@@ -2750,9 +2810,9 @@ cc_binary(
":socket_netlink_util",
":socket_test_util",
"//test/util:file_descriptor",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -2770,9 +2830,9 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
- "//test/util:test_util",
"@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
+ gtest,
+ "//test/util:test_util",
],
alwayslink = 1,
)
@@ -2789,11 +2849,11 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:timer_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2810,10 +2870,10 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2830,10 +2890,10 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2850,11 +2910,11 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:timer_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2871,8 +2931,8 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2889,10 +2949,10 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2986,9 +3046,9 @@ cc_binary(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3000,9 +3060,9 @@ cc_binary(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3014,9 +3074,9 @@ cc_binary(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3031,9 +3091,9 @@ cc_binary(
":socket_blocking_test_cases",
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3048,9 +3108,9 @@ cc_binary(
":ip_socket_test_util",
":socket_blocking_test_cases",
":socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3065,9 +3125,9 @@ cc_binary(
":socket_non_stream_blocking_test_cases",
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3082,9 +3142,9 @@ cc_binary(
":ip_socket_test_util",
":socket_non_stream_blocking_test_cases",
":socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3100,9 +3160,9 @@ cc_binary(
":socket_unix_cmsg_test_cases",
":socket_unix_test_cases",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3114,9 +3174,9 @@ cc_binary(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3128,9 +3188,9 @@ cc_binary(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3143,10 +3203,10 @@ cc_binary(
":socket_netlink_util",
":socket_test_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/base:endian",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/base:endian",
- "@com_google_googletest//:gtest",
],
)
@@ -3162,12 +3222,12 @@ cc_binary(
"//test/util:cleanup",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -3178,11 +3238,11 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -3196,12 +3256,12 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -3214,10 +3274,10 @@ cc_binary(
"//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3227,10 +3287,10 @@ cc_binary(
srcs = ["sync.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3240,10 +3300,10 @@ cc_binary(
srcs = ["sysinfo.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -3253,9 +3313,9 @@ cc_binary(
srcs = ["syslog.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3265,10 +3325,10 @@ cc_binary(
srcs = ["sysret.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:logging",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3280,12 +3340,12 @@ cc_binary(
deps = [
":socket_test_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:posix_error",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -3295,11 +3355,11 @@ cc_binary(
srcs = ["tgkill.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3309,10 +3369,10 @@ cc_binary(
srcs = ["time.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:proc_util",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3337,15 +3397,15 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:cleanup",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:multiprocess_util",
"//test/util:posix_error",
"//test/util:signal_util",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -3355,11 +3415,11 @@ cc_binary(
srcs = ["tkill.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:logging",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3373,11 +3433,30 @@ cc_binary(
"//test/util:capability_util",
"//test/util:cleanup",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "tuntap_test",
+ testonly = 1,
+ srcs = ["tuntap.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ gtest,
+ "//test/syscalls/linux:socket_netlink_route_util",
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "//test/util:posix_error",
+ "//test/util:test_main",
+ "//test/util:test_util",
"@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -3393,12 +3472,12 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -3421,9 +3500,9 @@ cc_binary(
deps = [
":socket_test_util",
"//test/util:file_descriptor",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3434,14 +3513,14 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:capability_util",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:uid_util",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -3452,11 +3531,11 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:capability_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -3469,11 +3548,11 @@ cc_binary(
"//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -3483,11 +3562,11 @@ cc_binary(
srcs = ["unshare.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/synchronization",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/synchronization",
- "@com_google_googletest//:gtest",
],
)
@@ -3513,11 +3592,11 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:fs_util",
+ gtest,
"//test/util:posix_error",
"//test/util:proc_util",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3527,13 +3606,13 @@ cc_binary(
srcs = ["vfork.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:multiprocess_util",
"//test/util:test_util",
"//test/util:time_util",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -3545,6 +3624,10 @@ cc_binary(
deps = [
"//test/util:cleanup",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:multiprocess_util",
"//test/util:posix_error",
@@ -3553,10 +3636,6 @@ cc_binary(
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:time_util",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -3567,10 +3646,10 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:cleanup",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3581,30 +3660,47 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ gtest,
"//test/util:posix_error",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:str_format",
- "@com_google_googletest//:gtest",
],
)
cc_binary(
- name = "semaphore_test",
+ name = "network_namespace_test",
testonly = 1,
- srcs = ["semaphore.cc"],
+ srcs = ["network_namespace.cc"],
linkstatic = 1,
deps = [
+ ":socket_test_util",
+ gtest,
"//test/util:capability_util",
+ "//test/util:memory_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+cc_binary(
+ name = "semaphore_test",
+ testonly = 1,
+ srcs = ["semaphore.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
],
)
@@ -3630,10 +3726,10 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3643,11 +3739,11 @@ cc_binary(
srcs = ["vdso_clock_gettime.cc"],
linkstatic = 1,
deps = [
- "//test/util:test_main",
- "//test/util:test_util",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
],
)
@@ -3657,10 +3753,10 @@ cc_binary(
srcs = ["vsyscall.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:proc_util",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3673,11 +3769,11 @@ cc_binary(
":unix_domain_socket_test_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
- "//test/util:test_main",
- "//test/util:test_util",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
- "@com_google_googletest//:gtest",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
],
)
@@ -3689,12 +3785,12 @@ cc_binary(
deps = [
"//test/util:file_descriptor",
"//test/util:fs_util",
+ gtest,
"//test/util:memory_util",
"//test/util:multiprocess_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3706,10 +3802,10 @@ cc_binary(
deps = [
":ip_socket_test_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -3721,10 +3817,10 @@ cc_binary(
deps = [
":ip_socket_test_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -3740,11 +3836,12 @@ cc_binary(
"//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
diff --git a/test/syscalls/linux/alarm.cc b/test/syscalls/linux/alarm.cc
index d89269985..940c97285 100644
--- a/test/syscalls/linux/alarm.cc
+++ b/test/syscalls/linux/alarm.cc
@@ -188,6 +188,5 @@ int main(int argc, char** argv) {
TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
gvisor::testing::TestInit(&argc, &argv);
-
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/bad.cc b/test/syscalls/linux/bad.cc
index f246a799e..adfb149df 100644
--- a/test/syscalls/linux/bad.cc
+++ b/test/syscalls/linux/bad.cc
@@ -22,11 +22,17 @@ namespace gvisor {
namespace testing {
namespace {
+#ifdef __x86_64__
+// get_kernel_syms is not supported in Linux > 2.6, and not implemented in
+// gVisor.
+constexpr uint32_t kNotImplementedSyscall = SYS_get_kernel_syms;
+#elif __aarch64__
+// Use the last of arch_specific_syscalls which are not implemented on arm64.
+constexpr uint32_t kNotImplementedSyscall = SYS_arch_specific_syscall + 15;
+#endif
TEST(BadSyscallTest, NotImplemented) {
- // get_kernel_syms is not supported in Linux > 2.6, and not implemented in
- // gVisor.
- EXPECT_THAT(syscall(SYS_get_kernel_syms), SyscallFailsWithErrno(ENOSYS));
+ EXPECT_THAT(syscall(kNotImplementedSyscall), SyscallFailsWithErrno(ENOSYS));
}
TEST(BadSyscallTest, NegativeOne) {
diff --git a/test/syscalls/linux/chroot.cc b/test/syscalls/linux/chroot.cc
index 04bc2d7b9..85ec013d5 100644
--- a/test/syscalls/linux/chroot.cc
+++ b/test/syscalls/linux/chroot.cc
@@ -162,12 +162,12 @@ TEST(ChrootTest, DotDotFromOpenFD) {
// getdents on fd should not error.
char buf[1024];
- ASSERT_THAT(syscall(SYS_getdents, fd.get(), buf, sizeof(buf)),
+ ASSERT_THAT(syscall(SYS_getdents64, fd.get(), buf, sizeof(buf)),
SyscallSucceeds());
}
// Test that link resolution in a chroot can escape the root by following an
-// open proc fd.
+// open proc fd. Regression test for b/32316719.
TEST(ChrootTest, ProcFdLinkResolutionInChroot) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT)));
diff --git a/test/syscalls/linux/concurrency.cc b/test/syscalls/linux/concurrency.cc
index f41f99900..7cd6a75bd 100644
--- a/test/syscalls/linux/concurrency.cc
+++ b/test/syscalls/linux/concurrency.cc
@@ -46,7 +46,8 @@ TEST(ConcurrencyTest, SingleProcessMultithreaded) {
}
// Test that multiple threads in this process continue to execute in parallel,
-// even if an unrelated second process is spawned.
+// even if an unrelated second process is spawned. Regression test for
+// b/32119508.
TEST(ConcurrencyTest, MultiProcessMultithreaded) {
// In PID 1, start TIDs 1 and 2, and put both to sleep.
//
diff --git a/test/syscalls/linux/dev.cc b/test/syscalls/linux/dev.cc
index 4dd302eed..4e473268c 100644
--- a/test/syscalls/linux/dev.cc
+++ b/test/syscalls/linux/dev.cc
@@ -153,6 +153,13 @@ TEST(DevTest, TTYExists) {
EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666);
}
+TEST(DevTest, NetTunExists) {
+ struct stat statbuf = {};
+ ASSERT_THAT(stat("/dev/net/tun", &statbuf), SyscallSucceeds());
+ // Check that it's a character device with rw-rw-rw- permissions.
+ EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666);
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/exec.cc b/test/syscalls/linux/exec.cc
index b5e0a512b..07bd527e6 100644
--- a/test/syscalls/linux/exec.cc
+++ b/test/syscalls/linux/exec.cc
@@ -868,6 +868,5 @@ int main(int argc, char** argv) {
}
gvisor::testing::TestInit(&argc, &argv);
-
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/exec_proc_exe_workload.cc b/test/syscalls/linux/exec_proc_exe_workload.cc
index b790fe5be..2989379b7 100644
--- a/test/syscalls/linux/exec_proc_exe_workload.cc
+++ b/test/syscalls/linux/exec_proc_exe_workload.cc
@@ -21,6 +21,12 @@
#include "test/util/posix_error.h"
int main(int argc, char** argv, char** envp) {
+ // This is annoying. Because remote build systems may put these binaries
+ // in a content-addressable-store, you may wind up with /proc/self/exe
+ // pointing to some random path (but with a sensible argv[0]).
+ //
+ // Therefore, this test simply checks that the /proc/self/exe
+ // is absolute and *doesn't* match argv[1].
std::string exe =
gvisor::testing::ProcessExePath(getpid()).ValueOrDie();
if (exe[0] != '/') {
diff --git a/test/syscalls/linux/fallocate.cc b/test/syscalls/linux/fallocate.cc
index 1c3d00287..7819f4ac3 100644
--- a/test/syscalls/linux/fallocate.cc
+++ b/test/syscalls/linux/fallocate.cc
@@ -33,7 +33,7 @@ namespace testing {
namespace {
int fallocate(int fd, int mode, off_t offset, off_t len) {
- return syscall(__NR_fallocate, fd, mode, offset, len);
+ return RetryEINTR(syscall)(__NR_fallocate, fd, mode, offset, len);
}
class AllocateTest : public FileTest {
@@ -47,27 +47,27 @@ TEST_F(AllocateTest, Fallocate) {
EXPECT_EQ(buf.st_size, 0);
// Grow to ten bytes.
- EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 0, 10), SyscallSucceeds());
+ ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 10), SyscallSucceeds());
ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
EXPECT_EQ(buf.st_size, 10);
// Allocate to a smaller size should be noop.
- EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 0, 5), SyscallSucceeds());
+ ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 5), SyscallSucceeds());
ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
EXPECT_EQ(buf.st_size, 10);
// Grow again.
- EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 0, 20), SyscallSucceeds());
+ ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 20), SyscallSucceeds());
ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
EXPECT_EQ(buf.st_size, 20);
// Grow with offset.
- EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 10, 20), SyscallSucceeds());
+ ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 10, 20), SyscallSucceeds());
ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
EXPECT_EQ(buf.st_size, 30);
// Grow with offset beyond EOF.
- EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 39, 1), SyscallSucceeds());
+ ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 39, 1), SyscallSucceeds());
ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
EXPECT_EQ(buf.st_size, 40);
}
diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc
index 4f3aa81d6..c7cc5816e 100644
--- a/test/syscalls/linux/fcntl.cc
+++ b/test/syscalls/linux/fcntl.cc
@@ -31,6 +31,7 @@
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/cleanup.h"
#include "test/util/eventfd_util.h"
+#include "test/util/fs_util.h"
#include "test/util/multiprocess_util.h"
#include "test/util/posix_error.h"
#include "test/util/save_util.h"
@@ -55,10 +56,6 @@ ABSL_FLAG(int32_t, socket_fd, -1,
namespace gvisor {
namespace testing {
-// O_LARGEFILE as defined by Linux. glibc tries to be clever by setting it to 0
-// because "it isn't needed", even though Linux can return it via F_GETFL.
-constexpr int kOLargeFile = 00100000;
-
class FcntlLockTest : public ::testing::Test {
public:
void SetUp() override {
@@ -1131,5 +1128,5 @@ int main(int argc, char** argv) {
exit(err);
}
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/fork.cc b/test/syscalls/linux/fork.cc
index 371890110..ff8bdfeb0 100644
--- a/test/syscalls/linux/fork.cc
+++ b/test/syscalls/linux/fork.cc
@@ -215,6 +215,8 @@ TEST_F(ForkTest, PrivateMapping) {
EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0));
}
+// CPUID is x86 specific.
+#ifdef __x86_64__
// Test that cpuid works after a fork.
TEST_F(ForkTest, Cpuid) {
pid_t child = Fork();
@@ -227,6 +229,7 @@ TEST_F(ForkTest, Cpuid) {
}
EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0));
}
+#endif
TEST_F(ForkTest, Mmap) {
pid_t child = Fork();
@@ -268,7 +271,7 @@ TEST_F(ForkTest, Alarm) {
EXPECT_EQ(0, alarmed);
}
-// Child cannot affect parent private memory.
+// Child cannot affect parent private memory. Regression test for b/24137240.
TEST_F(ForkTest, PrivateMemory) {
std::atomic<uint32_t> local(0);
@@ -295,6 +298,9 @@ TEST_F(ForkTest, PrivateMemory) {
}
// Kernel-accessed buffers should remain coherent across COW.
+//
+// The buffer must be >= usermem.ZeroCopyMinBytes, as UnsafeAccess operates
+// differently. Regression test for b/33811887.
TEST_F(ForkTest, COWSegment) {
constexpr int kBufSize = 1024;
char* read_buf = private_;
diff --git a/test/syscalls/linux/getdents.cc b/test/syscalls/linux/getdents.cc
index ad2dbacb8..b147d6181 100644
--- a/test/syscalls/linux/getdents.cc
+++ b/test/syscalls/linux/getdents.cc
@@ -228,19 +228,28 @@ class GetdentsTest : public ::testing::Test {
// Multiple template parameters are not allowed, so we must use explicit
// template specialization to set the syscall number.
+
+// SYS_getdents isn't defined on arm64.
+#ifdef __x86_64__
template <>
int GetdentsTest<struct linux_dirent>::SyscallNum() {
return SYS_getdents;
}
+#endif
template <>
int GetdentsTest<struct linux_dirent64>::SyscallNum() {
return SYS_getdents64;
}
-// Test both legacy getdents and getdents64.
+#ifdef __x86_64__
+// Test both legacy getdents and getdents64 on x86_64.
typedef ::testing::Types<struct linux_dirent, struct linux_dirent64>
GetdentsTypes;
+#elif __aarch64__
+// Test only getdents64 on arm64.
+typedef ::testing::Types<struct linux_dirent64> GetdentsTypes;
+#endif
TYPED_TEST_SUITE(GetdentsTest, GetdentsTypes);
// N.B. TYPED_TESTs require explicitly using this-> to access members of
diff --git a/test/syscalls/linux/inotify.cc b/test/syscalls/linux/inotify.cc
index fdef646eb..0e13ad190 100644
--- a/test/syscalls/linux/inotify.cc
+++ b/test/syscalls/linux/inotify.cc
@@ -1055,9 +1055,9 @@ TEST(Inotify, ChmodGeneratesAttribEvent_NoRandomSave) {
const TempPath file1 =
ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
- const FileDescriptor root_fd =
+ FileDescriptor root_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(root.path(), O_RDONLY));
- const FileDescriptor file1_fd =
+ FileDescriptor file1_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDWR));
FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
@@ -1091,6 +1091,11 @@ TEST(Inotify, ChmodGeneratesAttribEvent_NoRandomSave) {
ASSERT_THAT(fchmodat(root_fd.get(), file1_basename.c_str(), S_IWGRP, 0),
SyscallSucceeds());
verify_chmod_events();
+
+ // Make sure the chmod'ed file descriptors are destroyed before DisableSave
+ // is destructed.
+ root_fd.reset();
+ file1_fd.reset();
}
TEST(Inotify, TruncateGeneratesModifyEvent) {
diff --git a/test/syscalls/linux/ip_socket_test_util.cc b/test/syscalls/linux/ip_socket_test_util.cc
index 6b472eb2f..bba022a41 100644
--- a/test/syscalls/linux/ip_socket_test_util.cc
+++ b/test/syscalls/linux/ip_socket_test_util.cc
@@ -79,6 +79,33 @@ SocketPairKind DualStackTCPAcceptBindSocketPair(int type) {
/* dual_stack = */ true)};
}
+SocketPairKind IPv6TCPAcceptBindPersistentListenerSocketPair(int type) {
+ std::string description =
+ absl::StrCat(DescribeSocketType(type), "connected IPv6 TCP socket");
+ return SocketPairKind{description, AF_INET6, type | SOCK_STREAM, IPPROTO_TCP,
+ TCPAcceptBindPersistentListenerSocketPairCreator(
+ AF_INET6, type | SOCK_STREAM, 0,
+ /* dual_stack = */ false)};
+}
+
+SocketPairKind IPv4TCPAcceptBindPersistentListenerSocketPair(int type) {
+ std::string description =
+ absl::StrCat(DescribeSocketType(type), "connected IPv4 TCP socket");
+ return SocketPairKind{description, AF_INET, type | SOCK_STREAM, IPPROTO_TCP,
+ TCPAcceptBindPersistentListenerSocketPairCreator(
+ AF_INET, type | SOCK_STREAM, 0,
+ /* dual_stack = */ false)};
+}
+
+SocketPairKind DualStackTCPAcceptBindPersistentListenerSocketPair(int type) {
+ std::string description =
+ absl::StrCat(DescribeSocketType(type), "connected dual stack TCP socket");
+ return SocketPairKind{description, AF_INET6, type | SOCK_STREAM, IPPROTO_TCP,
+ TCPAcceptBindPersistentListenerSocketPairCreator(
+ AF_INET6, type | SOCK_STREAM, 0,
+ /* dual_stack = */ true)};
+}
+
SocketPairKind IPv6UDPBidirectionalBindSocketPair(int type) {
std::string description =
absl::StrCat(DescribeSocketType(type), "connected IPv6 UDP socket");
diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h
index 0f58e0f77..39fd6709d 100644
--- a/test/syscalls/linux/ip_socket_test_util.h
+++ b/test/syscalls/linux/ip_socket_test_util.h
@@ -50,6 +50,21 @@ SocketPairKind IPv4TCPAcceptBindSocketPair(int type);
// given type bound to the IPv4 loopback.
SocketPairKind DualStackTCPAcceptBindSocketPair(int type);
+// IPv6TCPAcceptBindPersistentListenerSocketPair is like
+// IPv6TCPAcceptBindSocketPair except it uses a persistent listening socket to
+// create all socket pairs.
+SocketPairKind IPv6TCPAcceptBindPersistentListenerSocketPair(int type);
+
+// IPv4TCPAcceptBindPersistentListenerSocketPair is like
+// IPv4TCPAcceptBindSocketPair except it uses a persistent listening socket to
+// create all socket pairs.
+SocketPairKind IPv4TCPAcceptBindPersistentListenerSocketPair(int type);
+
+// DualStackTCPAcceptBindPersistentListenerSocketPair is like
+// DualStackTCPAcceptBindSocketPair except it uses a persistent listening socket
+// to create all socket pairs.
+SocketPairKind DualStackTCPAcceptBindPersistentListenerSocketPair(int type);
+
// IPv6UDPBidirectionalBindSocketPair returns a SocketPairKind that represents
// SocketPairs created with bind() and connect() syscalls with AF_INET6 and the
// given type bound to the IPv6 loopback.
@@ -69,20 +84,20 @@ SocketPairKind DualStackUDPBidirectionalBindSocketPair(int type);
// SocketPairs created with AF_INET and the given type.
SocketPairKind IPv4UDPUnboundSocketPair(int type);
-// IPv4UDPUnboundSocketPair returns a SocketKind that represents
-// a SimpleSocket created with AF_INET, SOCK_DGRAM, and the given type.
+// IPv4UDPUnboundSocket returns a SocketKind that represents a SimpleSocket
+// created with AF_INET, SOCK_DGRAM, and the given type.
SocketKind IPv4UDPUnboundSocket(int type);
-// IPv6UDPUnboundSocketPair returns a SocketKind that represents
-// a SimpleSocket created with AF_INET6, SOCK_DGRAM, and the given type.
+// IPv6UDPUnboundSocket returns a SocketKind that represents a SimpleSocket
+// created with AF_INET6, SOCK_DGRAM, and the given type.
SocketKind IPv6UDPUnboundSocket(int type);
-// IPv4TCPUnboundSocketPair returns a SocketKind that represents
-// a SimpleSocket created with AF_INET, SOCK_STREAM and the given type.
+// IPv4TCPUnboundSocket returns a SocketKind that represents a SimpleSocket
+// created with AF_INET, SOCK_STREAM and the given type.
SocketKind IPv4TCPUnboundSocket(int type);
-// IPv6TCPUnboundSocketPair returns a SocketKind that represents
-// a SimpleSocket created with AF_INET6, SOCK_STREAM and the given type.
+// IPv6TCPUnboundSocket returns a SocketKind that represents a SimpleSocket
+// created with AF_INET6, SOCK_STREAM and the given type.
SocketKind IPv6TCPUnboundSocket(int type);
// IfAddrHelper is a helper class that determines the local interfaces present
diff --git a/test/syscalls/linux/itimer.cc b/test/syscalls/linux/itimer.cc
index b77e4cbd1..8b48f0804 100644
--- a/test/syscalls/linux/itimer.cc
+++ b/test/syscalls/linux/itimer.cc
@@ -349,6 +349,5 @@ int main(int argc, char** argv) {
}
gvisor::testing::TestInit(&argc, &argv);
-
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/mmap.cc b/test/syscalls/linux/mmap.cc
index 1c4d9f1c7..11fb1b457 100644
--- a/test/syscalls/linux/mmap.cc
+++ b/test/syscalls/linux/mmap.cc
@@ -1418,7 +1418,7 @@ TEST_P(MMapFileParamTest, NoSigBusOnPageContainingEOF) {
//
// On most platforms this is trivial, but when the file is mapped via the sentry
// page cache (which does not yet support writing to shared mappings), a bug
-// caused reads to fail unnecessarily on such mappings.
+// caused reads to fail unnecessarily on such mappings. See b/28913513.
TEST_F(MMapFileTest, ReadingWritableSharedFilePageSucceeds) {
uintptr_t addr;
size_t len = strlen(kFileContents);
@@ -1435,7 +1435,7 @@ TEST_F(MMapFileTest, ReadingWritableSharedFilePageSucceeds) {
// Tests that EFAULT is returned when invoking a syscall that requires the OS to
// read past end of file (resulting in a fault in sentry context in the gVisor
-// case).
+// case). See b/28913513.
TEST_F(MMapFileTest, InternalSigBus) {
uintptr_t addr;
ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE,
@@ -1578,7 +1578,7 @@ TEST_F(MMapFileTest, Bug38498194) {
}
// Tests that reading from a file to a memory mapping of the same file does not
-// deadlock.
+// deadlock. See b/34813270.
TEST_F(MMapFileTest, SelfRead) {
uintptr_t addr;
ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED,
@@ -1590,7 +1590,7 @@ TEST_F(MMapFileTest, SelfRead) {
}
// Tests that writing to a file from a memory mapping of the same file does not
-// deadlock.
+// deadlock. Regression test for b/34813270.
TEST_F(MMapFileTest, SelfWrite) {
uintptr_t addr;
ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0),
diff --git a/test/syscalls/linux/network_namespace.cc b/test/syscalls/linux/network_namespace.cc
new file mode 100644
index 000000000..6ea48c263
--- /dev/null
+++ b/test/syscalls/linux/network_namespace.cc
@@ -0,0 +1,121 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <net/if.h>
+#include <sched.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/synchronization/notification.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/memory_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+using TestFunc = std::function<PosixError()>;
+using RunFunc = std::function<PosixError(TestFunc)>;
+
+struct NamespaceStrategy {
+ RunFunc run;
+
+ static NamespaceStrategy Of(RunFunc run) {
+ NamespaceStrategy s;
+ s.run = run;
+ return s;
+ }
+};
+
+PosixError RunWithUnshare(TestFunc fn) {
+ PosixError err = PosixError(-1, "function did not return a value");
+ ScopedThread t([&] {
+ if (unshare(CLONE_NEWNET) != 0) {
+ err = PosixError(errno);
+ return;
+ }
+ err = fn();
+ });
+ t.Join();
+ return err;
+}
+
+PosixError RunWithClone(TestFunc fn) {
+ struct Args {
+ absl::Notification n;
+ TestFunc fn;
+ PosixError err;
+ };
+ Args args;
+ args.fn = fn;
+ args.err = PosixError(-1, "function did not return a value");
+
+ ASSIGN_OR_RETURN_ERRNO(
+ Mapping child_stack,
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ pid_t child = clone(
+ +[](void *arg) {
+ Args *args = reinterpret_cast<Args *>(arg);
+ args->err = args->fn();
+ args->n.Notify();
+ syscall(SYS_exit, 0); // Exit manually. No return address on stack.
+ return 0;
+ },
+ reinterpret_cast<void *>(child_stack.addr() + kPageSize),
+ CLONE_NEWNET | CLONE_THREAD | CLONE_SIGHAND | CLONE_VM, &args);
+ if (child < 0) {
+ return PosixError(errno, "clone() failed");
+ }
+ args.n.WaitForNotification();
+ return args.err;
+}
+
+class NetworkNamespaceTest
+ : public ::testing::TestWithParam<NamespaceStrategy> {};
+
+TEST_P(NetworkNamespaceTest, LoopbackExists) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ EXPECT_NO_ERRNO(GetParam().run([]() {
+ // TODO(gvisor.dev/issue/1833): Update this to test that only "lo" exists.
+ // Check loopback device exists.
+ int sock = socket(AF_INET, SOCK_DGRAM, 0);
+ if (sock < 0) {
+ return PosixError(errno, "socket() failed");
+ }
+ struct ifreq ifr;
+ snprintf(ifr.ifr_name, IFNAMSIZ, "lo");
+ if (ioctl(sock, SIOCGIFINDEX, &ifr) < 0) {
+ return PosixError(errno, "ioctl() failed, lo cannot be found");
+ }
+ return NoError();
+ }));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllNetworkNamespaceTest, NetworkNamespaceTest,
+ ::testing::Values(NamespaceStrategy::Of(RunWithUnshare),
+ NamespaceStrategy::Of(RunWithClone)));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc
index 431733dbe..902d0a0dc 100644
--- a/test/syscalls/linux/open_create.cc
+++ b/test/syscalls/linux/open_create.cc
@@ -132,6 +132,7 @@ TEST(CreateTest, CreateFailsOnDirWithoutWritePerms) {
}
// A file originally created RW, but opened RO can later be opened RW.
+// Regression test for b/65385065.
TEST(CreateTest, OpenCreateROThenRW) {
TempPath file(NewTempAbsPath());
diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc
index 92ae55eec..bc22de788 100644
--- a/test/syscalls/linux/packet_socket.cc
+++ b/test/syscalls/linux/packet_socket.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include <arpa/inet.h>
+#include <ifaddrs.h>
#include <linux/capability.h>
#include <linux/if_arp.h>
#include <linux/if_packet.h>
@@ -163,16 +164,11 @@ int CookedPacketTest::GetLoopbackIndex() {
return ifr.ifr_ifindex;
}
-// Receive via a packet socket.
-TEST_P(CookedPacketTest, Receive) {
- // Let's use a simple IP payload: a UDP datagram.
- FileDescriptor udp_sock =
- ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
- SendUDPMessage(udp_sock.get());
-
+// Receive and verify the message via packet socket on interface.
+void ReceiveMessage(int sock, int ifindex) {
// Wait for the socket to become readable.
struct pollfd pfd = {};
- pfd.fd = socket_;
+ pfd.fd = sock;
pfd.events = POLLIN;
EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1));
@@ -182,9 +178,10 @@ TEST_P(CookedPacketTest, Receive) {
char buf[64];
struct sockaddr_ll src = {};
socklen_t src_len = sizeof(src);
- ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0,
+ ASSERT_THAT(recvfrom(sock, buf, sizeof(buf), 0,
reinterpret_cast<struct sockaddr*>(&src), &src_len),
SyscallSucceedsWithValue(packet_size));
+
// sockaddr_ll ends with an 8 byte physical address field, but ethernet
// addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2
// here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns
@@ -194,7 +191,7 @@ TEST_P(CookedPacketTest, Receive) {
// TODO(b/129292371): Verify protocol once we return it.
// Verify the source address.
EXPECT_EQ(src.sll_family, AF_PACKET);
- EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex());
+ EXPECT_EQ(src.sll_ifindex, ifindex);
EXPECT_EQ(src.sll_halen, ETH_ALEN);
// This came from the loopback device, so the address is all 0s.
for (int i = 0; i < src.sll_halen; i++) {
@@ -222,6 +219,18 @@ TEST_P(CookedPacketTest, Receive) {
EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0);
}
+// Receive via a packet socket.
+TEST_P(CookedPacketTest, Receive) {
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ SendUDPMessage(udp_sock.get());
+
+ // Receive and verify the data.
+ int loopback_index = GetLoopbackIndex();
+ ReceiveMessage(socket_, loopback_index);
+}
+
// Send via a packet socket.
TEST_P(CookedPacketTest, Send) {
// TODO(b/129292371): Remove once we support packet socket writing.
@@ -313,6 +322,101 @@ TEST_P(CookedPacketTest, Send) {
EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK));
}
+// Bind and receive via packet socket.
+TEST_P(CookedPacketTest, BindReceive) {
+ struct sockaddr_ll bind_addr = {};
+ bind_addr.sll_family = AF_PACKET;
+ bind_addr.sll_protocol = htons(GetParam());
+ bind_addr.sll_ifindex = GetLoopbackIndex();
+
+ ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallSucceeds());
+
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ SendUDPMessage(udp_sock.get());
+
+ // Receive and verify the data.
+ ReceiveMessage(socket_, bind_addr.sll_ifindex);
+}
+
+// Double Bind socket.
+TEST_P(CookedPacketTest, DoubleBind) {
+ struct sockaddr_ll bind_addr = {};
+ bind_addr.sll_family = AF_PACKET;
+ bind_addr.sll_protocol = htons(GetParam());
+ bind_addr.sll_ifindex = GetLoopbackIndex();
+
+ ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallSucceeds());
+
+ // Binding socket again should fail.
+ ASSERT_THAT(
+ bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ // Linux 4.09 returns EINVAL here, but some time before 4.19 it switched
+ // to EADDRINUSE.
+ AnyOf(SyscallFailsWithErrno(EADDRINUSE), SyscallFailsWithErrno(EINVAL)));
+}
+
+// Bind and verify we do not receive data on interface which is not bound
+TEST_P(CookedPacketTest, BindDrop) {
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+
+ struct ifaddrs* if_addr_list = nullptr;
+ auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); });
+
+ ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds());
+
+ // Get interface other than loopback.
+ struct ifreq ifr = {};
+ for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) {
+ if (strcmp(i->ifa_name, "lo") != 0) {
+ strncpy(ifr.ifr_name, i->ifa_name, sizeof(ifr.ifr_name));
+ break;
+ }
+ }
+
+ // Skip if no interface is available other than loopback.
+ if (strlen(ifr.ifr_name) == 0) {
+ GTEST_SKIP();
+ }
+
+ // Get interface index.
+ EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds());
+ EXPECT_NE(ifr.ifr_ifindex, 0);
+
+ // Bind to packet socket requires only family, protocol and ifindex.
+ struct sockaddr_ll bind_addr = {};
+ bind_addr.sll_family = AF_PACKET;
+ bind_addr.sll_protocol = htons(GetParam());
+ bind_addr.sll_ifindex = ifr.ifr_ifindex;
+
+ ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallSucceeds());
+
+ // Send to loopback interface.
+ struct sockaddr_in dest = {};
+ dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ dest.sin_family = AF_INET;
+ dest.sin_port = kPort;
+ EXPECT_THAT(sendto(udp_sock.get(), kMessage, sizeof(kMessage), 0,
+ reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
+ SyscallSucceedsWithValue(sizeof(kMessage)));
+
+ // Wait and make sure the socket never receives any data.
+ struct pollfd pfd = {};
+ pfd.fd = socket_;
+ pfd.events = POLLIN;
+ EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0));
+}
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, CookedPacketTest,
::testing::Values(ETH_P_IP, ETH_P_ALL));
diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc
index ac9b21b24..d8e19e910 100644
--- a/test/syscalls/linux/pipe.cc
+++ b/test/syscalls/linux/pipe.cc
@@ -25,6 +25,7 @@
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
#include "test/util/posix_error.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
@@ -144,11 +145,10 @@ TEST_P(PipeTest, Flags) {
if (IsNamedPipe()) {
// May be stubbed to zero; define locally.
- constexpr int kLargefile = 0100000;
EXPECT_THAT(fcntl(rfd_.get(), F_GETFL),
- SyscallSucceedsWithValue(kLargefile | O_RDONLY));
+ SyscallSucceedsWithValue(kOLargeFile | O_RDONLY));
EXPECT_THAT(fcntl(wfd_.get(), F_GETFL),
- SyscallSucceedsWithValue(kLargefile | O_WRONLY));
+ SyscallSucceedsWithValue(kOLargeFile | O_WRONLY));
} else {
EXPECT_THAT(fcntl(rfd_.get(), F_GETFL), SyscallSucceedsWithValue(O_RDONLY));
EXPECT_THAT(fcntl(wfd_.get(), F_GETFL), SyscallSucceedsWithValue(O_WRONLY));
diff --git a/test/syscalls/linux/prctl.cc b/test/syscalls/linux/prctl.cc
index d07571a5f..04c5161f5 100644
--- a/test/syscalls/linux/prctl.cc
+++ b/test/syscalls/linux/prctl.cc
@@ -226,5 +226,5 @@ int main(int argc, char** argv) {
prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0));
}
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/prctl_setuid.cc b/test/syscalls/linux/prctl_setuid.cc
index 30f0d75b3..c4e9cf528 100644
--- a/test/syscalls/linux/prctl_setuid.cc
+++ b/test/syscalls/linux/prctl_setuid.cc
@@ -264,5 +264,5 @@ int main(int argc, char** argv) {
prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0);
}
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/preadv.cc b/test/syscalls/linux/preadv.cc
index f7ea44054..5b0743fe9 100644
--- a/test/syscalls/linux/preadv.cc
+++ b/test/syscalls/linux/preadv.cc
@@ -37,6 +37,7 @@ namespace testing {
namespace {
+// Stress copy-on-write. Attempts to reproduce b/38430174.
TEST(PreadvTest, MMConcurrencyStress) {
// Fill a one-page file with zeroes (the contents don't really matter).
const auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
diff --git a/test/syscalls/linux/preadv2.cc b/test/syscalls/linux/preadv2.cc
index cd936ea90..4a9acd7ae 100644
--- a/test/syscalls/linux/preadv2.cc
+++ b/test/syscalls/linux/preadv2.cc
@@ -35,6 +35,8 @@ namespace {
#ifndef SYS_preadv2
#if defined(__x86_64__)
#define SYS_preadv2 327
+#elif defined(__aarch64__)
+#define SYS_preadv2 286
#else
#error "Unknown architecture"
#endif
diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc
index bf9bb45d3..f91187e75 100644
--- a/test/syscalls/linux/proc.cc
+++ b/test/syscalls/linux/proc.cc
@@ -100,18 +100,6 @@ namespace {
#define SUID_DUMP_ROOT 2
#endif /* SUID_DUMP_ROOT */
-// O_LARGEFILE as defined by Linux. glibc tries to be clever by setting it to 0
-// because "it isn't needed", even though Linux can return it via F_GETFL.
-#if defined(__x86_64__) || defined(__i386__)
-constexpr int kOLargeFile = 00100000;
-#elif __aarch64__
-// The value originate from the Linux
-// kernel's arch/arm64/include/uapi/asm/fcntl.h.
-constexpr int kOLargeFile = 00400000;
-#else
-#error "Unknown architecture"
-#endif
-
#if defined(__x86_64__) || defined(__i386__)
// This list of "required" fields is taken from reading the file
// arch/x86/kernel/cpu/proc.c and seeing which fields will be unconditionally
@@ -1364,13 +1352,19 @@ TEST(ProcPidSymlink, SubprocessZombied) {
// FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux
// on proc files.
- // 4.17 & gVisor: Syscall succeeds and returns 1
+ //
+ // ~4.3: Syscall fails with EACCES.
+ // 4.17 & gVisor: Syscall succeeds and returns 1.
+ //
// EXPECT_THAT(ReadlinkWhileZombied("ns/pid", buf, sizeof(buf)),
// SyscallFailsWithErrno(EACCES));
// FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux
// on proc files.
- // 4.17 & gVisor: Syscall succeeds and returns 1.
+ //
+ // ~4.3: Syscall fails with EACCES.
+ // 4.17 & gVisor: Syscall succeeds and returns 1.
+ //
// EXPECT_THAT(ReadlinkWhileZombied("ns/user", buf, sizeof(buf)),
// SyscallFailsWithErrno(EACCES));
}
@@ -1443,8 +1437,12 @@ TEST(ProcPidFile, SubprocessRunning) {
TEST(ProcPidFile, SubprocessZombie) {
char buf[1];
- // 4.17: Succeeds and returns 1
- // gVisor: Succeeds and returns 0
+ // FIXME(gvisor.dev/issue/164): Loosen requirement due to inconsistent
+ // behavior on different kernels.
+ //
+ // ~4.3: Succeds and returns 0.
+ // 4.17: Succeeds and returns 1.
+ // gVisor: Succeeds and returns 0.
EXPECT_THAT(ReadWhileZombied("auxv", buf, sizeof(buf)), SyscallSucceeds());
EXPECT_THAT(ReadWhileZombied("cmdline", buf, sizeof(buf)),
@@ -1470,7 +1468,10 @@ TEST(ProcPidFile, SubprocessZombie) {
// FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux
// on proc files.
+ //
+ // ~4.3: Fails and returns EACCES.
// gVisor & 4.17: Succeeds and returns 1.
+ //
// EXPECT_THAT(ReadWhileZombied("io", buf, sizeof(buf)),
// SyscallFailsWithErrno(EACCES));
}
@@ -1479,9 +1480,12 @@ TEST(ProcPidFile, SubprocessZombie) {
TEST(ProcPidFile, SubprocessExited) {
char buf[1];
- // FIXME(gvisor.dev/issue/164): Inconsistent behavior between kernels
+ // FIXME(gvisor.dev/issue/164): Inconsistent behavior between kernels.
+ //
+ // ~4.3: Fails and returns ESRCH.
// gVisor: Fails with ESRCH.
// 4.17: Succeeds and returns 1.
+ //
// EXPECT_THAT(ReadWhileExited("auxv", buf, sizeof(buf)),
// SyscallFailsWithErrno(ESRCH));
@@ -1653,7 +1657,7 @@ TEST(ProcTask, KilledThreadsDisappear) {
EXPECT_NO_ERRNO(DirContainsExactly("/proc/self/task",
TaskFiles(initial, {child1.Tid()})));
- // Stat child1's task file.
+ // Stat child1's task file. Regression test for b/32097707.
struct stat statbuf;
const std::string child1_task_file =
absl::StrCat("/proc/self/task/", child1.Tid());
@@ -1681,7 +1685,7 @@ TEST(ProcTask, KilledThreadsDisappear) {
EXPECT_NO_ERRNO(EventuallyDirContainsExactly(
"/proc/self/task", TaskFiles(initial, {child3.Tid(), child5.Tid()})));
- // Stat child1's task file again. This time it should fail.
+ // Stat child1's task file again. This time it should fail. See b/32097707.
EXPECT_THAT(stat(child1_task_file.c_str(), &statbuf),
SyscallFailsWithErrno(ENOENT));
@@ -1836,7 +1840,7 @@ TEST(ProcSysVmOvercommitMemory, HasNumericValue) {
}
// Check that link for proc fd entries point the target node, not the
-// symlink itself.
+// symlink itself. Regression test for b/31155070.
TEST(ProcTaskFd, FstatatFollowsSymlink) {
const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
const FileDescriptor fd =
@@ -1895,6 +1899,20 @@ TEST(ProcMounts, IsSymlink) {
EXPECT_EQ(link, "self/mounts");
}
+TEST(ProcSelfMountinfo, RequiredFieldsArePresent) {
+ auto mountinfo =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/mountinfo"));
+ EXPECT_THAT(
+ mountinfo,
+ AllOf(
+ // Root mount.
+ ContainsRegex(
+ R"([0-9]+ [0-9]+ [0-9]+:[0-9]+ / / (rw|ro).*- \S+ \S+ (rw|ro)\S*)"),
+ // Proc mount - always rw.
+ ContainsRegex(
+ R"([0-9]+ [0-9]+ [0-9]+:[0-9]+ / /proc rw.*- \S+ \S+ rw\S*)")));
+}
+
// Check that /proc/self/mounts looks something like a real mounts file.
TEST(ProcSelfMounts, RequiredFieldsArePresent) {
auto mounts = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/mounts"));
@@ -2006,7 +2024,7 @@ TEST(Proc, GetdentsEnoent) {
},
nullptr, nullptr));
char buf[1024];
- ASSERT_THAT(syscall(SYS_getdents, fd.get(), buf, sizeof(buf)),
+ ASSERT_THAT(syscall(SYS_getdents64, fd.get(), buf, sizeof(buf)),
SyscallFailsWithErrno(ENOENT));
}
@@ -2058,5 +2076,5 @@ int main(int argc, char** argv) {
}
gvisor::testing::TestInit(&argc, &argv);
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc
index 4dd5cf27b..bfe3e2603 100644
--- a/test/syscalls/linux/ptrace.cc
+++ b/test/syscalls/linux/ptrace.cc
@@ -1208,5 +1208,5 @@ int main(int argc, char** argv) {
gvisor::testing::RunExecveChild();
}
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/pwritev2.cc b/test/syscalls/linux/pwritev2.cc
index 1dbc0d6df..3fe5a600f 100644
--- a/test/syscalls/linux/pwritev2.cc
+++ b/test/syscalls/linux/pwritev2.cc
@@ -34,6 +34,8 @@ namespace {
#ifndef SYS_pwritev2
#if defined(__x86_64__)
#define SYS_pwritev2 328
+#elif defined(__aarch64__)
+#define SYS_pwritev2 287
#else
#error "Unknown architecture"
#endif
diff --git a/test/syscalls/linux/readv.cc b/test/syscalls/linux/readv.cc
index 4069cbc7e..baaf9f757 100644
--- a/test/syscalls/linux/readv.cc
+++ b/test/syscalls/linux/readv.cc
@@ -254,7 +254,9 @@ TEST_F(ReadvTest, IovecOutsideTaskAddressRangeInNonemptyArray) {
// This test depends on the maximum extent of a single readv() syscall, so
// we can't tolerate interruption from saving.
TEST(ReadvTestNoFixture, TruncatedAtMax_NoRandomSave) {
- // Ensure that we won't be interrupted by ITIMER_PROF.
+ // Ensure that we won't be interrupted by ITIMER_PROF. This is particularly
+ // important in environments where automated profiling tools may start
+ // ITIMER_PROF automatically.
struct itimerval itv = {};
auto const cleanup_itimer =
ASSERT_NO_ERRNO_AND_VALUE(ScopedItimer(ITIMER_PROF, itv));
diff --git a/test/syscalls/linux/rseq.cc b/test/syscalls/linux/rseq.cc
index 106c045e3..4bfb1ff56 100644
--- a/test/syscalls/linux/rseq.cc
+++ b/test/syscalls/linux/rseq.cc
@@ -36,7 +36,7 @@ namespace {
// We must be very careful about how these tests are written. Each thread may
// only have one struct rseq registration, which may be done automatically at
// thread start (as of 2019-11-13, glibc does *not* support rseq and thus does
-// not do so).
+// not do so, but other libraries do).
//
// Testing of rseq is thus done primarily in a child process with no
// registration. This means exec'ing a nostdlib binary, as rseq registration can
diff --git a/test/syscalls/linux/rseq/uapi.h b/test/syscalls/linux/rseq/uapi.h
index e3ff0579a..ca1d67691 100644
--- a/test/syscalls/linux/rseq/uapi.h
+++ b/test/syscalls/linux/rseq/uapi.h
@@ -15,14 +15,9 @@
#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_
#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_
-// User-kernel ABI for restartable sequences.
+#include <stdint.h>
-// Standard types.
-//
-// N.B. This header will be included in targets that do have the standard
-// library, so we can't shadow the standard type names.
-using __u32 = __UINT32_TYPE__;
-using __u64 = __UINT64_TYPE__;
+// User-kernel ABI for restartable sequences.
#ifdef __x86_64__
// Syscall numbers.
@@ -32,20 +27,20 @@ constexpr int kRseqSyscall = 334;
#endif // __x86_64__
struct rseq_cs {
- __u32 version;
- __u32 flags;
- __u64 start_ip;
- __u64 post_commit_offset;
- __u64 abort_ip;
-} __attribute__((aligned(4 * sizeof(__u64))));
+ uint32_t version;
+ uint32_t flags;
+ uint64_t start_ip;
+ uint64_t post_commit_offset;
+ uint64_t abort_ip;
+} __attribute__((aligned(4 * sizeof(uint64_t))));
// N.B. alignment is enforced by the kernel.
struct rseq {
- __u32 cpu_id_start;
- __u32 cpu_id;
+ uint32_t cpu_id_start;
+ uint32_t cpu_id;
struct rseq_cs* rseq_cs;
- __u32 flags;
-} __attribute__((aligned(4 * sizeof(__u64))));
+ uint32_t flags;
+} __attribute__((aligned(4 * sizeof(uint64_t))));
constexpr int kRseqFlagUnregister = 1 << 0;
diff --git a/test/syscalls/linux/rtsignal.cc b/test/syscalls/linux/rtsignal.cc
index 81d193ffd..ed27e2566 100644
--- a/test/syscalls/linux/rtsignal.cc
+++ b/test/syscalls/linux/rtsignal.cc
@@ -167,6 +167,5 @@ int main(int argc, char** argv) {
TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
gvisor::testing::TestInit(&argc, &argv);
-
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/seccomp.cc b/test/syscalls/linux/seccomp.cc
index 294ee6808..cf6499f8b 100644
--- a/test/syscalls/linux/seccomp.cc
+++ b/test/syscalls/linux/seccomp.cc
@@ -49,7 +49,12 @@ namespace testing {
namespace {
// A syscall not implemented by Linux that we don't expect to be called.
+#ifdef __x86_64__
constexpr uint32_t kFilteredSyscall = SYS_vserver;
+#elif __aarch64__
+// Use the last of arch_specific_syscalls which are not implemented on arm64.
+constexpr uint32_t kFilteredSyscall = SYS_arch_specific_syscall + 15;
+#endif
// Applies a seccomp-bpf filter that returns `filtered_result` for
// `sysno` and allows all other syscalls. Async-signal-safe.
@@ -406,5 +411,5 @@ int main(int argc, char** argv) {
}
gvisor::testing::TestInit(&argc, &argv);
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/select.cc b/test/syscalls/linux/select.cc
index 424e2a67f..be2364fb8 100644
--- a/test/syscalls/linux/select.cc
+++ b/test/syscalls/linux/select.cc
@@ -146,7 +146,7 @@ TEST_F(SelectTest, IgnoreBitsAboveNfds) {
// This test illustrates Linux's behavior of 'select' calls passing after
// setrlimit RLIMIT_NOFILE is called. In particular, versions of sshd rely on
-// this behavior.
+// this behavior. See b/122318458.
TEST_F(SelectTest, SetrlimitCallNOFILE) {
fd_set read_set;
FD_ZERO(&read_set);
diff --git a/test/syscalls/linux/shm.cc b/test/syscalls/linux/shm.cc
index 7ba752599..c7fdbb924 100644
--- a/test/syscalls/linux/shm.cc
+++ b/test/syscalls/linux/shm.cc
@@ -473,7 +473,7 @@ TEST(ShmTest, PartialUnmap) {
}
// Check that sentry does not panic when asked for a zero-length private shm
-// segment.
+// segment. Regression test for b/110694797.
TEST(ShmTest, GracefullyFailOnZeroLenSegmentCreation) {
EXPECT_THAT(Shmget(IPC_PRIVATE, 0, 0), PosixErrorIs(EINVAL, _));
}
diff --git a/test/syscalls/linux/sigiret.cc b/test/syscalls/linux/sigiret.cc
index 4deb1ae95..6227774a4 100644
--- a/test/syscalls/linux/sigiret.cc
+++ b/test/syscalls/linux/sigiret.cc
@@ -132,6 +132,5 @@ int main(int argc, char** argv) {
TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
gvisor::testing::TestInit(&argc, &argv);
-
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/signalfd.cc b/test/syscalls/linux/signalfd.cc
index 95be4b66c..389e5fca2 100644
--- a/test/syscalls/linux/signalfd.cc
+++ b/test/syscalls/linux/signalfd.cc
@@ -369,5 +369,5 @@ int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/sigprocmask.cc b/test/syscalls/linux/sigprocmask.cc
index 654c6a47f..a603fc1d1 100644
--- a/test/syscalls/linux/sigprocmask.cc
+++ b/test/syscalls/linux/sigprocmask.cc
@@ -237,7 +237,7 @@ TEST_F(SigProcMaskTest, SignalHandler) {
}
// Check that sigprocmask correctly handles aliasing of the set and oldset
-// pointers.
+// pointers. Regression test for b/30502311.
TEST_F(SigProcMaskTest, AliasedSets) {
sigset_t mask;
diff --git a/test/syscalls/linux/sigstop.cc b/test/syscalls/linux/sigstop.cc
index 7db57d968..b2fcedd62 100644
--- a/test/syscalls/linux/sigstop.cc
+++ b/test/syscalls/linux/sigstop.cc
@@ -147,5 +147,5 @@ int main(int argc, char** argv) {
return 1;
}
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/sigtimedwait.cc b/test/syscalls/linux/sigtimedwait.cc
index 1e5bf5942..4f8afff15 100644
--- a/test/syscalls/linux/sigtimedwait.cc
+++ b/test/syscalls/linux/sigtimedwait.cc
@@ -319,6 +319,5 @@ int main(int argc, char** argv) {
TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
gvisor::testing::TestInit(&argc, &argv);
-
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/socket_abstract.cc b/test/syscalls/linux/socket_abstract.cc
index 715d87b76..00999f192 100644
--- a/test/syscalls/linux/socket_abstract.cc
+++ b/test/syscalls/linux/socket_abstract.cc
@@ -23,6 +23,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return ApplyVec<SocketPairKind>(
@@ -43,5 +44,6 @@ INSTANTIATE_TEST_SUITE_P(
AbstractUnixSockets, UnixSocketPairCmsgTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_filesystem.cc b/test/syscalls/linux/socket_filesystem.cc
index 74e262959..287359363 100644
--- a/test/syscalls/linux/socket_filesystem.cc
+++ b/test/syscalls/linux/socket_filesystem.cc
@@ -23,6 +23,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return ApplyVec<SocketPairKind>(
@@ -43,5 +44,6 @@ INSTANTIATE_TEST_SUITE_P(
FilesystemUnixSockets, UnixSocketPairCmsgTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc
index e8f24a59e..f7d6139f1 100644
--- a/test/syscalls/linux/socket_generic.cc
+++ b/test/syscalls/linux/socket_generic.cc
@@ -447,6 +447,60 @@ TEST_P(AllSocketPairTest, RecvTimeoutRecvmsgSucceeds) {
SyscallFailsWithErrno(EAGAIN));
}
+TEST_P(AllSocketPairTest, SendTimeoutDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ timeval actual_tv = {.tv_sec = -1, .tv_usec = -1};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, 0);
+ EXPECT_EQ(actual_tv.tv_usec, 0);
+}
+
+TEST_P(AllSocketPairTest, SetGetSendTimeout) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ timeval tv = {.tv_sec = 89, .tv_usec = 42000};
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ timeval actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, 89);
+ EXPECT_EQ(actual_tv.tv_usec, 42000);
+}
+
+TEST_P(AllSocketPairTest, SetGetSendTimeoutLargerArg) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval_with_extra {
+ struct timeval tv;
+ int64_t extra_data;
+ } ABSL_ATTRIBUTE_PACKED;
+
+ timeval_with_extra tv_extra = {
+ .tv = {.tv_sec = 0, .tv_usec = 123000},
+ };
+
+ EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &tv_extra, sizeof(tv_extra)),
+ SyscallSucceeds());
+
+ timeval_with_extra actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv.tv_sec, 0);
+ EXPECT_EQ(actual_tv.tv.tv_usec, 123000);
+}
+
TEST_P(AllSocketPairTest, SendTimeoutAllowsWrite) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
@@ -491,18 +545,36 @@ TEST_P(AllSocketPairTest, SendTimeoutAllowsSendmsg) {
ASSERT_NO_FATAL_FAILURE(SendNullCmsg(sockets->first_fd(), buf, sizeof(buf)));
}
-TEST_P(AllSocketPairTest, SoRcvTimeoIsSet) {
+TEST_P(AllSocketPairTest, RecvTimeoutDefault) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
- struct timeval tv {
- .tv_sec = 0, .tv_usec = 35
- };
+ timeval actual_tv = {.tv_sec = -1, .tv_usec = -1};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, 0);
+ EXPECT_EQ(actual_tv.tv_usec, 0);
+}
+
+TEST_P(AllSocketPairTest, SetGetRecvTimeout) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ timeval tv = {.tv_sec = 123, .tv_usec = 456000};
EXPECT_THAT(
setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
SyscallSucceeds());
+
+ timeval actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, 123);
+ EXPECT_EQ(actual_tv.tv_usec, 456000);
}
-TEST_P(AllSocketPairTest, SoRcvTimeoIsSetLargerArg) {
+TEST_P(AllSocketPairTest, SetGetRecvTimeoutLargerArg) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
struct timeval_with_extra {
@@ -510,13 +582,21 @@ TEST_P(AllSocketPairTest, SoRcvTimeoIsSetLargerArg) {
int64_t extra_data;
} ABSL_ATTRIBUTE_PACKED;
- timeval_with_extra tv_extra;
- tv_extra.tv.tv_sec = 0;
- tv_extra.tv.tv_usec = 25;
+ timeval_with_extra tv_extra = {
+ .tv = {.tv_sec = 0, .tv_usec = 432000},
+ };
EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
&tv_extra, sizeof(tv_extra)),
SyscallSucceeds());
+
+ timeval_with_extra actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv.tv_sec, 0);
+ EXPECT_EQ(actual_tv.tv.tv_usec, 432000);
}
TEST_P(AllSocketPairTest, RecvTimeoutRecvmsgOneSecondSucceeds) {
diff --git a/test/syscalls/linux/socket_generic_stress.cc b/test/syscalls/linux/socket_generic_stress.cc
new file mode 100644
index 000000000..6a232238d
--- /dev/null
+++ b/test/syscalls/linux/socket_generic_stress.cc
@@ -0,0 +1,83 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdio.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to pairs of connected sockets.
+using ConnectStressTest = SocketPairTest;
+
+TEST_P(ConnectStressTest, Reset65kTimes) {
+ for (int i = 0; i < 1 << 16; ++i) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ // Send some data to ensure that the connection gets reset and the port gets
+ // released immediately. This avoids either end entering TIME-WAIT.
+ char sent_data[100] = {};
+ ASSERT_THAT(write(sockets->first_fd(), sent_data, sizeof(sent_data)),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllConnectedSockets, ConnectStressTest,
+ ::testing::Values(IPv6UDPBidirectionalBindSocketPair(0),
+ IPv4UDPBidirectionalBindSocketPair(0),
+ DualStackUDPBidirectionalBindSocketPair(0),
+
+ // Without REUSEADDR, we get port exhaustion on Linux.
+ SetSockOpt(SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn)(IPv6TCPAcceptBindSocketPair(0)),
+ SetSockOpt(SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn)(IPv4TCPAcceptBindSocketPair(0)),
+ SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)(
+ DualStackTCPAcceptBindSocketPair(0))));
+
+// Test fixture for tests that apply to pairs of connected sockets created with
+// a persistent listener (if applicable).
+using PersistentListenerConnectStressTest = SocketPairTest;
+
+TEST_P(PersistentListenerConnectStressTest, 65kTimes) {
+ for (int i = 0; i < 1 << 16; ++i) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllConnectedSockets, PersistentListenerConnectStressTest,
+ ::testing::Values(
+ IPv6UDPBidirectionalBindSocketPair(0),
+ IPv4UDPBidirectionalBindSocketPair(0),
+ DualStackUDPBidirectionalBindSocketPair(0),
+
+ // Without REUSEADDR, we get port exhaustion on Linux.
+ SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)(
+ IPv6TCPAcceptBindPersistentListenerSocketPair(0)),
+ SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)(
+ IPv4TCPAcceptBindPersistentListenerSocketPair(0)),
+ SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)(
+ DualStackTCPAcceptBindPersistentListenerSocketPair(0))));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 2f9821555..b24618a88 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -325,6 +325,12 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) {
TestAddress const& listener = param.listener;
TestAddress const& connector = param.connector;
+ constexpr int kAcceptCount = 32;
+ constexpr int kBacklog = kAcceptCount * 2;
+ constexpr int kFDs = 128;
+ constexpr int kThreadCount = 4;
+ constexpr int kFDsPerThread = kFDs / kThreadCount;
+
// Create the listening socket.
FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
@@ -332,7 +338,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) {
ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
listener.addr_len),
SyscallSucceeds());
- ASSERT_THAT(listen(listen_fd.get(), 1001), SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
@@ -345,9 +351,6 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) {
DisableSave ds; // Too many system calls.
sockaddr_storage conn_addr = connector.addr;
ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
- constexpr int kFDs = 2048;
- constexpr int kThreadCount = 4;
- constexpr int kFDsPerThread = kFDs / kThreadCount;
FileDescriptor clients[kFDs];
std::unique_ptr<ScopedThread> threads[kThreadCount];
for (int i = 0; i < kFDs; i++) {
@@ -371,7 +374,7 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) {
for (int i = 0; i < kThreadCount; i++) {
threads[i]->Join();
}
- for (int i = 0; i < 32; i++) {
+ for (int i = 0; i < kAcceptCount; i++) {
auto accepted =
ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
}
@@ -828,6 +831,164 @@ TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) {
EXPECT_EQ(get, kUserTimeout);
}
+// TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not
+// saved. Enable S/R once issue is fixed.
+TEST_P(SocketInetLoopbackTest, TCPDeferAccept_NoRandomSave) {
+ // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not
+ // saved. Enable S/R issue is fixed.
+ DisableSave ds;
+
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+
+ const uint16_t port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Set the TCP_DEFER_ACCEPT on the listening socket.
+ constexpr int kTCPDeferAccept = 3;
+ ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT,
+ &kTCPDeferAccept, sizeof(kTCPDeferAccept)),
+ SyscallSucceeds());
+
+ // Connect to the listening socket.
+ FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Set the listening socket to nonblock so that we can verify that there is no
+ // connection in queue despite the connect above succeeding since the peer has
+ // sent no data and TCP_DEFER_ACCEPT is set on the listening socket. Set the
+ // FD to O_NONBLOCK.
+ int opts;
+ ASSERT_THAT(opts = fcntl(listen_fd.get(), F_GETFL), SyscallSucceeds());
+ opts |= O_NONBLOCK;
+ ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds());
+
+ ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Set FD back to blocking.
+ opts &= ~O_NONBLOCK;
+ ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds());
+
+ // Now write some data to the socket.
+ int data = 0;
+ ASSERT_THAT(RetryEINTR(write)(conn_fd.get(), &data, sizeof(data)),
+ SyscallSucceedsWithValue(sizeof(data)));
+
+ // This should now cause the connection to complete and be delivered to the
+ // accept socket.
+
+ // Accept the connection.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+
+ // Verify that the accepted socket returns the data written.
+ int get = -1;
+ ASSERT_THAT(RetryEINTR(recv)(accepted.get(), &get, sizeof(get), 0),
+ SyscallSucceedsWithValue(sizeof(get)));
+
+ EXPECT_EQ(get, data);
+}
+
+// TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not
+// saved. Enable S/R once issue is fixed.
+TEST_P(SocketInetLoopbackTest, TCPDeferAcceptTimeout_NoRandomSave) {
+ // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not
+ // saved. Enable S/R once issue is fixed.
+ DisableSave ds;
+
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+
+ const uint16_t port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Set the TCP_DEFER_ACCEPT on the listening socket.
+ constexpr int kTCPDeferAccept = 3;
+ ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT,
+ &kTCPDeferAccept, sizeof(kTCPDeferAccept)),
+ SyscallSucceeds());
+
+ // Connect to the listening socket.
+ FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Set the listening socket to nonblock so that we can verify that there is no
+ // connection in queue despite the connect above succeeding since the peer has
+ // sent no data and TCP_DEFER_ACCEPT is set on the listening socket. Set the
+ // FD to O_NONBLOCK.
+ int opts;
+ ASSERT_THAT(opts = fcntl(listen_fd.get(), F_GETFL), SyscallSucceeds());
+ opts |= O_NONBLOCK;
+ ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds());
+
+ // Verify that there is no acceptable connection before TCP_DEFER_ACCEPT
+ // timeout is hit.
+ absl::SleepFor(absl::Seconds(kTCPDeferAccept - 1));
+ ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Set FD back to blocking.
+ opts &= ~O_NONBLOCK;
+ ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds());
+
+ // Now sleep for a little over the TCP_DEFER_ACCEPT duration. When the timeout
+ // is hit a SYN-ACK should be retransmitted by the listener as a last ditch
+ // attempt to complete the connection with or without data.
+ absl::SleepFor(absl::Seconds(2));
+
+ // Verify that we have a connection that can be accepted even though no
+ // data was written.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+}
+
INSTANTIATE_TEST_SUITE_P(
All, SocketInetLoopbackTest,
::testing::Values(
diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc
index 57ce8e169..27779e47c 100644
--- a/test/syscalls/linux/socket_ip_tcp_generic.cc
+++ b/test/syscalls/linux/socket_ip_tcp_generic.cc
@@ -24,6 +24,7 @@
#include <sys/un.h>
#include "gtest/gtest.h"
+#include "absl/memory/memory.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "test/syscalls/linux/socket_test_util.h"
@@ -875,5 +876,37 @@ TEST_P(TCPSocketPairTest, SetTCPUserTimeoutAboveZero) {
EXPECT_EQ(get, kAbove);
}
+TEST_P(TCPSocketPairTest, TCPResetDuringClose_NoRandomSave) {
+ DisableSave ds; // Too many syscalls.
+ constexpr int kThreadCount = 1000;
+ std::unique_ptr<ScopedThread> instances[kThreadCount];
+ for (int i = 0; i < kThreadCount; i++) {
+ instances[i] = absl::make_unique<ScopedThread>([&]() {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ScopedThread t([&]() {
+ // Close one end to trigger sending of a FIN.
+ struct pollfd poll_fd = {sockets->second_fd(), POLLIN | POLLHUP, 0};
+ // Wait up to 20 seconds for the data.
+ constexpr int kPollTimeoutMs = 20000;
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+ ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds());
+ });
+
+ // Send some data then close.
+ constexpr char kStr[] = "abc";
+ ASSERT_THAT(write(sockets->first_fd(), kStr, 3),
+ SyscallSucceedsWithValue(3));
+ absl::SleepFor(absl::Milliseconds(10));
+ ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
+ t.Join();
+ });
+ }
+ for (int i = 0; i < kThreadCount; i++) {
+ instances[i]->Join();
+ }
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc b/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc
index d11f7cc23..4e79d21f4 100644
--- a/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc
+++ b/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc
@@ -23,6 +23,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return ApplyVecToVec<SocketPairKind>(
@@ -39,5 +40,6 @@ INSTANTIATE_TEST_SUITE_P(
AllTCPSockets, TCPSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc b/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc
index fcd20102f..f996b93d2 100644
--- a/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc
+++ b/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc
@@ -23,6 +23,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return ApplyVecToVec<SocketPairKind>(
@@ -39,5 +40,6 @@ INSTANTIATE_TEST_SUITE_P(
BlockingTCPSockets, BlockingStreamSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc b/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc
index 63a05b799..ffa377210 100644
--- a/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc
+++ b/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc
@@ -23,6 +23,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return ApplyVecToVec<SocketPairKind>(
@@ -38,5 +39,6 @@ INSTANTIATE_TEST_SUITE_P(
NonBlockingTCPSockets, NonBlockingSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc
index 53290bed7..1c533fdf2 100644
--- a/test/syscalls/linux/socket_ip_udp_generic.cc
+++ b/test/syscalls/linux/socket_ip_udp_generic.cc
@@ -14,6 +14,7 @@
#include "test/syscalls/linux/socket_ip_udp_generic.h"
+#include <errno.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <poll.h>
@@ -209,46 +210,6 @@ TEST_P(UDPSocketPairTest, SetMulticastLoopChar) {
EXPECT_EQ(get, kSockOptOn);
}
-// Ensure that Receiving TOS is off by default.
-TEST_P(UDPSocketPairTest, RecvTosDefault) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
-
- int get = -1;
- socklen_t get_len = sizeof(get);
- ASSERT_THAT(
- getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len),
- SyscallSucceedsWithValue(0));
- EXPECT_EQ(get_len, sizeof(get));
- EXPECT_EQ(get, kSockOptOff);
-}
-
-// Test that setting and getting IP_RECVTOS works as expected.
-TEST_P(UDPSocketPairTest, SetRecvTos) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
-
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS,
- &kSockOptOff, sizeof(kSockOptOff)),
- SyscallSucceeds());
-
- int get = -1;
- socklen_t get_len = sizeof(get);
- ASSERT_THAT(
- getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len),
- SyscallSucceedsWithValue(0));
- EXPECT_EQ(get_len, sizeof(get));
- EXPECT_EQ(get, kSockOptOff);
-
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS,
- &kSockOptOn, sizeof(kSockOptOn)),
- SyscallSucceeds());
-
- ASSERT_THAT(
- getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len),
- SyscallSucceedsWithValue(0));
- EXPECT_EQ(get_len, sizeof(get));
- EXPECT_EQ(get, kSockOptOn);
-}
-
TEST_P(UDPSocketPairTest, ReuseAddrDefault) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
@@ -357,5 +318,141 @@ TEST_P(UDPSocketPairTest, SetReuseAddrReusePort) {
EXPECT_EQ(get, kSockOptOn);
}
+// Test getsockopt for a socket which is not set with IP_PKTINFO option.
+TEST_P(UDPSocketPairTest, IPPKTINFODefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_IP, IP_PKTINFO, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+// Test setsockopt and getsockopt for a socket with IP_PKTINFO option.
+TEST_P(UDPSocketPairTest, SetAndGetIPPKTINFO) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int level = SOL_IP;
+ int type = IP_PKTINFO;
+
+ // Check getsockopt before IP_PKTINFO is set.
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get, kSockOptOn);
+ EXPECT_EQ(get_len, sizeof(get));
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOff,
+ sizeof(kSockOptOff)),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get, kSockOptOff);
+ EXPECT_EQ(get_len, sizeof(get));
+}
+
+// Holds TOS or TClass information for IPv4 or IPv6 respectively.
+struct RecvTosOption {
+ int level;
+ int option;
+};
+
+RecvTosOption GetRecvTosOption(int domain) {
+ TEST_CHECK(domain == AF_INET || domain == AF_INET6);
+ RecvTosOption opt;
+ switch (domain) {
+ case AF_INET:
+ opt.level = IPPROTO_IP;
+ opt.option = IP_RECVTOS;
+ break;
+ case AF_INET6:
+ opt.level = IPPROTO_IPV6;
+ opt.option = IPV6_RECVTCLASS;
+ break;
+ }
+ return opt;
+}
+
+// Ensure that Receiving TOS or TCLASS is off by default.
+TEST_P(UDPSocketPairTest, RecvTosDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ RecvTosOption t = GetRecvTosOption(GetParam().domain);
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+// Test that setting and getting IP_RECVTOS or IPV6_RECVTCLASS works as
+// expected.
+TEST_P(UDPSocketPairTest, SetRecvTos) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ RecvTosOption t = GetRecvTosOption(GetParam().domain);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), t.level, t.option, &kSockOptOff,
+ sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), t.level, t.option, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOn);
+}
+
+// Test that any socket (including IPv6 only) accepts the IPv4 TOS option: this
+// mirrors behavior in linux.
+TEST_P(UDPSocketPairTest, TOSRecvMismatch) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ RecvTosOption t = GetRecvTosOption(AF_INET);
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+}
+
+// Test that an IPv4 socket does not support the IPv6 TClass option.
+TEST_P(UDPSocketPairTest, TClassRecvMismatch) {
+ // This should only test AF_INET sockets for the mismatch behavior.
+ SKIP_IF(GetParam().domain != AF_INET);
+
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IPV6, IPV6_RECVTCLASS,
+ &get, &get_len),
+ SyscallFailsWithErrno(EOPNOTSUPP));
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_udp_loopback.cc b/test/syscalls/linux/socket_ip_udp_loopback.cc
index 1df74a348..c7fa44884 100644
--- a/test/syscalls/linux/socket_ip_udp_loopback.cc
+++ b/test/syscalls/linux/socket_ip_udp_loopback.cc
@@ -23,6 +23,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return {
@@ -44,5 +45,6 @@ INSTANTIATE_TEST_SUITE_P(
AllUDPSockets, UDPSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_udp_loopback_blocking.cc b/test/syscalls/linux/socket_ip_udp_loopback_blocking.cc
index 1e259efa7..d6925a8df 100644
--- a/test/syscalls/linux/socket_ip_udp_loopback_blocking.cc
+++ b/test/syscalls/linux/socket_ip_udp_loopback_blocking.cc
@@ -21,6 +21,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return {
@@ -33,5 +34,6 @@ INSTANTIATE_TEST_SUITE_P(
BlockingUDPSockets, BlockingNonStreamSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc b/test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc
index 74cbd326d..d675eddc6 100644
--- a/test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc
+++ b/test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc
@@ -21,6 +21,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return {
@@ -33,5 +34,6 @@ INSTANTIATE_TEST_SUITE_P(
NonBlockingUDPSockets, NonBlockingSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc
index 3ac790873..797c4174e 100644
--- a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc
+++ b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc
@@ -22,6 +22,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketKind> GetSockets() {
return ApplyVec<SocketKind>(
@@ -32,5 +33,7 @@ std::vector<SocketKind> GetSockets() {
INSTANTIATE_TEST_SUITE_P(IPv4TCPUnboundSockets,
IPv4TCPUnboundExternalNetworkingSocketTest,
::testing::ValuesIn(GetSockets()));
+
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
index aa6fb4e3f..bc4b07a62 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
@@ -15,6 +15,7 @@
#include "test/syscalls/linux/socket_ipv4_udp_unbound.h"
#include <arpa/inet.h>
+#include <net/if.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/un.h>
@@ -30,27 +31,6 @@
namespace gvisor {
namespace testing {
-constexpr char kMulticastAddress[] = "224.0.2.1";
-constexpr char kBroadcastAddress[] = "255.255.255.255";
-
-TestAddress V4Multicast() {
- TestAddress t("V4Multicast");
- t.addr.ss_family = AF_INET;
- t.addr_len = sizeof(sockaddr_in);
- reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr =
- inet_addr(kMulticastAddress);
- return t;
-}
-
-TestAddress V4Broadcast() {
- TestAddress t("V4Broadcast");
- t.addr.ss_family = AF_INET;
- t.addr_len = sizeof(sockaddr_in);
- reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr =
- inet_addr(kBroadcastAddress);
- return t;
-}
-
// Check that packets are not received without a group membership. Default send
// interface configured by bind.
TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNoGroup) {
@@ -2149,5 +2129,88 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) {
SyscallSucceedsWithValue(kMessageSize));
}
+// Test that socket will receive packet info control message.
+TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) {
+ // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet.
+ SKIP_IF((IsRunningWithHostinet()));
+
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto sender_addr = V4Loopback();
+ int level = SOL_IP;
+ int type = IP_PKTINFO;
+
+ ASSERT_THAT(
+ bind(receiver->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ sender_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t sender_addr_len = sender_addr.addr_len;
+ ASSERT_THAT(getsockname(receiver->get(),
+ reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ &sender_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(sender_addr_len, sender_addr.addr_len);
+
+ auto receiver_addr = V4Loopback();
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&sender_addr.addr)->sin_port;
+ ASSERT_THAT(
+ connect(sender->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+
+ // Allow socket to receive control message.
+ ASSERT_THAT(
+ setsockopt(receiver->get(), level, type, &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Prepare message to send.
+ constexpr size_t kDataLength = 1024;
+ msghdr sent_msg = {};
+ iovec sent_iov = {};
+ char sent_data[kDataLength];
+ sent_iov.iov_base = sent_data;
+ sent_iov.iov_len = kDataLength;
+ sent_msg.msg_iov = &sent_iov;
+ sent_msg.msg_iovlen = 1;
+ sent_msg.msg_flags = 0;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(sender->get(), &sent_msg, 0),
+ SyscallSucceedsWithValue(kDataLength));
+
+ msghdr received_msg = {};
+ iovec received_iov = {};
+ char received_data[kDataLength];
+ char received_cmsg_buf[CMSG_SPACE(sizeof(in_pktinfo))] = {};
+ size_t cmsg_data_len = sizeof(in_pktinfo);
+ received_iov.iov_base = received_data;
+ received_iov.iov_len = kDataLength;
+ received_msg.msg_iov = &received_iov;
+ received_msg.msg_iovlen = 1;
+ received_msg.msg_controllen = CMSG_LEN(cmsg_data_len);
+ received_msg.msg_control = received_cmsg_buf;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(receiver->get(), &received_msg, 0),
+ SyscallSucceedsWithValue(kDataLength));
+
+ cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg);
+ ASSERT_NE(cmsg, nullptr);
+ EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len));
+ EXPECT_EQ(cmsg->cmsg_level, level);
+ EXPECT_EQ(cmsg->cmsg_type, type);
+
+ // Get loopback index.
+ ifreq ifr = {};
+ absl::SNPrintF(ifr.ifr_name, IFNAMSIZ, "lo");
+ ASSERT_THAT(ioctl(sender->get(), SIOCGIFINDEX, &ifr), SyscallSucceeds());
+ ASSERT_NE(ifr.ifr_ifindex, 0);
+
+ // Check the data
+ in_pktinfo received_pktinfo = {};
+ memcpy(&received_pktinfo, CMSG_DATA(cmsg), sizeof(in_pktinfo));
+ EXPECT_EQ(received_pktinfo.ipi_ifindex, ifr.ifr_ifindex);
+ EXPECT_EQ(received_pktinfo.ipi_spec_dst.s_addr, htonl(INADDR_LOOPBACK));
+ EXPECT_EQ(received_pktinfo.ipi_addr.s_addr, htonl(INADDR_LOOPBACK));
+}
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
index 98ae414f3..40e673625 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
@@ -41,26 +41,6 @@ TestAddress V4EmptyAddress() {
return t;
}
-constexpr char kMulticastAddress[] = "224.0.2.1";
-
-TestAddress V4Multicast() {
- TestAddress t("V4Multicast");
- t.addr.ss_family = AF_INET;
- t.addr_len = sizeof(sockaddr_in);
- reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr =
- inet_addr(kMulticastAddress);
- return t;
-}
-
-TestAddress V4Broadcast() {
- TestAddress t("V4Broadcast");
- t.addr.ss_family = AF_INET;
- t.addr_len = sizeof(sockaddr_in);
- reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr =
- htonl(INADDR_BROADCAST);
- return t;
-}
-
void IPv4UDPUnboundExternalNetworkingSocketTest::SetUp() {
got_if_infos_ = false;
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc
index 8f47952b0..f6e64c157 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc
@@ -22,6 +22,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketKind> GetSockets() {
return ApplyVec<SocketKind>(
@@ -32,5 +33,7 @@ std::vector<SocketKind> GetSockets() {
INSTANTIATE_TEST_SUITE_P(IPv4UDPUnboundSockets,
IPv4UDPUnboundExternalNetworkingSocketTest,
::testing::ValuesIn(GetSockets()));
+
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc
index 1e28e658d..e5aed1eec 100644
--- a/test/syscalls/linux/socket_netlink_route.cc
+++ b/test/syscalls/linux/socket_netlink_route.cc
@@ -14,6 +14,7 @@
#include <arpa/inet.h>
#include <ifaddrs.h>
+#include <linux/if.h>
#include <linux/netlink.h>
#include <linux/rtnetlink.h>
#include <sys/socket.h>
@@ -25,8 +26,10 @@
#include "gtest/gtest.h"
#include "absl/strings/str_format.h"
+#include "absl/types/optional.h"
#include "test/syscalls/linux/socket_netlink_util.h"
#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
#include "test/util/cleanup.h"
#include "test/util/file_descriptor.h"
#include "test/util/test_util.h"
@@ -38,6 +41,8 @@ namespace testing {
namespace {
+constexpr uint32_t kSeq = 12345;
+
using ::testing::AnyOf;
using ::testing::Eq;
@@ -113,58 +118,224 @@ void CheckGetLinkResponse(const struct nlmsghdr* hdr, int seq, int port) {
// TODO(mpratt): Check ifinfomsg contents and following attrs.
}
+PosixError DumpLinks(
+ const FileDescriptor& fd, uint32_t seq,
+ const std::function<void(const struct nlmsghdr* hdr)>& fn) {
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = seq;
+ req.ifm.ifi_family = AF_UNSPEC;
+
+ return NetlinkRequestResponse(fd, &req, sizeof(req), fn, false);
+}
+
TEST(NetlinkRouteTest, GetLinkDump) {
FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get()));
+ // Loopback is common among all tests, check that it's found.
+ bool loopbackFound = false;
+ ASSERT_NO_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) {
+ CheckGetLinkResponse(hdr, kSeq, port);
+ if (hdr->nlmsg_type != RTM_NEWLINK) {
+ return;
+ }
+ ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg)));
+ const struct ifinfomsg* msg =
+ reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr));
+ std::cout << "Found interface idx=" << msg->ifi_index
+ << ", type=" << std::hex << msg->ifi_type;
+ if (msg->ifi_type == ARPHRD_LOOPBACK) {
+ loopbackFound = true;
+ EXPECT_NE(msg->ifi_flags & IFF_LOOPBACK, 0);
+ }
+ }));
+ EXPECT_TRUE(loopbackFound);
+}
+
+struct Link {
+ int index;
+ std::string name;
+};
+
+PosixErrorOr<absl::optional<Link>> FindLoopbackLink() {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
+
+ absl::optional<Link> link;
+ RETURN_IF_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) {
+ if (hdr->nlmsg_type != RTM_NEWLINK ||
+ hdr->nlmsg_len < NLMSG_SPACE(sizeof(struct ifinfomsg))) {
+ return;
+ }
+ const struct ifinfomsg* msg =
+ reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr));
+ if (msg->ifi_type == ARPHRD_LOOPBACK) {
+ const auto* rta = FindRtAttr(hdr, msg, IFLA_IFNAME);
+ if (rta == nullptr) {
+ // Ignore links that do not have a name.
+ return;
+ }
+
+ link = Link();
+ link->index = msg->ifi_index;
+ link->name = std::string(reinterpret_cast<const char*>(RTA_DATA(rta)));
+ }
+ }));
+ return link;
+}
+
+// CheckLinkMsg checks a netlink message against an expected link.
+void CheckLinkMsg(const struct nlmsghdr* hdr, const Link& link) {
+ ASSERT_THAT(hdr->nlmsg_type, Eq(RTM_NEWLINK));
+ ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg)));
+ const struct ifinfomsg* msg =
+ reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr));
+ EXPECT_EQ(msg->ifi_index, link.index);
+
+ const struct rtattr* rta = FindRtAttr(hdr, msg, IFLA_IFNAME);
+ EXPECT_NE(nullptr, rta) << "IFLA_IFNAME not found in message.";
+ if (rta != nullptr) {
+ std::string name(reinterpret_cast<const char*>(RTA_DATA(rta)));
+ EXPECT_EQ(name, link.name);
+ }
+}
+
+TEST(NetlinkRouteTest, GetLinkByIndex) {
+ absl::optional<Link> loopback_link =
+ ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink());
+ ASSERT_TRUE(loopback_link.has_value());
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
struct request {
struct nlmsghdr hdr;
struct ifinfomsg ifm;
};
- constexpr uint32_t kSeq = 12345;
-
struct request req = {};
req.hdr.nlmsg_len = sizeof(req);
req.hdr.nlmsg_type = RTM_GETLINK;
- req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST;
req.hdr.nlmsg_seq = kSeq;
req.ifm.ifi_family = AF_UNSPEC;
+ req.ifm.ifi_index = loopback_link->index;
- // Loopback is common among all tests, check that it's found.
- bool loopbackFound = false;
+ bool found = false;
ASSERT_NO_ERRNO(NetlinkRequestResponse(
fd, &req, sizeof(req),
[&](const struct nlmsghdr* hdr) {
- CheckGetLinkResponse(hdr, kSeq, port);
- if (hdr->nlmsg_type != RTM_NEWLINK) {
- return;
- }
- ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg)));
- const struct ifinfomsg* msg =
- reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr));
- std::cout << "Found interface idx=" << msg->ifi_index
- << ", type=" << std::hex << msg->ifi_type;
- if (msg->ifi_type == ARPHRD_LOOPBACK) {
- loopbackFound = true;
- EXPECT_NE(msg->ifi_flags & IFF_LOOPBACK, 0);
- }
+ CheckLinkMsg(hdr, *loopback_link);
+ found = true;
},
false));
- EXPECT_TRUE(loopbackFound);
+ EXPECT_TRUE(found) << "Netlink response does not contain any links.";
}
-TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) {
+TEST(NetlinkRouteTest, GetLinkByName) {
+ absl::optional<Link> loopback_link =
+ ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink());
+ ASSERT_TRUE(loopback_link.has_value());
+
FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
struct request {
struct nlmsghdr hdr;
struct ifinfomsg ifm;
+ struct rtattr rtattr;
+ char ifname[IFNAMSIZ];
+ char pad[NLMSG_ALIGNTO + RTA_ALIGNTO];
};
- constexpr uint32_t kSeq = 12345;
+ struct request req = {};
+ req.hdr.nlmsg_type = RTM_GETLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifm.ifi_family = AF_UNSPEC;
+ req.rtattr.rta_type = IFLA_IFNAME;
+ req.rtattr.rta_len = RTA_LENGTH(loopback_link->name.size() + 1);
+ strncpy(req.ifname, loopback_link->name.c_str(), sizeof(req.ifname));
+ req.hdr.nlmsg_len =
+ NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len);
+
+ bool found = false;
+ ASSERT_NO_ERRNO(NetlinkRequestResponse(
+ fd, &req, sizeof(req),
+ [&](const struct nlmsghdr* hdr) {
+ CheckLinkMsg(hdr, *loopback_link);
+ found = true;
+ },
+ false));
+ EXPECT_TRUE(found) << "Netlink response does not contain any links.";
+}
+
+TEST(NetlinkRouteTest, GetLinkByIndexNotFound) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifm.ifi_family = AF_UNSPEC;
+ req.ifm.ifi_index = 1234590;
+
+ EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)),
+ PosixErrorIs(ENODEV, ::testing::_));
+}
+
+TEST(NetlinkRouteTest, GetLinkByNameNotFound) {
+ const std::string name = "nodevice?!";
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ struct rtattr rtattr;
+ char ifname[IFNAMSIZ];
+ char pad[NLMSG_ALIGNTO + RTA_ALIGNTO];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_type = RTM_GETLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifm.ifi_family = AF_UNSPEC;
+ req.rtattr.rta_type = IFLA_IFNAME;
+ req.rtattr.rta_len = RTA_LENGTH(name.size() + 1);
+ strncpy(req.ifname, name.c_str(), sizeof(req.ifname));
+ req.hdr.nlmsg_len =
+ NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len);
+
+ EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)),
+ PosixErrorIs(ENODEV, ::testing::_));
+}
+
+TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ };
struct request req = {};
req.hdr.nlmsg_len = sizeof(req);
@@ -175,18 +346,8 @@ TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) {
req.hdr.nlmsg_seq = kSeq;
req.ifm.ifi_family = AF_UNSPEC;
- ASSERT_NO_ERRNO(NetlinkRequestResponse(
- fd, &req, sizeof(req),
- [&](const struct nlmsghdr* hdr) {
- EXPECT_THAT(hdr->nlmsg_type, Eq(NLMSG_ERROR));
- EXPECT_EQ(hdr->nlmsg_seq, kSeq);
- EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct nlmsgerr));
-
- const struct nlmsgerr* msg =
- reinterpret_cast<const struct nlmsgerr*>(NLMSG_DATA(hdr));
- EXPECT_EQ(msg->error, -EOPNOTSUPP);
- },
- true));
+ EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)),
+ PosixErrorIs(EOPNOTSUPP, ::testing::_));
}
TEST(NetlinkRouteTest, MsgHdrMsgTrunc) {
@@ -198,8 +359,6 @@ TEST(NetlinkRouteTest, MsgHdrMsgTrunc) {
struct ifinfomsg ifm;
};
- constexpr uint32_t kSeq = 12345;
-
struct request req = {};
req.hdr.nlmsg_len = sizeof(req);
req.hdr.nlmsg_type = RTM_GETLINK;
@@ -238,8 +397,6 @@ TEST(NetlinkRouteTest, MsgTruncMsgHdrMsgTrunc) {
struct ifinfomsg ifm;
};
- constexpr uint32_t kSeq = 12345;
-
struct request req = {};
req.hdr.nlmsg_len = sizeof(req);
req.hdr.nlmsg_type = RTM_GETLINK;
@@ -282,8 +439,6 @@ TEST(NetlinkRouteTest, ControlMessageIgnored) {
struct ifinfomsg ifm;
};
- constexpr uint32_t kSeq = 12345;
-
struct request req = {};
// This control message is ignored. We still receive a response for the
@@ -317,8 +472,6 @@ TEST(NetlinkRouteTest, GetAddrDump) {
struct rtgenmsg rgm;
};
- constexpr uint32_t kSeq = 12345;
-
struct request req;
req.hdr.nlmsg_len = sizeof(req);
req.hdr.nlmsg_type = RTM_GETADDR;
@@ -367,6 +520,57 @@ TEST(NetlinkRouteTest, LookupAll) {
ASSERT_GT(count, 0);
}
+TEST(NetlinkRouteTest, AddAddr) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ absl::optional<Link> loopback_link =
+ ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink());
+ ASSERT_TRUE(loopback_link.has_value());
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifaddrmsg ifa;
+ struct rtattr rtattr;
+ struct in_addr addr;
+ char pad[NLMSG_ALIGNTO + RTA_ALIGNTO];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_type = RTM_NEWADDR;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifa.ifa_family = AF_INET;
+ req.ifa.ifa_prefixlen = 24;
+ req.ifa.ifa_flags = 0;
+ req.ifa.ifa_scope = 0;
+ req.ifa.ifa_index = loopback_link->index;
+ req.rtattr.rta_type = IFA_LOCAL;
+ req.rtattr.rta_len = RTA_LENGTH(sizeof(req.addr));
+ inet_pton(AF_INET, "10.0.0.1", &req.addr);
+ req.hdr.nlmsg_len =
+ NLMSG_LENGTH(sizeof(req.ifa)) + NLMSG_ALIGN(req.rtattr.rta_len);
+
+ // Create should succeed, as no such address in kernel.
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_ACK;
+ EXPECT_NO_ERRNO(
+ NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len));
+
+ // Replace an existing address should succeed.
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_REPLACE | NLM_F_ACK;
+ req.hdr.nlmsg_seq++;
+ EXPECT_NO_ERRNO(
+ NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len));
+
+ // Create exclusive should fail, as we created the address above.
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL | NLM_F_ACK;
+ req.hdr.nlmsg_seq++;
+ EXPECT_THAT(
+ NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len),
+ PosixErrorIs(EEXIST, ::testing::_));
+}
+
// GetRouteDump tests a RTM_GETROUTE + NLM_F_DUMP request.
TEST(NetlinkRouteTest, GetRouteDump) {
FileDescriptor fd =
@@ -378,8 +582,6 @@ TEST(NetlinkRouteTest, GetRouteDump) {
struct rtmsg rtm;
};
- constexpr uint32_t kSeq = 12345;
-
struct request req = {};
req.hdr.nlmsg_len = sizeof(req);
req.hdr.nlmsg_type = RTM_GETROUTE;
@@ -538,8 +740,6 @@ TEST(NetlinkRouteTest, RecvmsgTrunc) {
struct rtgenmsg rgm;
};
- constexpr uint32_t kSeq = 12345;
-
struct request req;
req.hdr.nlmsg_len = sizeof(req);
req.hdr.nlmsg_type = RTM_GETADDR;
@@ -615,8 +815,6 @@ TEST(NetlinkRouteTest, RecvmsgTruncPeek) {
struct rtgenmsg rgm;
};
- constexpr uint32_t kSeq = 12345;
-
struct request req;
req.hdr.nlmsg_len = sizeof(req);
req.hdr.nlmsg_type = RTM_GETADDR;
@@ -695,8 +893,6 @@ TEST(NetlinkRouteTest, NoPasscredNoCreds) {
struct rtgenmsg rgm;
};
- constexpr uint32_t kSeq = 12345;
-
struct request req;
req.hdr.nlmsg_len = sizeof(req);
req.hdr.nlmsg_type = RTM_GETADDR;
@@ -743,8 +939,6 @@ TEST(NetlinkRouteTest, PasscredCreds) {
struct rtgenmsg rgm;
};
- constexpr uint32_t kSeq = 12345;
-
struct request req;
req.hdr.nlmsg_len = sizeof(req);
req.hdr.nlmsg_type = RTM_GETADDR;
diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc
new file mode 100644
index 000000000..53eb3b6b2
--- /dev/null
+++ b/test/syscalls/linux/socket_netlink_route_util.cc
@@ -0,0 +1,163 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_netlink_route_util.h"
+
+#include <linux/if.h>
+#include <linux/netlink.h>
+#include <linux/rtnetlink.h>
+
+#include "absl/types/optional.h"
+#include "test/syscalls/linux/socket_netlink_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+constexpr uint32_t kSeq = 12345;
+
+} // namespace
+
+PosixError DumpLinks(
+ const FileDescriptor& fd, uint32_t seq,
+ const std::function<void(const struct nlmsghdr* hdr)>& fn) {
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = seq;
+ req.ifm.ifi_family = AF_UNSPEC;
+
+ return NetlinkRequestResponse(fd, &req, sizeof(req), fn, false);
+}
+
+PosixErrorOr<std::vector<Link>> DumpLinks() {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
+
+ std::vector<Link> links;
+ RETURN_IF_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) {
+ if (hdr->nlmsg_type != RTM_NEWLINK ||
+ hdr->nlmsg_len < NLMSG_SPACE(sizeof(struct ifinfomsg))) {
+ return;
+ }
+ const struct ifinfomsg* msg =
+ reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr));
+ const auto* rta = FindRtAttr(hdr, msg, IFLA_IFNAME);
+ if (rta == nullptr) {
+ // Ignore links that do not have a name.
+ return;
+ }
+
+ links.emplace_back();
+ links.back().index = msg->ifi_index;
+ links.back().type = msg->ifi_type;
+ links.back().name =
+ std::string(reinterpret_cast<const char*>(RTA_DATA(rta)));
+ }));
+ return links;
+}
+
+PosixErrorOr<absl::optional<Link>> FindLoopbackLink() {
+ ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks());
+ for (const auto& link : links) {
+ if (link.type == ARPHRD_LOOPBACK) {
+ return absl::optional<Link>(link);
+ }
+ }
+ return absl::optional<Link>();
+}
+
+PosixError LinkAddLocalAddr(int index, int family, int prefixlen,
+ const void* addr, int addrlen) {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifaddrmsg ifaddr;
+ char attrbuf[512];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifaddr));
+ req.hdr.nlmsg_type = RTM_NEWADDR;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifaddr.ifa_index = index;
+ req.ifaddr.ifa_family = family;
+ req.ifaddr.ifa_prefixlen = prefixlen;
+
+ struct rtattr* rta = reinterpret_cast<struct rtattr*>(
+ reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len));
+ rta->rta_type = IFA_LOCAL;
+ rta->rta_len = RTA_LENGTH(addrlen);
+ req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen);
+ memcpy(RTA_DATA(rta), addr, addrlen);
+
+ return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len);
+}
+
+PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change) {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifinfo;
+ char pad[NLMSG_ALIGNTO];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo));
+ req.hdr.nlmsg_type = RTM_NEWLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifinfo.ifi_index = index;
+ req.ifinfo.ifi_flags = flags;
+ req.ifinfo.ifi_change = change;
+
+ return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len);
+}
+
+PosixError LinkSetMacAddr(int index, const void* addr, int addrlen) {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifinfo;
+ char attrbuf[512];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo));
+ req.hdr.nlmsg_type = RTM_NEWLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifinfo.ifi_index = index;
+
+ struct rtattr* rta = reinterpret_cast<struct rtattr*>(
+ reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len));
+ rta->rta_type = IFLA_ADDRESS;
+ rta->rta_len = RTA_LENGTH(addrlen);
+ req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen);
+ memcpy(RTA_DATA(rta), addr, addrlen);
+
+ return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len);
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h
new file mode 100644
index 000000000..2c018e487
--- /dev/null
+++ b/test/syscalls/linux/socket_netlink_route_util.h
@@ -0,0 +1,55 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_
+
+#include <linux/netlink.h>
+#include <linux/rtnetlink.h>
+
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "test/syscalls/linux/socket_netlink_util.h"
+
+namespace gvisor {
+namespace testing {
+
+struct Link {
+ int index;
+ int16_t type;
+ std::string name;
+};
+
+PosixError DumpLinks(const FileDescriptor& fd, uint32_t seq,
+ const std::function<void(const struct nlmsghdr* hdr)>& fn);
+
+PosixErrorOr<std::vector<Link>> DumpLinks();
+
+PosixErrorOr<absl::optional<Link>> FindLoopbackLink();
+
+// LinkAddLocalAddr sets IFA_LOCAL attribute on the interface.
+PosixError LinkAddLocalAddr(int index, int family, int prefixlen,
+ const void* addr, int addrlen);
+
+// LinkChangeFlags changes interface flags. E.g. IFF_UP.
+PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change);
+
+// LinkSetMacAddr sets IFLA_ADDRESS attribute of the interface.
+PosixError LinkSetMacAddr(int index, const void* addr, int addrlen);
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_
diff --git a/test/syscalls/linux/socket_netlink_util.cc b/test/syscalls/linux/socket_netlink_util.cc
index cd2212a1a..952eecfe8 100644
--- a/test/syscalls/linux/socket_netlink_util.cc
+++ b/test/syscalls/linux/socket_netlink_util.cc
@@ -16,6 +16,7 @@
#include <linux/if_arp.h>
#include <linux/netlink.h>
+#include <linux/rtnetlink.h>
#include <sys/socket.h>
#include <vector>
@@ -71,9 +72,10 @@ PosixError NetlinkRequestResponse(
iov.iov_base = buf.data();
iov.iov_len = buf.size();
- // Response is a series of NLM_F_MULTI messages, ending with a NLMSG_DONE
- // message.
+ // If NLM_F_MULTI is set, response is a series of messages that ends with a
+ // NLMSG_DONE message.
int type = -1;
+ int flags = 0;
do {
int len;
RETURN_ERROR_IF_SYSCALL_FAIL(len = RetryEINTR(recvmsg)(fd.get(), &msg, 0));
@@ -89,6 +91,7 @@ PosixError NetlinkRequestResponse(
for (struct nlmsghdr* hdr = reinterpret_cast<struct nlmsghdr*>(buf.data());
NLMSG_OK(hdr, len); hdr = NLMSG_NEXT(hdr, len)) {
fn(hdr);
+ flags = hdr->nlmsg_flags;
type = hdr->nlmsg_type;
// Done should include an integer payload for dump_done_errno.
// See net/netlink/af_netlink.c:netlink_dump
@@ -98,11 +101,11 @@ PosixError NetlinkRequestResponse(
EXPECT_GE(hdr->nlmsg_len, NLMSG_LENGTH(sizeof(int)));
}
}
- } while (type != NLMSG_DONE && type != NLMSG_ERROR);
+ } while ((flags & NLM_F_MULTI) && type != NLMSG_DONE && type != NLMSG_ERROR);
if (expect_nlmsgerr) {
EXPECT_EQ(type, NLMSG_ERROR);
- } else {
+ } else if (flags & NLM_F_MULTI) {
EXPECT_EQ(type, NLMSG_DONE);
}
return NoError();
@@ -146,5 +149,39 @@ PosixError NetlinkRequestResponseSingle(
return NoError();
}
+PosixError NetlinkRequestAckOrError(const FileDescriptor& fd, uint32_t seq,
+ void* request, size_t len) {
+ // Dummy negative number for no error message received.
+ // We won't get a negative error number so there will be no confusion.
+ int err = -42;
+ RETURN_IF_ERRNO(NetlinkRequestResponse(
+ fd, request, len,
+ [&](const struct nlmsghdr* hdr) {
+ EXPECT_EQ(NLMSG_ERROR, hdr->nlmsg_type);
+ EXPECT_EQ(hdr->nlmsg_seq, seq);
+ EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct nlmsgerr));
+
+ const struct nlmsgerr* msg =
+ reinterpret_cast<const struct nlmsgerr*>(NLMSG_DATA(hdr));
+ err = -msg->error;
+ },
+ true));
+ return PosixError(err);
+}
+
+const struct rtattr* FindRtAttr(const struct nlmsghdr* hdr,
+ const struct ifinfomsg* msg, int16_t attr) {
+ const int ifi_space = NLMSG_SPACE(sizeof(*msg));
+ int attrlen = hdr->nlmsg_len - ifi_space;
+ const struct rtattr* rta = reinterpret_cast<const struct rtattr*>(
+ reinterpret_cast<const uint8_t*>(hdr) + NLMSG_ALIGN(ifi_space));
+ for (; RTA_OK(rta, attrlen); rta = RTA_NEXT(rta, attrlen)) {
+ if (rta->rta_type == attr) {
+ return rta;
+ }
+ }
+ return nullptr;
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_netlink_util.h b/test/syscalls/linux/socket_netlink_util.h
index 3678c0599..e13ead406 100644
--- a/test/syscalls/linux/socket_netlink_util.h
+++ b/test/syscalls/linux/socket_netlink_util.h
@@ -19,6 +19,7 @@
// socket.h has to be included before if_arp.h.
#include <linux/if_arp.h>
#include <linux/netlink.h>
+#include <linux/rtnetlink.h>
#include "test/util/file_descriptor.h"
#include "test/util/posix_error.h"
@@ -47,6 +48,14 @@ PosixError NetlinkRequestResponseSingle(
const FileDescriptor& fd, void* request, size_t len,
const std::function<void(const struct nlmsghdr* hdr)>& fn);
+// Send the passed request then expect and return an ack or error.
+PosixError NetlinkRequestAckOrError(const FileDescriptor& fd, uint32_t seq,
+ void* request, size_t len);
+
+// Find rtnetlink attribute in message.
+const struct rtattr* FindRtAttr(const struct nlmsghdr* hdr,
+ const struct ifinfomsg* msg, int16_t attr);
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc
index eff7d577e..5d3a39868 100644
--- a/test/syscalls/linux/socket_test_util.cc
+++ b/test/syscalls/linux/socket_test_util.cc
@@ -18,10 +18,13 @@
#include <poll.h>
#include <sys/socket.h>
+#include <memory>
+
#include "gtest/gtest.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/time/clock.h"
+#include "absl/types/optional.h"
#include "test/util/file_descriptor.h"
#include "test/util/posix_error.h"
#include "test/util/temp_path.h"
@@ -109,7 +112,10 @@ Creator<SocketPair> AcceptBindSocketPairCreator(bool abstract, int domain,
MaybeSave(); // Unlinked path.
}
- return absl::make_unique<AddrFDSocketPair>(connected, accepted, bind_addr,
+ // accepted is before connected to destruct connected before accepted.
+ // Destructors for nonstatic member objects are called in the reverse order
+ // in which they appear in the class declaration.
+ return absl::make_unique<AddrFDSocketPair>(accepted, connected, bind_addr,
extra_addr);
};
}
@@ -311,11 +317,16 @@ PosixErrorOr<T> BindIP(int fd, bool dual_stack) {
}
template <typename T>
-PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> CreateTCPAcceptBindSocketPair(
- int bound, int connected, int type, bool dual_stack) {
- ASSIGN_OR_RETURN_ERRNO(T bind_addr, BindIP<T>(bound, dual_stack));
- RETURN_ERROR_IF_SYSCALL_FAIL(listen(bound, /* backlog = */ 5));
+PosixErrorOr<T> TCPBindAndListen(int fd, bool dual_stack) {
+ ASSIGN_OR_RETURN_ERRNO(T addr, BindIP<T>(fd, dual_stack));
+ RETURN_ERROR_IF_SYSCALL_FAIL(listen(fd, /* backlog = */ 5));
+ return addr;
+}
+template <typename T>
+PosixErrorOr<std::unique_ptr<AddrFDSocketPair>>
+CreateTCPConnectAcceptSocketPair(int bound, int connected, int type,
+ bool dual_stack, T bind_addr) {
int connect_result = 0;
RETURN_ERROR_IF_SYSCALL_FAIL(
(connect_result = RetryEINTR(connect)(
@@ -358,16 +369,27 @@ PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> CreateTCPAcceptBindSocketPair(
absl::SleepFor(absl::Seconds(1));
}
- // Cleanup no longer needed resources.
- RETURN_ERROR_IF_SYSCALL_FAIL(close(bound));
- MaybeSave(); // Successful close.
-
T extra_addr = {};
LocalhostAddr(&extra_addr, dual_stack);
return absl::make_unique<AddrFDSocketPair>(connected, accepted, bind_addr,
extra_addr);
}
+template <typename T>
+PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> CreateTCPAcceptBindSocketPair(
+ int bound, int connected, int type, bool dual_stack) {
+ ASSIGN_OR_RETURN_ERRNO(T bind_addr, TCPBindAndListen<T>(bound, dual_stack));
+
+ auto result = CreateTCPConnectAcceptSocketPair(bound, connected, type,
+ dual_stack, bind_addr);
+
+ // Cleanup no longer needed resources.
+ RETURN_ERROR_IF_SYSCALL_FAIL(close(bound));
+ MaybeSave(); // Successful close.
+
+ return result;
+}
+
Creator<SocketPair> TCPAcceptBindSocketPairCreator(int domain, int type,
int protocol,
bool dual_stack) {
@@ -389,6 +411,63 @@ Creator<SocketPair> TCPAcceptBindSocketPairCreator(int domain, int type,
};
}
+Creator<SocketPair> TCPAcceptBindPersistentListenerSocketPairCreator(
+ int domain, int type, int protocol, bool dual_stack) {
+ // These are lazily initialized below, on the first call to the returned
+ // lambda. These values are private to each returned lambda, but shared across
+ // invocations of a specific lambda.
+ //
+ // The sharing allows pairs created with the same parameters to share a
+ // listener. This prevents future connects from failing if the connecting
+ // socket selects a port which had previously been used by a listening socket
+ // that still has some connections in TIME-WAIT.
+ //
+ // The lazy initialization is to avoid creating sockets during parameter
+ // enumeration. This is important because parameters are enumerated during the
+ // build process where networking may not be available.
+ auto listener = std::make_shared<absl::optional<int>>(absl::optional<int>());
+ auto addr4 = std::make_shared<absl::optional<sockaddr_in>>(
+ absl::optional<sockaddr_in>());
+ auto addr6 = std::make_shared<absl::optional<sockaddr_in6>>(
+ absl::optional<sockaddr_in6>());
+
+ return [=]() -> PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> {
+ int connected;
+ RETURN_ERROR_IF_SYSCALL_FAIL(connected = socket(domain, type, protocol));
+ MaybeSave(); // Successful socket creation.
+
+ // Share the listener across invocations.
+ if (!listener->has_value()) {
+ int fd = socket(domain, type, protocol);
+ if (fd < 0) {
+ return PosixError(errno, absl::StrCat("socket(", domain, ", ", type,
+ ", ", protocol, ")"));
+ }
+ listener->emplace(fd);
+ MaybeSave(); // Successful socket creation.
+ }
+
+ // Bind the listener once, but create a new connect/accept pair each
+ // time.
+ if (domain == AF_INET) {
+ if (!addr4->has_value()) {
+ addr4->emplace(
+ TCPBindAndListen<sockaddr_in>(listener->value(), dual_stack)
+ .ValueOrDie());
+ }
+ return CreateTCPConnectAcceptSocketPair(listener->value(), connected,
+ type, dual_stack, addr4->value());
+ }
+ if (!addr6->has_value()) {
+ addr6->emplace(
+ TCPBindAndListen<sockaddr_in6>(listener->value(), dual_stack)
+ .ValueOrDie());
+ }
+ return CreateTCPConnectAcceptSocketPair(listener->value(), connected, type,
+ dual_stack, addr6->value());
+ };
+}
+
template <typename T>
PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> CreateUDPBoundSocketPair(
int sock1, int sock2, int type, bool dual_stack) {
@@ -518,8 +597,8 @@ size_t CalculateUnixSockAddrLen(const char* sun_path) {
if (sun_path[0] == 0) {
return sizeof(sockaddr_un);
}
- // Filesystem addresses use the address length plus the 2 byte sun_family and
- // null terminator.
+ // Filesystem addresses use the address length plus the 2 byte sun_family
+ // and null terminator.
return strlen(sun_path) + 3;
}
@@ -726,6 +805,24 @@ TestAddress V4MappedLoopback() {
return t;
}
+TestAddress V4Multicast() {
+ TestAddress t("V4Multicast");
+ t.addr.ss_family = AF_INET;
+ t.addr_len = sizeof(sockaddr_in);
+ reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr =
+ inet_addr(kMulticastAddress);
+ return t;
+}
+
+TestAddress V4Broadcast() {
+ TestAddress t("V4Broadcast");
+ t.addr.ss_family = AF_INET;
+ t.addr_len = sizeof(sockaddr_in);
+ reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr =
+ htonl(INADDR_BROADCAST);
+ return t;
+}
+
TestAddress V6Any() {
TestAddress t("V6Any");
t.addr.ss_family = AF_INET6;
diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h
index 2dbb8bed3..734b48b96 100644
--- a/test/syscalls/linux/socket_test_util.h
+++ b/test/syscalls/linux/socket_test_util.h
@@ -273,6 +273,12 @@ Creator<SocketPair> TCPAcceptBindSocketPairCreator(int domain, int type,
int protocol,
bool dual_stack);
+// TCPAcceptBindPersistentListenerSocketPairCreator is like
+// TCPAcceptBindSocketPairCreator, except it uses the same listening socket to
+// create all SocketPairs.
+Creator<SocketPair> TCPAcceptBindPersistentListenerSocketPairCreator(
+ int domain, int type, int protocol, bool dual_stack);
+
// UDPBidirectionalBindSocketPairCreator returns a Creator<SocketPair> that
// obtains file descriptors by invoking the bind() and connect() syscalls on UDP
// sockets.
@@ -478,10 +484,15 @@ struct TestAddress {
: description(std::move(description)), addr(), addr_len() {}
};
+constexpr char kMulticastAddress[] = "224.0.2.1";
+constexpr char kBroadcastAddress[] = "255.255.255.255";
+
TestAddress V4Any();
+TestAddress V4Broadcast();
TestAddress V4Loopback();
TestAddress V4MappedAny();
TestAddress V4MappedLoopback();
+TestAddress V4Multicast();
TestAddress V6Any();
TestAddress V6Loopback();
diff --git a/test/syscalls/linux/socket_unix_abstract_nonblock.cc b/test/syscalls/linux/socket_unix_abstract_nonblock.cc
index be31ab2a7..8bef76b67 100644
--- a/test/syscalls/linux/socket_unix_abstract_nonblock.cc
+++ b/test/syscalls/linux/socket_unix_abstract_nonblock.cc
@@ -21,6 +21,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return ApplyVec<SocketPairKind>(
@@ -33,5 +34,6 @@ INSTANTIATE_TEST_SUITE_P(
NonBlockingAbstractUnixSockets, NonBlockingSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_blocking_local.cc b/test/syscalls/linux/socket_unix_blocking_local.cc
index 6f84221b2..77cb8c6d6 100644
--- a/test/syscalls/linux/socket_unix_blocking_local.cc
+++ b/test/syscalls/linux/socket_unix_blocking_local.cc
@@ -21,6 +21,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return VecCat<SocketPairKind>(
@@ -39,5 +40,6 @@ INSTANTIATE_TEST_SUITE_P(
NonBlockingUnixDomainSockets, BlockingSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_dgram_local.cc b/test/syscalls/linux/socket_unix_dgram_local.cc
index 9134fcdf7..31d2d5216 100644
--- a/test/syscalls/linux/socket_unix_dgram_local.cc
+++ b/test/syscalls/linux/socket_unix_dgram_local.cc
@@ -23,6 +23,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return VecCat<SocketPairKind>(VecCat<SocketPairKind>(
@@ -52,5 +53,6 @@ INSTANTIATE_TEST_SUITE_P(
DgramUnixSockets, NonStreamSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_domain.cc b/test/syscalls/linux/socket_unix_domain.cc
index fa3efc7f8..f7dff8b4d 100644
--- a/test/syscalls/linux/socket_unix_domain.cc
+++ b/test/syscalls/linux/socket_unix_domain.cc
@@ -21,6 +21,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return ApplyVec<SocketPairKind>(
@@ -33,5 +34,6 @@ INSTANTIATE_TEST_SUITE_P(
AllUnixDomainSockets, AllSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_filesystem_nonblock.cc b/test/syscalls/linux/socket_unix_filesystem_nonblock.cc
index 8ba7af971..6700b4d90 100644
--- a/test/syscalls/linux/socket_unix_filesystem_nonblock.cc
+++ b/test/syscalls/linux/socket_unix_filesystem_nonblock.cc
@@ -21,6 +21,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return ApplyVec<SocketPairKind>(
@@ -33,5 +34,6 @@ INSTANTIATE_TEST_SUITE_P(
NonBlockingFilesystemUnixSockets, NonBlockingSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_non_stream.cc b/test/syscalls/linux/socket_unix_non_stream.cc
index 276a94eb8..884319e1d 100644
--- a/test/syscalls/linux/socket_unix_non_stream.cc
+++ b/test/syscalls/linux/socket_unix_non_stream.cc
@@ -109,7 +109,7 @@ PosixErrorOr<std::vector<Mapping>> CreateFragmentedRegion(const int size,
}
// A contiguous iov that is heavily fragmented in FileMem can still be sent
-// successfully.
+// successfully. See b/115833655.
TEST_P(UnixNonStreamSocketPairTest, FragmentedSendMsg) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
@@ -165,7 +165,7 @@ TEST_P(UnixNonStreamSocketPairTest, FragmentedSendMsg) {
}
// A contiguous iov that is heavily fragmented in FileMem can still be received
-// into successfully.
+// into successfully. Regression test for b/115833655.
TEST_P(UnixNonStreamSocketPairTest, FragmentedRecvMsg) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
diff --git a/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc b/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc
index 8855d5001..fddcdf1c5 100644
--- a/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc
+++ b/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc
@@ -21,6 +21,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return VecCat<SocketPairKind>(
@@ -36,5 +37,6 @@ INSTANTIATE_TEST_SUITE_P(
BlockingNonStreamUnixSockets, BlockingNonStreamSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_pair.cc b/test/syscalls/linux/socket_unix_pair.cc
index 411fb4518..85999db04 100644
--- a/test/syscalls/linux/socket_unix_pair.cc
+++ b/test/syscalls/linux/socket_unix_pair.cc
@@ -22,6 +22,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return VecCat<SocketPairKind>(ApplyVec<SocketPairKind>(
@@ -38,5 +39,6 @@ INSTANTIATE_TEST_SUITE_P(
AllUnixDomainSockets, UnixSocketPairCmsgTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_pair_nonblock.cc b/test/syscalls/linux/socket_unix_pair_nonblock.cc
index 3135d325f..281410a9a 100644
--- a/test/syscalls/linux/socket_unix_pair_nonblock.cc
+++ b/test/syscalls/linux/socket_unix_pair_nonblock.cc
@@ -21,6 +21,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return ApplyVec<SocketPairKind>(
@@ -33,5 +34,6 @@ INSTANTIATE_TEST_SUITE_P(
NonBlockingUnixSockets, NonBlockingSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_seqpacket_local.cc b/test/syscalls/linux/socket_unix_seqpacket_local.cc
index dff75a532..69a5f150d 100644
--- a/test/syscalls/linux/socket_unix_seqpacket_local.cc
+++ b/test/syscalls/linux/socket_unix_seqpacket_local.cc
@@ -23,6 +23,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return VecCat<SocketPairKind>(VecCat<SocketPairKind>(
@@ -52,5 +53,6 @@ INSTANTIATE_TEST_SUITE_P(
SeqpacketUnixSockets, UnixNonStreamSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_stream_blocking_local.cc b/test/syscalls/linux/socket_unix_stream_blocking_local.cc
index 08e579ba7..8429bd429 100644
--- a/test/syscalls/linux/socket_unix_stream_blocking_local.cc
+++ b/test/syscalls/linux/socket_unix_stream_blocking_local.cc
@@ -21,6 +21,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return {
@@ -34,5 +35,6 @@ INSTANTIATE_TEST_SUITE_P(
BlockingStreamUnixSockets, BlockingStreamSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_stream_local.cc b/test/syscalls/linux/socket_unix_stream_local.cc
index 65eef1a81..a7e3449a9 100644
--- a/test/syscalls/linux/socket_unix_stream_local.cc
+++ b/test/syscalls/linux/socket_unix_stream_local.cc
@@ -21,6 +21,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return VecCat<SocketPairKind>(
@@ -42,5 +43,6 @@ INSTANTIATE_TEST_SUITE_P(
StreamUnixSockets, StreamSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_stream_nonblock_local.cc b/test/syscalls/linux/socket_unix_stream_nonblock_local.cc
index 1936aa135..4b763c8e2 100644
--- a/test/syscalls/linux/socket_unix_stream_nonblock_local.cc
+++ b/test/syscalls/linux/socket_unix_stream_nonblock_local.cc
@@ -20,6 +20,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return {
@@ -33,5 +34,6 @@ INSTANTIATE_TEST_SUITE_P(
NonBlockingStreamUnixSockets, NonBlockingStreamSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc
index 85232cb1f..faa1247f6 100644
--- a/test/syscalls/linux/splice.cc
+++ b/test/syscalls/linux/splice.cc
@@ -60,6 +60,62 @@ TEST(SpliceTest, TwoRegularFiles) {
SyscallFailsWithErrno(EINVAL));
}
+int memfd_create(const std::string& name, unsigned int flags) {
+ return syscall(__NR_memfd_create, name.c_str(), flags);
+}
+
+TEST(SpliceTest, NegativeOffset) {
+ // Create a new pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Fill the pipe.
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Open the output file as write only.
+ int fd;
+ EXPECT_THAT(fd = memfd_create("negative", 0), SyscallSucceeds());
+ const FileDescriptor out_fd(fd);
+
+ loff_t out_offset = 0xffffffffffffffffull;
+ constexpr int kSize = 2;
+ EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), &out_offset, kSize, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// Write offset + size overflows int64.
+//
+// This is a regression test for b/148041624.
+TEST(SpliceTest, WriteOverflow) {
+ // Create a new pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Fill the pipe.
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Open the output file.
+ int fd;
+ EXPECT_THAT(fd = memfd_create("overflow", 0), SyscallSucceeds());
+ const FileDescriptor out_fd(fd);
+
+ // out_offset + kSize overflows INT64_MAX.
+ loff_t out_offset = 0x7ffffffffffffffeull;
+ constexpr int kSize = 3;
+ EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), &out_offset, kSize, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
TEST(SpliceTest, SamePipe) {
// Create a new pipe.
int fds[2];
diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc
index 388d75835..c951ac3b3 100644
--- a/test/syscalls/linux/stat.cc
+++ b/test/syscalls/linux/stat.cc
@@ -557,6 +557,8 @@ TEST(SimpleStatTest, AnonDeviceAllocatesUniqueInodesAcrossSaveRestore) {
#ifndef SYS_statx
#if defined(__x86_64__)
#define SYS_statx 332
+#elif defined(__aarch64__)
+#define SYS_statx 291
#else
#error "Unknown architecture"
#endif
diff --git a/test/syscalls/linux/symlink.cc b/test/syscalls/linux/symlink.cc
index b249ff91f..03ee1250d 100644
--- a/test/syscalls/linux/symlink.cc
+++ b/test/syscalls/linux/symlink.cc
@@ -38,7 +38,7 @@ mode_t FilePermission(const std::string& path) {
}
// Test that name collisions are checked on the new link path, not the source
-// path.
+// path. Regression test for b/31782115.
TEST(SymlinkTest, CanCreateSymlinkWithCachedSourceDirent) {
const std::string srcname = NewTempAbsPath();
const std::string newname = NewTempAbsPath();
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index 33a5ac66c..c4591a3b9 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -244,7 +244,8 @@ TEST_P(TcpSocketTest, ZeroWriteAllowed) {
}
// Test that a non-blocking write with a buffer that is larger than the send
-// buffer size will not actually write the whole thing at once.
+// buffer size will not actually write the whole thing at once. Regression test
+// for b/64438887.
TEST_P(TcpSocketTest, NonblockingLargeWrite) {
// Set the FD to O_NONBLOCK.
int opts;
@@ -1286,6 +1287,68 @@ TEST_P(SimpleTcpSocketTest, SetTCPUserTimeout) {
EXPECT_EQ(get, kTCPUserTimeout);
}
+TEST_P(SimpleTcpSocketTest, SetTCPDeferAcceptNeg) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ // -ve TCP_DEFER_ACCEPT is same as setting it to zero.
+ constexpr int kNeg = -1;
+ EXPECT_THAT(
+ setsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &kNeg, sizeof(kNeg)),
+ SyscallSucceeds());
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 0);
+}
+
+TEST_P(SimpleTcpSocketTest, GetTCPDeferAcceptDefault) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 0);
+}
+
+TEST_P(SimpleTcpSocketTest, SetTCPDeferAcceptGreaterThanZero) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ // kTCPDeferAccept is in seconds.
+ // NOTE: linux translates seconds to # of retries and back from
+ // #of retries to seconds. Which means only certain values
+ // translate back exactly. That's why we use 3 here, a value of
+ // 5 will result in us getting back 7 instead of 5 in the
+ // getsockopt.
+ constexpr int kTCPDeferAccept = 3;
+ ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT,
+ &kTCPDeferAccept, sizeof(kTCPDeferAccept)),
+ SyscallSucceeds());
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &get, &get_len),
+ SyscallSucceeds());
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kTCPDeferAccept);
+}
+
+TEST_P(SimpleTcpSocketTest, RecvOnClosedSocket) {
+ auto s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ char buf[1];
+ EXPECT_THAT(recv(s.get(), buf, 0, 0), SyscallFailsWithErrno(ENOTCONN));
+ EXPECT_THAT(recv(s.get(), buf, sizeof(buf), 0),
+ SyscallFailsWithErrno(ENOTCONN));
+}
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest,
::testing::Values(AF_INET, AF_INET6));
diff --git a/test/syscalls/linux/time.cc b/test/syscalls/linux/time.cc
index c7eead17e..1ccb95733 100644
--- a/test/syscalls/linux/time.cc
+++ b/test/syscalls/linux/time.cc
@@ -62,6 +62,7 @@ TEST(TimeTest, VsyscallTime_InvalidAddressSIGSEGV) {
::testing::KilledBySignal(SIGSEGV), "");
}
+// Mimics the gettimeofday(2) wrapper from the Go runtime <= 1.2.
int vsyscall_gettimeofday(struct timeval* tv, struct timezone* tz) {
constexpr uint64_t kVsyscallGettimeofdayEntry = 0xffffffffff600000;
return reinterpret_cast<int (*)(struct timeval*, struct timezone*)>(
diff --git a/test/syscalls/linux/timers.cc b/test/syscalls/linux/timers.cc
index 3db18d7ac..4b3c44527 100644
--- a/test/syscalls/linux/timers.cc
+++ b/test/syscalls/linux/timers.cc
@@ -297,9 +297,13 @@ class IntervalTimer {
PosixErrorOr<IntervalTimer> TimerCreate(clockid_t clockid,
const struct sigevent& sev) {
int timerid;
- if (syscall(SYS_timer_create, clockid, &sev, &timerid) < 0) {
+ int ret = syscall(SYS_timer_create, clockid, &sev, &timerid);
+ if (ret < 0) {
return PosixError(errno, "timer_create");
}
+ if (ret > 0) {
+ return PosixError(EINVAL, "timer_create should never return positive");
+ }
MaybeSave();
return IntervalTimer(timerid);
}
@@ -317,6 +321,18 @@ TEST(IntervalTimerTest, IsInitiallyStopped) {
EXPECT_EQ(0, its.it_value.tv_nsec);
}
+// Kernel can create multiple timers without issue.
+//
+// Regression test for gvisor.dev/issue/1738.
+TEST(IntervalTimerTest, MultipleTimers) {
+ struct sigevent sev = {};
+ sev.sigev_notify = SIGEV_NONE;
+ const auto timer1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev));
+ const auto timer2 =
+ ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev));
+}
+
TEST(IntervalTimerTest, SingleShotSilent) {
struct sigevent sev = {};
sev.sigev_notify = SIGEV_NONE;
@@ -642,5 +658,5 @@ int main(int argc, char** argv) {
}
}
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/tkill.cc b/test/syscalls/linux/tkill.cc
index bae377c69..8d8ebbb24 100644
--- a/test/syscalls/linux/tkill.cc
+++ b/test/syscalls/linux/tkill.cc
@@ -54,7 +54,7 @@ void SigHandler(int sig, siginfo_t* info, void* context) {
TEST_CHECK(info->si_code == SI_TKILL);
}
-// Test with a real signal.
+// Test with a real signal. Regression test for b/24790092.
TEST(TkillTest, ValidTIDAndRealSignal) {
struct sigaction sa;
sa.sa_sigaction = SigHandler;
diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc
new file mode 100644
index 000000000..f6ac9d7b8
--- /dev/null
+++ b/test/syscalls/linux/tuntap.cc
@@ -0,0 +1,346 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <linux/capability.h>
+#include <linux/if_arp.h>
+#include <linux/if_ether.h>
+#include <linux/if_tun.h>
+#include <netinet/ip.h>
+#include <netinet/ip_icmp.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/ascii.h"
+#include "absl/strings/str_split.h"
+#include "test/syscalls/linux/socket_netlink_route_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+constexpr int kIPLen = 4;
+
+constexpr const char kDevNetTun[] = "/dev/net/tun";
+constexpr const char kTapName[] = "tap0";
+
+constexpr const uint8_t kMacA[ETH_ALEN] = {0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA};
+constexpr const uint8_t kMacB[ETH_ALEN] = {0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB};
+
+PosixErrorOr<std::set<std::string>> DumpLinkNames() {
+ ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks());
+ std::set<std::string> names;
+ for (const auto& link : links) {
+ names.emplace(link.name);
+ }
+ return names;
+}
+
+PosixErrorOr<absl::optional<Link>> GetLinkByName(const std::string& name) {
+ ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks());
+ for (const auto& link : links) {
+ if (link.name == name) {
+ return absl::optional<Link>(link);
+ }
+ }
+ return absl::optional<Link>();
+}
+
+struct pihdr {
+ uint16_t pi_flags;
+ uint16_t pi_protocol;
+} __attribute__((packed));
+
+struct ping_pkt {
+ pihdr pi;
+ struct ethhdr eth;
+ struct iphdr ip;
+ struct icmphdr icmp;
+ char payload[64];
+} __attribute__((packed));
+
+ping_pkt CreatePingPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip,
+ const uint8_t dstmac[ETH_ALEN], const char* dstip) {
+ ping_pkt pkt = {};
+
+ pkt.pi.pi_protocol = htons(ETH_P_IP);
+
+ memcpy(pkt.eth.h_dest, dstmac, sizeof(pkt.eth.h_dest));
+ memcpy(pkt.eth.h_source, srcmac, sizeof(pkt.eth.h_source));
+ pkt.eth.h_proto = htons(ETH_P_IP);
+
+ pkt.ip.ihl = 5;
+ pkt.ip.version = 4;
+ pkt.ip.tos = 0;
+ pkt.ip.tot_len = htons(sizeof(struct iphdr) + sizeof(struct icmphdr) +
+ sizeof(pkt.payload));
+ pkt.ip.id = 1;
+ pkt.ip.frag_off = 1 << 6; // Do not fragment
+ pkt.ip.ttl = 64;
+ pkt.ip.protocol = IPPROTO_ICMP;
+ inet_pton(AF_INET, dstip, &pkt.ip.daddr);
+ inet_pton(AF_INET, srcip, &pkt.ip.saddr);
+ pkt.ip.check = IPChecksum(pkt.ip);
+
+ pkt.icmp.type = ICMP_ECHO;
+ pkt.icmp.code = 0;
+ pkt.icmp.checksum = 0;
+ pkt.icmp.un.echo.sequence = 1;
+ pkt.icmp.un.echo.id = 1;
+
+ strncpy(pkt.payload, "abcd", sizeof(pkt.payload));
+ pkt.icmp.checksum = ICMPChecksum(pkt.icmp, pkt.payload, sizeof(pkt.payload));
+
+ return pkt;
+}
+
+struct arp_pkt {
+ pihdr pi;
+ struct ethhdr eth;
+ struct arphdr arp;
+ uint8_t arp_sha[ETH_ALEN];
+ uint8_t arp_spa[kIPLen];
+ uint8_t arp_tha[ETH_ALEN];
+ uint8_t arp_tpa[kIPLen];
+} __attribute__((packed));
+
+std::string CreateArpPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip,
+ const uint8_t dstmac[ETH_ALEN], const char* dstip) {
+ std::string buffer;
+ buffer.resize(sizeof(arp_pkt));
+
+ arp_pkt* pkt = reinterpret_cast<arp_pkt*>(&buffer[0]);
+ {
+ pkt->pi.pi_protocol = htons(ETH_P_ARP);
+
+ memcpy(pkt->eth.h_dest, kMacA, sizeof(pkt->eth.h_dest));
+ memcpy(pkt->eth.h_source, kMacB, sizeof(pkt->eth.h_source));
+ pkt->eth.h_proto = htons(ETH_P_ARP);
+
+ pkt->arp.ar_hrd = htons(ARPHRD_ETHER);
+ pkt->arp.ar_pro = htons(ETH_P_IP);
+ pkt->arp.ar_hln = ETH_ALEN;
+ pkt->arp.ar_pln = kIPLen;
+ pkt->arp.ar_op = htons(ARPOP_REPLY);
+
+ memcpy(pkt->arp_sha, srcmac, sizeof(pkt->arp_sha));
+ inet_pton(AF_INET, srcip, pkt->arp_spa);
+ memcpy(pkt->arp_tha, dstmac, sizeof(pkt->arp_tha));
+ inet_pton(AF_INET, dstip, pkt->arp_tpa);
+ }
+ return buffer;
+}
+
+} // namespace
+
+class TuntapTest : public ::testing::Test {
+ protected:
+ void TearDown() override {
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))) {
+ // Bring back capability if we had dropped it in test case.
+ ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, true));
+ }
+ }
+};
+
+TEST_F(TuntapTest, CreateInterfaceNoCap) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, false));
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ struct ifreq ifr = {};
+ ifr.ifr_flags = IFF_TAP;
+ strncpy(ifr.ifr_name, kTapName, IFNAMSIZ);
+
+ EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallFailsWithErrno(EPERM));
+}
+
+TEST_F(TuntapTest, CreateFixedNameInterface) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ struct ifreq ifr_set = {};
+ ifr_set.ifr_flags = IFF_TAP;
+ strncpy(ifr_set.ifr_name, kTapName, IFNAMSIZ);
+ EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr_set),
+ SyscallSucceedsWithValue(0));
+
+ struct ifreq ifr_get = {};
+ EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get),
+ SyscallSucceedsWithValue(0));
+
+ struct ifreq ifr_expect = ifr_set;
+ // See __tun_chr_ioctl() in net/drivers/tun.c.
+ ifr_expect.ifr_flags |= IFF_NOFILTER;
+
+ EXPECT_THAT(DumpLinkNames(),
+ IsPosixErrorOkAndHolds(::testing::Contains(kTapName)));
+ EXPECT_THAT(memcmp(&ifr_expect, &ifr_get, sizeof(ifr_get)), ::testing::Eq(0));
+}
+
+TEST_F(TuntapTest, CreateInterface) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ struct ifreq ifr = {};
+ ifr.ifr_flags = IFF_TAP;
+ // Empty ifr.ifr_name. Let kernel assign.
+
+ EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0));
+
+ struct ifreq ifr_get = {};
+ EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get),
+ SyscallSucceedsWithValue(0));
+
+ std::string ifname = ifr_get.ifr_name;
+ EXPECT_THAT(ifname, ::testing::StartsWith("tap"));
+ EXPECT_THAT(DumpLinkNames(),
+ IsPosixErrorOkAndHolds(::testing::Contains(ifname)));
+}
+
+TEST_F(TuntapTest, InvalidReadWrite) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ char buf[128] = {};
+ EXPECT_THAT(read(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD));
+ EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD));
+}
+
+TEST_F(TuntapTest, WriteToDownDevice) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ // FIXME: gVisor always creates enabled/up'd interfaces.
+ SKIP_IF(IsRunningOnGvisor());
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ // Device created should be down by default.
+ struct ifreq ifr = {};
+ ifr.ifr_flags = IFF_TAP;
+ EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0));
+
+ char buf[128] = {};
+ EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EIO));
+}
+
+// This test sets up a TAP device and pings kernel by sending ICMP echo request.
+//
+// It works as the following:
+// * Open /dev/net/tun, and create kTapName interface.
+// * Use rtnetlink to do initial setup of the interface:
+// * Assign IP address 10.0.0.1/24 to kernel.
+// * MAC address: kMacA
+// * Bring up the interface.
+// * Send an ICMP echo reqest (ping) packet from 10.0.0.2 (kMacB) to kernel.
+// * Loop to receive packets from TAP device/fd:
+// * If packet is an ICMP echo reply, it stops and passes the test.
+// * If packet is an ARP request, it responds with canned reply and resends
+// the
+// ICMP request packet.
+TEST_F(TuntapTest, PingKernel) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ // Interface creation.
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ struct ifreq ifr_set = {};
+ ifr_set.ifr_flags = IFF_TAP;
+ strncpy(ifr_set.ifr_name, kTapName, IFNAMSIZ);
+ EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr_set),
+ SyscallSucceedsWithValue(0));
+
+ absl::optional<Link> link =
+ ASSERT_NO_ERRNO_AND_VALUE(GetLinkByName(kTapName));
+ ASSERT_TRUE(link.has_value());
+
+ // Interface setup.
+ struct in_addr addr;
+ inet_pton(AF_INET, "10.0.0.1", &addr);
+ EXPECT_NO_ERRNO(LinkAddLocalAddr(link->index, AF_INET, /*prefixlen=*/24,
+ &addr, sizeof(addr)));
+
+ if (!IsRunningOnGvisor()) {
+ // FIXME: gVisor doesn't support setting MAC address on interfaces yet.
+ EXPECT_NO_ERRNO(LinkSetMacAddr(link->index, kMacA, sizeof(kMacA)));
+
+ // FIXME: gVisor always creates enabled/up'd interfaces.
+ EXPECT_NO_ERRNO(LinkChangeFlags(link->index, IFF_UP, IFF_UP));
+ }
+
+ ping_pkt ping_req = CreatePingPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1");
+ std::string arp_rep = CreateArpPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1");
+
+ // Send ping, this would trigger an ARP request on Linux.
+ EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)),
+ SyscallSucceedsWithValue(sizeof(ping_req)));
+
+ // Receive loop to process inbound packets.
+ struct inpkt {
+ union {
+ pihdr pi;
+ ping_pkt ping;
+ arp_pkt arp;
+ };
+ };
+ while (1) {
+ inpkt r = {};
+ int n = read(fd.get(), &r, sizeof(r));
+ EXPECT_THAT(n, SyscallSucceeds());
+
+ if (n < sizeof(pihdr)) {
+ std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol
+ << " len: " << n << std::endl;
+ continue;
+ }
+
+ // Process ARP packet.
+ if (n >= sizeof(arp_pkt) && r.pi.pi_protocol == htons(ETH_P_ARP)) {
+ // Respond with canned ARP reply.
+ EXPECT_THAT(write(fd.get(), arp_rep.data(), arp_rep.size()),
+ SyscallSucceedsWithValue(arp_rep.size()));
+ // First ping request might have been dropped due to mac address not in
+ // ARP cache. Send it again.
+ EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)),
+ SyscallSucceedsWithValue(sizeof(ping_req)));
+ }
+
+ // Process ping response packet.
+ if (n >= sizeof(ping_pkt) && r.pi.pi_protocol == ping_req.pi.pi_protocol &&
+ r.ping.ip.protocol == ping_req.ip.protocol &&
+ !memcmp(&r.ping.ip.saddr, &ping_req.ip.daddr, kIPLen) &&
+ !memcmp(&r.ping.ip.daddr, &ping_req.ip.saddr, kIPLen) &&
+ r.ping.icmp.type == 0 && r.ping.icmp.code == 0) {
+ // Ends and passes the test.
+ break;
+ }
+ }
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc
index a2f6ef8cc..740c7986d 100644
--- a/test/syscalls/linux/udp_socket_test_cases.cc
+++ b/test/syscalls/linux/udp_socket_test_cases.cc
@@ -21,6 +21,10 @@
#include <sys/socket.h>
#include <sys/types.h>
+#ifndef SIOCGSTAMP
+#include <linux/sockios.h>
+#endif
+
#include "gtest/gtest.h"
#include "absl/base/macros.h"
#include "absl/time/clock.h"
@@ -1349,9 +1353,6 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) {
// outgoing packets, and that a receiving socket with IP_RECVTOS or
// IPV6_RECVTCLASS will create the corresponding control message.
TEST_P(UdpSocketTest, SetAndReceiveTOS) {
- // TODO(b/144868438): IPV6_RECVTCLASS not supported for netstack.
- SKIP_IF((GetParam() != AddressFamily::kIpv4) && IsRunningOnGvisor() &&
- !IsRunningWithHostinet());
ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds());
@@ -1422,7 +1423,6 @@ TEST_P(UdpSocketTest, SetAndReceiveTOS) {
// TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or
// IPV6_RECVTCLASS will create the corresponding control message.
TEST_P(UdpSocketTest, SendAndReceiveTOS) {
- // TODO(b/144868438): IPV6_RECVTCLASS not supported for netstack.
// TODO(b/146661005): Setting TOS via cmsg not supported for netstack.
SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet());
ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
@@ -1495,6 +1495,5 @@ TEST_P(UdpSocketTest, SendAndReceiveTOS) {
memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos));
EXPECT_EQ(received_tos, sent_tos);
}
-
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/vfork.cc b/test/syscalls/linux/vfork.cc
index 0aaba482d..19d05998e 100644
--- a/test/syscalls/linux/vfork.cc
+++ b/test/syscalls/linux/vfork.cc
@@ -191,5 +191,5 @@ int main(int argc, char** argv) {
return gvisor::testing::RunChild();
}
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/xattr.cc b/test/syscalls/linux/xattr.cc
index ab21d68c6..8b00ef44c 100644
--- a/test/syscalls/linux/xattr.cc
+++ b/test/syscalls/linux/xattr.cc
@@ -24,6 +24,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "absl/container/flat_hash_set.h"
#include "test/syscalls/linux/file_base.h"
#include "test/util/capability_util.h"
#include "test/util/file_descriptor.h"
@@ -38,6 +39,16 @@ namespace {
class XattrTest : public FileTest {};
+TEST_F(XattrTest, XattrNonexistentFile) {
+ const char* path = "/does/not/exist";
+ EXPECT_THAT(setxattr(path, nullptr, nullptr, 0, /*flags=*/0),
+ SyscallFailsWithErrno(ENOENT));
+ EXPECT_THAT(getxattr(path, nullptr, nullptr, 0),
+ SyscallFailsWithErrno(ENOENT));
+ EXPECT_THAT(listxattr(path, nullptr, 0), SyscallFailsWithErrno(ENOENT));
+ EXPECT_THAT(removexattr(path, nullptr), SyscallFailsWithErrno(ENOENT));
+}
+
TEST_F(XattrTest, XattrNullName) {
const char* path = test_file_name_.c_str();
@@ -45,6 +56,7 @@ TEST_F(XattrTest, XattrNullName) {
SyscallFailsWithErrno(EFAULT));
EXPECT_THAT(getxattr(path, nullptr, nullptr, 0),
SyscallFailsWithErrno(EFAULT));
+ EXPECT_THAT(removexattr(path, nullptr), SyscallFailsWithErrno(EFAULT));
}
TEST_F(XattrTest, XattrEmptyName) {
@@ -53,6 +65,7 @@ TEST_F(XattrTest, XattrEmptyName) {
EXPECT_THAT(setxattr(path, "", nullptr, 0, /*flags=*/0),
SyscallFailsWithErrno(ERANGE));
EXPECT_THAT(getxattr(path, "", nullptr, 0), SyscallFailsWithErrno(ERANGE));
+ EXPECT_THAT(removexattr(path, ""), SyscallFailsWithErrno(ERANGE));
}
TEST_F(XattrTest, XattrLargeName) {
@@ -74,6 +87,7 @@ TEST_F(XattrTest, XattrLargeName) {
SyscallFailsWithErrno(ERANGE));
EXPECT_THAT(getxattr(path, name.c_str(), nullptr, 0),
SyscallFailsWithErrno(ERANGE));
+ EXPECT_THAT(removexattr(path, name.c_str()), SyscallFailsWithErrno(ERANGE));
}
TEST_F(XattrTest, XattrInvalidPrefix) {
@@ -83,6 +97,8 @@ TEST_F(XattrTest, XattrInvalidPrefix) {
SyscallFailsWithErrno(EOPNOTSUPP));
EXPECT_THAT(getxattr(path, name.c_str(), nullptr, 0),
SyscallFailsWithErrno(EOPNOTSUPP));
+ EXPECT_THAT(removexattr(path, name.c_str()),
+ SyscallFailsWithErrno(EOPNOTSUPP));
}
// Do not allow save/restore cycles after making the test file read-only, as
@@ -104,10 +120,16 @@ TEST_F(XattrTest, XattrReadOnly_NoRandomSave) {
EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0),
SyscallFailsWithErrno(EACCES));
+ EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(EACCES));
char buf = '-';
EXPECT_THAT(getxattr(path, name, &buf, size), SyscallSucceedsWithValue(size));
EXPECT_EQ(buf, val);
+
+ char list[sizeof(name)];
+ EXPECT_THAT(listxattr(path, list, sizeof(list)),
+ SyscallSucceedsWithValue(sizeof(name)));
+ EXPECT_STREQ(list, name);
}
// Do not allow save/restore cycles after making the test file write-only, as
@@ -128,6 +150,14 @@ TEST_F(XattrTest, XattrWriteOnly_NoRandomSave) {
EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds());
EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(EACCES));
+
+ // listxattr will succeed even without read permissions.
+ char list[sizeof(name)];
+ EXPECT_THAT(listxattr(path, list, sizeof(list)),
+ SyscallSucceedsWithValue(sizeof(name)));
+ EXPECT_STREQ(list, name);
+
+ EXPECT_THAT(removexattr(path, name), SyscallSucceeds());
}
TEST_F(XattrTest, XattrTrustedWithNonadmin) {
@@ -139,16 +169,24 @@ TEST_F(XattrTest, XattrTrustedWithNonadmin) {
const char name[] = "trusted.abc";
EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0),
SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(EPERM));
EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
}
TEST_F(XattrTest, XattrOnDirectory) {
TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
const char name[] = "user.test";
- EXPECT_THAT(setxattr(dir.path().c_str(), name, NULL, 0, /*flags=*/0),
+ EXPECT_THAT(setxattr(dir.path().c_str(), name, nullptr, 0, /*flags=*/0),
SyscallSucceeds());
- EXPECT_THAT(getxattr(dir.path().c_str(), name, NULL, 0),
+ EXPECT_THAT(getxattr(dir.path().c_str(), name, nullptr, 0),
SyscallSucceedsWithValue(0));
+
+ char list[sizeof(name)];
+ EXPECT_THAT(listxattr(dir.path().c_str(), list, sizeof(list)),
+ SyscallSucceedsWithValue(sizeof(name)));
+ EXPECT_STREQ(list, name);
+
+ EXPECT_THAT(removexattr(dir.path().c_str(), name), SyscallSucceeds());
}
TEST_F(XattrTest, XattrOnSymlink) {
@@ -156,28 +194,38 @@ TEST_F(XattrTest, XattrOnSymlink) {
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateSymlinkTo(dir.path(), test_file_name_));
const char name[] = "user.test";
- EXPECT_THAT(setxattr(link.path().c_str(), name, NULL, 0, /*flags=*/0),
+ EXPECT_THAT(setxattr(link.path().c_str(), name, nullptr, 0, /*flags=*/0),
SyscallSucceeds());
- EXPECT_THAT(getxattr(link.path().c_str(), name, NULL, 0),
+ EXPECT_THAT(getxattr(link.path().c_str(), name, nullptr, 0),
SyscallSucceedsWithValue(0));
+
+ char list[sizeof(name)];
+ EXPECT_THAT(listxattr(link.path().c_str(), list, sizeof(list)),
+ SyscallSucceedsWithValue(sizeof(name)));
+ EXPECT_STREQ(list, name);
+
+ EXPECT_THAT(removexattr(link.path().c_str(), name), SyscallSucceeds());
}
TEST_F(XattrTest, XattrOnInvalidFileTypes) {
const char name[] = "user.test";
char char_device[] = "/dev/zero";
- EXPECT_THAT(setxattr(char_device, name, NULL, 0, /*flags=*/0),
+ EXPECT_THAT(setxattr(char_device, name, nullptr, 0, /*flags=*/0),
SyscallFailsWithErrno(EPERM));
- EXPECT_THAT(getxattr(char_device, name, NULL, 0),
+ EXPECT_THAT(getxattr(char_device, name, nullptr, 0),
SyscallFailsWithErrno(ENODATA));
+ EXPECT_THAT(listxattr(char_device, nullptr, 0), SyscallSucceedsWithValue(0));
// Use tmpfs, where creation of named pipes is supported.
const std::string fifo = NewTempAbsPathInDir("/dev/shm");
const char* path = fifo.c_str();
EXPECT_THAT(mknod(path, S_IFIFO | S_IRUSR | S_IWUSR, 0), SyscallSucceeds());
- EXPECT_THAT(setxattr(path, name, NULL, 0, /*flags=*/0),
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0),
SyscallFailsWithErrno(EPERM));
- EXPECT_THAT(getxattr(path, name, NULL, 0), SyscallFailsWithErrno(ENODATA));
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
+ EXPECT_THAT(listxattr(path, nullptr, 0), SyscallSucceedsWithValue(0));
+ EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(EPERM));
}
TEST_F(XattrTest, SetxattrSizeSmallerThanValue) {
@@ -415,18 +463,104 @@ TEST_F(XattrTest, GetxattrNonexistentName) {
EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
}
-TEST_F(XattrTest, LGetSetxattrOnSymlink) {
+TEST_F(XattrTest, Listxattr) {
+ const char* path = test_file_name_.c_str();
+ const std::string name = "user.test";
+ const std::string name2 = "user.test2";
+ const std::string name3 = "user.test3";
+ EXPECT_THAT(setxattr(path, name.c_str(), nullptr, 0, /*flags=*/0),
+ SyscallSucceeds());
+ EXPECT_THAT(setxattr(path, name2.c_str(), nullptr, 0, /*flags=*/0),
+ SyscallSucceeds());
+ EXPECT_THAT(setxattr(path, name3.c_str(), nullptr, 0, /*flags=*/0),
+ SyscallSucceeds());
+
+ std::vector<char> list(name.size() + 1 + name2.size() + 1 + name3.size() + 1);
+ char* buf = list.data();
+ EXPECT_THAT(listxattr(path, buf, XATTR_SIZE_MAX),
+ SyscallSucceedsWithValue(list.size()));
+
+ absl::flat_hash_set<std::string> got = {};
+ for (char* p = buf; p < buf + list.size(); p += strlen(p) + 1) {
+ got.insert(std::string{p});
+ }
+
+ absl::flat_hash_set<std::string> expected = {name, name2, name3};
+ EXPECT_EQ(got, expected);
+}
+
+TEST_F(XattrTest, ListxattrNoXattrs) {
+ const char* path = test_file_name_.c_str();
+
+ std::vector<char> list, expected;
+ EXPECT_THAT(listxattr(path, list.data(), sizeof(list)),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(list, expected);
+
+ // Listxattr should succeed if there are no attributes, even if the buffer
+ // passed in is a nullptr.
+ EXPECT_THAT(listxattr(path, nullptr, sizeof(list)),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST_F(XattrTest, ListxattrNullBuffer) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
+
+ EXPECT_THAT(listxattr(path, nullptr, sizeof(name)),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+TEST_F(XattrTest, ListxattrSizeTooSmall) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
+
+ char list[sizeof(name) - 1];
+ EXPECT_THAT(listxattr(path, list, sizeof(list)),
+ SyscallFailsWithErrno(ERANGE));
+}
+
+TEST_F(XattrTest, ListxattrZeroSize) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
+ EXPECT_THAT(listxattr(path, nullptr, 0),
+ SyscallSucceedsWithValue(sizeof(name)));
+}
+
+TEST_F(XattrTest, RemoveXattr) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
+ EXPECT_THAT(removexattr(path, name), SyscallSucceeds());
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
+}
+
+TEST_F(XattrTest, RemoveXattrNonexistentName) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(ENODATA));
+}
+
+TEST_F(XattrTest, LXattrOnSymlink) {
+ const char name[] = "user.test";
TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateSymlinkTo(dir.path(), test_file_name_));
- EXPECT_THAT(lsetxattr(link.path().c_str(), nullptr, nullptr, 0, 0),
+ EXPECT_THAT(lsetxattr(link.path().c_str(), name, nullptr, 0, 0),
SyscallFailsWithErrno(EPERM));
- EXPECT_THAT(lgetxattr(link.path().c_str(), nullptr, nullptr, 0),
+ EXPECT_THAT(lgetxattr(link.path().c_str(), name, nullptr, 0),
SyscallFailsWithErrno(ENODATA));
+ EXPECT_THAT(llistxattr(link.path().c_str(), nullptr, 0),
+ SyscallSucceedsWithValue(0));
+ EXPECT_THAT(lremovexattr(link.path().c_str(), name),
+ SyscallFailsWithErrno(EPERM));
}
-TEST_F(XattrTest, LGetSetxattrOnNonsymlink) {
+TEST_F(XattrTest, LXattrOnNonsymlink) {
const char* path = test_file_name_.c_str();
const char name[] = "user.test";
int val = 1234;
@@ -438,9 +572,16 @@ TEST_F(XattrTest, LGetSetxattrOnNonsymlink) {
EXPECT_THAT(lgetxattr(path, name, &buf, size),
SyscallSucceedsWithValue(size));
EXPECT_EQ(buf, val);
+
+ char list[sizeof(name)];
+ EXPECT_THAT(llistxattr(path, list, sizeof(list)),
+ SyscallSucceedsWithValue(sizeof(name)));
+ EXPECT_STREQ(list, name);
+
+ EXPECT_THAT(lremovexattr(path, name), SyscallSucceeds());
}
-TEST_F(XattrTest, FGetSetxattr) {
+TEST_F(XattrTest, XattrWithFD) {
const FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_.c_str(), 0));
const char name[] = "user.test";
@@ -453,6 +594,13 @@ TEST_F(XattrTest, FGetSetxattr) {
EXPECT_THAT(fgetxattr(fd.get(), name, &buf, size),
SyscallSucceedsWithValue(size));
EXPECT_EQ(buf, val);
+
+ char list[sizeof(name)];
+ EXPECT_THAT(flistxattr(fd.get(), list, sizeof(list)),
+ SyscallSucceedsWithValue(sizeof(name)));
+ EXPECT_STREQ(list, name);
+
+ EXPECT_THAT(fremovexattr(fd.get(), name), SyscallSucceeds());
}
} // namespace
diff --git a/test/syscalls/syscall_test_runner.sh b/test/syscalls/syscall_test_runner.sh
deleted file mode 100755
index 864bb2de4..000000000
--- a/test/syscalls/syscall_test_runner.sh
+++ /dev/null
@@ -1,34 +0,0 @@
-#!/bin/bash
-
-# Copyright 2018 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# syscall_test_runner.sh is a simple wrapper around the go syscall test runner.
-# It exists so that we can build the syscall test runner once, and use it for
-# all syscall tests, rather than build it for each test run.
-
-set -euf -x -o pipefail
-
-echo -- "$@"
-
-if [[ -n "${TEST_UNDECLARED_OUTPUTS_DIR}" ]]; then
- mkdir -p "${TEST_UNDECLARED_OUTPUTS_DIR}"
- chmod a+rwx "${TEST_UNDECLARED_OUTPUTS_DIR}"
-fi
-
-# Get location of syscall_test_runner binary.
-readonly runner=$(find "${TEST_SRCDIR}" -name syscall_test_runner)
-
-# Pass the arguments of this script directly to the runner.
-exec "${runner}" "$@"
diff --git a/test/util/BUILD b/test/util/BUILD
index 1ac8b3fd6..8b5a0f25c 100644
--- a/test/util/BUILD
+++ b/test/util/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "cc_library", "cc_test", "select_system")
+load("//tools:defs.bzl", "cc_library", "cc_test", "gbenchmark", "gtest", "select_system")
package(
default_visibility = ["//:sandbox"],
@@ -41,7 +41,7 @@ cc_library(
":save_util",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
@@ -55,7 +55,7 @@ cc_library(
":posix_error",
":test_util",
"@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
@@ -67,7 +67,7 @@ cc_test(
":proc_util",
":test_main",
":test_util",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
@@ -87,7 +87,7 @@ cc_library(
":file_descriptor",
":posix_error",
"@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
@@ -101,7 +101,7 @@ cc_test(
":temp_path",
":test_main",
":test_util",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
@@ -134,7 +134,7 @@ cc_library(
":cleanup",
":posix_error",
":test_util",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
@@ -183,7 +183,7 @@ cc_library(
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:variant",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
@@ -194,7 +194,7 @@ cc_test(
deps = [
":posix_error",
":test_main",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
@@ -218,7 +218,7 @@ cc_library(
":cleanup",
":posix_error",
":test_util",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
@@ -233,7 +233,7 @@ cc_library(
":test_util",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
@@ -259,7 +259,8 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
+ gtest,
+ gbenchmark,
],
)
@@ -291,7 +292,7 @@ cc_library(
":posix_error",
":test_util",
"@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
@@ -302,7 +303,7 @@ cc_test(
deps = [
":test_main",
":test_util",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
@@ -322,7 +323,7 @@ cc_library(
":file_descriptor",
":posix_error",
":save_util",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
diff --git a/test/util/fs_util.h b/test/util/fs_util.h
index ee1b341d7..caf19b24d 100644
--- a/test/util/fs_util.h
+++ b/test/util/fs_util.h
@@ -26,6 +26,17 @@
namespace gvisor {
namespace testing {
+
+// O_LARGEFILE as defined by Linux. glibc tries to be clever by setting it to 0
+// because "it isn't needed", even though Linux can return it via F_GETFL.
+#if defined(__x86_64__)
+constexpr int kOLargeFile = 00100000;
+#elif defined(__aarch64__)
+constexpr int kOLargeFile = 00400000;
+#else
+#error "Unknown architecture"
+#endif
+
// Returns a status or the current working directory.
PosixErrorOr<std::string> GetCWD();
diff --git a/test/util/platform_util.cc b/test/util/platform_util.cc
index 2724e63f3..c9200d381 100644
--- a/test/util/platform_util.cc
+++ b/test/util/platform_util.cc
@@ -20,10 +20,9 @@ namespace gvisor {
namespace testing {
PlatformSupport PlatformSupport32Bit() {
- if (GvisorPlatform() == Platform::kPtrace) {
+ if (GvisorPlatform() == Platform::kPtrace ||
+ GvisorPlatform() == Platform::kKVM) {
return PlatformSupport::NotSupported;
- } else if (GvisorPlatform() == Platform::kKVM) {
- return PlatformSupport::Segfault;
} else {
return PlatformSupport::Allowed;
}
diff --git a/test/util/signal_util.h b/test/util/signal_util.h
index bcf85c337..e7b66aa51 100644
--- a/test/util/signal_util.h
+++ b/test/util/signal_util.h
@@ -85,6 +85,20 @@ inline void FixupFault(ucontext_t* ctx) {
// The encoding is 0x48 0xab 0x00.
ctx->uc_mcontext.gregs[REG_RIP] += 3;
}
+#elif __aarch64__
+inline void Fault() {
+ // Zero and dereference x0.
+ asm("mov xzr, x0\r\n"
+ "str xzr, [x0]\r\n"
+ :
+ :
+ : "x0");
+}
+
+inline void FixupFault(ucontext_t* ctx) {
+ // Skip the bad instruction above.
+ ctx->uc_mcontext.pc += 4;
+}
#endif
} // namespace testing
diff --git a/test/util/temp_path.cc b/test/util/temp_path.cc
index 35aacb172..9c10b6674 100644
--- a/test/util/temp_path.cc
+++ b/test/util/temp_path.cc
@@ -77,6 +77,7 @@ std::string NewTempAbsPath() {
std::string NewTempRelPath() { return NextTempBasename(); }
std::string GetAbsoluteTestTmpdir() {
+ // Note that TEST_TMPDIR is guaranteed to be set.
char* env_tmpdir = getenv("TEST_TMPDIR");
std::string tmp_dir =
env_tmpdir != nullptr ? std::string(env_tmpdir) : "/tmp";
diff --git a/test/util/test_main.cc b/test/util/test_main.cc
index 5c7ee0064..1f389e58f 100644
--- a/test/util/test_main.cc
+++ b/test/util/test_main.cc
@@ -16,5 +16,5 @@
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/util/test_util.cc b/test/util/test_util.cc
index 15cbc6da6..95e1e0c96 100644
--- a/test/util/test_util.cc
+++ b/test/util/test_util.cc
@@ -69,7 +69,6 @@ bool IsRunningWithHostinet() {
"xchg %%rdi, %%rbx\n" \
: "=a"(a), "=D"(b), "=c"(c), "=d"(d) \
: "a"(a_inp), "2"(c_inp))
-#endif // defined(__x86_64__)
CPUVendor GetCPUVendor() {
uint32_t eax, ebx, ecx, edx;
@@ -86,6 +85,7 @@ CPUVendor GetCPUVendor() {
}
return CPUVendor::kUnknownVendor;
}
+#endif // defined(__x86_64__)
bool operator==(const KernelVersion& first, const KernelVersion& second) {
return first.major == second.major && first.minor == second.minor &&
diff --git a/test/util/test_util.h b/test/util/test_util.h
index 2d22b0eb8..c5cb9d6d6 100644
--- a/test/util/test_util.h
+++ b/test/util/test_util.h
@@ -771,6 +771,7 @@ std::string RunfilePath(std::string path);
#endif
void TestInit(int* argc, char*** argv);
+int RunAllTests(void);
} // namespace testing
} // namespace gvisor
diff --git a/test/util/test_util_impl.cc b/test/util/test_util_impl.cc
index ba7c0a85b..7e1ad9e66 100644
--- a/test/util/test_util_impl.cc
+++ b/test/util/test_util_impl.cc
@@ -17,8 +17,12 @@
#include "gtest/gtest.h"
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
+#include "benchmark/benchmark.h"
#include "test/util/logging.h"
+extern bool FLAGS_benchmark_list_tests;
+extern std::string FLAGS_benchmark_filter;
+
namespace gvisor {
namespace testing {
@@ -26,6 +30,7 @@ void SetupGvisorDeathTest() {}
void TestInit(int* argc, char*** argv) {
::testing::InitGoogleTest(argc, *argv);
+ benchmark::Initialize(argc, *argv);
::absl::ParseCommandLine(*argc, *argv);
// Always mask SIGPIPE as it's common and tests aren't expected to handle it.
@@ -34,5 +39,14 @@ void TestInit(int* argc, char*** argv) {
TEST_CHECK(sigaction(SIGPIPE, &sa, nullptr) == 0);
}
+int RunAllTests() {
+ if (FLAGS_benchmark_list_tests || FLAGS_benchmark_filter != ".") {
+ benchmark::RunSpecifiedBenchmarks();
+ return 0;
+ } else {
+ return RUN_ALL_TESTS();
+ }
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/tools/build/BUILD b/tools/bazeldefs/BUILD
index 0c0ce3f4d..00a467473 100644
--- a/tools/build/BUILD
+++ b/tools/bazeldefs/BUILD
@@ -6,5 +6,5 @@ genrule(
name = "loopback",
outs = ["loopback.txt"],
cmd = "touch $@",
- visibility = ["//visibility:public"],
+ visibility = ["//:sandbox"],
)
diff --git a/tools/build/defs.bzl b/tools/bazeldefs/defs.bzl
index d0556abd1..905b16d41 100644
--- a/tools/build/defs.bzl
+++ b/tools/bazeldefs/defs.bzl
@@ -18,7 +18,9 @@ cc_test = _cc_test
cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain"
go_image = _go_image
go_embed_data = _go_embed_data
-loopback = "//tools/build:loopback"
+gtest = "@com_google_googletest//:gtest"
+gbenchmark = "@com_google_benchmark//:benchmark"
+loopback = "//tools/bazeldefs:loopback"
proto_library = native.proto_library
pkg_deb = _pkg_deb
pkg_tar = _pkg_tar
@@ -69,7 +71,7 @@ def go_test(name, **kwargs):
**kwargs
)
-def py_requirement(name, direct = False):
+def py_requirement(name, direct = True):
return _py_requirement(name)
def select_arch(amd64 = "amd64", arm64 = "arm64", default = None, **kwargs):
diff --git a/tools/bazeldefs/platforms.bzl b/tools/bazeldefs/platforms.bzl
new file mode 100644
index 000000000..92b0b5fc0
--- /dev/null
+++ b/tools/bazeldefs/platforms.bzl
@@ -0,0 +1,17 @@
+"""List of platforms."""
+
+# Platform to associated tags.
+platforms = {
+ "ptrace": [
+ # TODO(b/120560048): Make the tests run without this tag.
+ "no-sandbox",
+ ],
+ "kvm": [
+ "manual",
+ "local",
+ # TODO(b/120560048): Make the tests run without this tag.
+ "no-sandbox",
+ ],
+}
+
+default_platform = "ptrace"
diff --git a/tools/bazeldefs/tags.bzl b/tools/bazeldefs/tags.bzl
new file mode 100644
index 000000000..558fb53ae
--- /dev/null
+++ b/tools/bazeldefs/tags.bzl
@@ -0,0 +1,40 @@
+"""List of special Go suffixes."""
+
+go_suffixes = [
+ "_386",
+ "_386_unsafe",
+ "_aarch64",
+ "_aarch64_unsafe",
+ "_amd64",
+ "_amd64_unsafe",
+ "_arm",
+ "_arm64",
+ "_arm64_unsafe",
+ "_arm_unsafe",
+ "_impl",
+ "_impl_unsafe",
+ "_linux",
+ "_linux_unsafe",
+ "_mips",
+ "_mips64",
+ "_mips64_unsafe",
+ "_mips64le",
+ "_mips64le_unsafe",
+ "_mips_unsafe",
+ "_mipsle",
+ "_mipsle_unsafe",
+ "_opts",
+ "_opts_unsafe",
+ "_ppc64",
+ "_ppc64_unsafe",
+ "_ppc64le",
+ "_ppc64le_unsafe",
+ "_riscv64",
+ "_riscv64_unsafe",
+ "_s390x",
+ "_s390x_unsafe",
+ "_sparc64",
+ "_sparc64_unsafe",
+ "_wasm",
+ "_wasm_unsafe",
+]
diff --git a/tools/checkunsafe/BUILD b/tools/checkunsafe/BUILD
index 92ba8ab06..4f1a31a6d 100644
--- a/tools/checkunsafe/BUILD
+++ b/tools/checkunsafe/BUILD
@@ -5,7 +5,7 @@ package(licenses = ["notice"])
go_tool_library(
name = "checkunsafe",
srcs = ["check_unsafe.go"],
- visibility = ["//visibility:public"],
+ visibility = ["//:sandbox"],
deps = [
"@org_golang_x_tools//go/analysis:go_tool_library",
],
diff --git a/tools/defs.bzl b/tools/defs.bzl
index 819f12b0d..15a310403 100644
--- a/tools/defs.bzl
+++ b/tools/defs.bzl
@@ -7,7 +7,9 @@ change for Google-internal and bazel-compatible rules.
load("//tools/go_stateify:defs.bzl", "go_stateify")
load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps")
-load("//tools/build:defs.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system")
+load("//tools/bazeldefs:defs.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system")
+load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms")
+load("//tools/bazeldefs:tags.bzl", "go_suffixes")
# Delegate directly.
cc_binary = _cc_binary
@@ -20,6 +22,8 @@ go_embed_data = _go_embed_data
go_image = _go_image
go_test = _go_test
go_tool_library = _go_tool_library
+gtest = _gtest
+gbenchmark = _gbenchmark
pkg_deb = _pkg_deb
pkg_tar = _pkg_tar
py_library = _py_library
@@ -31,6 +35,8 @@ select_system = _select_system
loopback = _loopback
default_installer = _default_installer
default_net_util = _default_net_util
+platforms = _platforms
+default_platform = _default_platform
def go_binary(name, **kwargs):
"""Wraps the standard go_binary.
@@ -44,7 +50,45 @@ def go_binary(name, **kwargs):
**kwargs
)
-def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, **kwargs):
+def calculate_sets(srcs):
+ """Calculates special Go sets for templates.
+
+ Args:
+ srcs: the full set of Go sources.
+
+ Returns:
+ A dictionary of the form:
+
+ "": [src1.go, src2.go]
+ "suffix": [src3suffix.go, src4suffix.go]
+
+ Note that suffix will typically start with '_'.
+ """
+ result = dict()
+ for file in srcs:
+ if not file.endswith(".go"):
+ continue
+ target = ""
+ for suffix in go_suffixes:
+ if file.endswith(suffix + ".go"):
+ target = suffix
+ if not target in result:
+ result[target] = [file]
+ else:
+ result[target].append(file)
+ return result
+
+def go_imports(name, src, out):
+ """Simplify a single Go source file by eliminating unused imports."""
+ native.genrule(
+ name = name,
+ srcs = [src],
+ outs = [out],
+ tools = ["@org_golang_x_tools//cmd/goimports:goimports"],
+ cmd = ("$(location @org_golang_x_tools//cmd/goimports:goimports) $(SRCS) > $@"),
+ )
+
+def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, marshal_debug = False, **kwargs):
"""Wraps the standard go_library and does stateification and marshalling.
The recommended way is to use this rule with mostly identical configuration as the native
@@ -67,41 +111,62 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F
imports: imports required for stateify.
stateify: whether statify is enabled (default: true).
marshal: whether marshal is enabled (default: false).
+ marshal_debug: whether the gomarshal tools emits debugging output (default: false).
**kwargs: standard go_library arguments.
"""
+ all_srcs = srcs
+ all_deps = deps
+ dirname, _, _ = native.package_name().rpartition("/")
+ full_pkg = dirname + "/" + name
if stateify:
# Only do stateification for non-state packages without manual autogen.
- go_stateify(
- name = name + "_state_autogen",
- srcs = [src for src in srcs if src.endswith(".go")],
- imports = imports,
- package = name,
- arch = select_arch(),
- out = name + "_state_autogen.go",
- )
- all_srcs = srcs + [name + "_state_autogen.go"]
- if "//pkg/state" not in deps:
- all_deps = deps + ["//pkg/state"]
- else:
- all_deps = deps
- else:
- all_deps = deps
- all_srcs = srcs
+ # First, we need to segregate the input files via the special suffixes,
+ # and calculate the final output set.
+ state_sets = calculate_sets(srcs)
+ for (suffix, src_subset) in state_sets.items():
+ go_stateify(
+ name = name + suffix + "_state_autogen_with_imports",
+ srcs = src_subset,
+ imports = imports,
+ package = full_pkg,
+ out = name + suffix + "_state_autogen_with_imports.go",
+ )
+ go_imports(
+ name = name + suffix + "_state_autogen",
+ src = name + suffix + "_state_autogen_with_imports.go",
+ out = name + suffix + "_state_autogen.go",
+ )
+ all_srcs = all_srcs + [
+ name + suffix + "_state_autogen.go"
+ for suffix in state_sets.keys()
+ ]
+ if "//pkg/state" not in all_deps:
+ all_deps = all_deps + ["//pkg/state"]
+
if marshal:
- go_marshal(
- name = name + "_abi_autogen",
- srcs = [src for src in srcs if src.endswith(".go")],
- debug = False,
- imports = imports,
- package = name,
- )
+ # See above.
+ marshal_sets = calculate_sets(srcs)
+ for (suffix, src_subset) in marshal_sets.items():
+ go_marshal(
+ name = name + suffix + "_abi_autogen",
+ srcs = src_subset,
+ debug = select({
+ "//tools/go_marshal:marshal_config_verbose": True,
+ "//conditions:default": marshal_debug,
+ }),
+ imports = imports,
+ package = name,
+ )
extra_deps = [
dep
for dep in marshal_deps
if not dep in all_deps
]
all_deps = all_deps + extra_deps
- all_srcs = srcs + [name + "_abi_autogen_unsafe.go"]
+ all_srcs = all_srcs + [
+ name + suffix + "_abi_autogen_unsafe.go"
+ for suffix in marshal_sets.keys()
+ ]
_go_library(
name = name,
@@ -114,13 +179,16 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F
# Ignore importpath for go_test.
kwargs.pop("importpath", None)
- _go_test(
- name = name + "_abi_autogen_test",
- srcs = [name + "_abi_autogen_test.go"],
- library = ":" + name,
- deps = marshal_test_deps,
- **kwargs
- )
+ # See above.
+ marshal_sets = calculate_sets(srcs)
+ for (suffix, _) in marshal_sets.items():
+ _go_test(
+ name = name + suffix + "_abi_autogen_test",
+ srcs = [name + suffix + "_abi_autogen_test.go"],
+ library = ":" + name,
+ deps = marshal_test_deps,
+ **kwargs
+ )
def proto_library(name, srcs, **kwargs):
"""Wraps the standard proto_library.
diff --git a/tools/go_generics/BUILD b/tools/go_generics/BUILD
index 069df3856..32a949c93 100644
--- a/tools/go_generics/BUILD
+++ b/tools/go_generics/BUILD
@@ -9,7 +9,7 @@ go_binary(
"imports.go",
"remove.go",
],
- visibility = ["//visibility:public"],
+ visibility = ["//:sandbox"],
deps = ["//tools/go_generics/globals"],
)
diff --git a/tools/go_generics/go_merge/BUILD b/tools/go_generics/go_merge/BUILD
index b7d35e272..2fd5a200d 100644
--- a/tools/go_generics/go_merge/BUILD
+++ b/tools/go_generics/go_merge/BUILD
@@ -5,5 +5,5 @@ package(licenses = ["notice"])
go_binary(
name = "go_merge",
srcs = ["main.go"],
- visibility = ["//visibility:public"],
+ visibility = ["//:sandbox"],
)
diff --git a/tools/go_marshal/BUILD b/tools/go_marshal/BUILD
index 80d9c0504..be49cf9c8 100644
--- a/tools/go_marshal/BUILD
+++ b/tools/go_marshal/BUILD
@@ -12,3 +12,8 @@ go_binary(
"//tools/go_marshal/gomarshal",
],
)
+
+config_setting(
+ name = "marshal_config_verbose",
+ values = {"define": "gomarshal=verbose"},
+)
diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD
index c92b59dd6..b5d5a4487 100644
--- a/tools/go_marshal/gomarshal/BUILD
+++ b/tools/go_marshal/gomarshal/BUILD
@@ -14,4 +14,5 @@ go_library(
visibility = [
"//:sandbox",
],
+ deps = ["//tools/tags"],
)
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
index af90bdecb..d365a1f3c 100644
--- a/tools/go_marshal/gomarshal/generator.go
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -23,6 +23,9 @@ import (
"go/token"
"os"
"sort"
+ "strings"
+
+ "gvisor.dev/gvisor/tools/tags"
)
const (
@@ -31,9 +34,9 @@ const (
usermemImport = "gvisor.dev/gvisor/pkg/usermem"
)
-// List of identifiers we use in generated code, that may conflict a
-// similarly-named source identifier. Avoid problems by refusing the generate
-// code when we see these.
+// List of identifiers we use in generated code that may conflict with a
+// similarly-named source identifier. Abort gracefully when we see these to
+// avoid potentially confusing compilation failures in generated code.
//
// This only applies to import aliases at the moment. All other identifiers
// are qualified by a receiver argument, since they're struct fields.
@@ -41,10 +44,21 @@ const (
// All recievers are single letters, so we don't allow import aliases to be a
// single letter.
var badIdents = []string{
- "src", "srcs", "dst", "dsts", "blk", "buf", "err",
+ "addr", "blk", "buf", "dst", "dsts", "err", "hdr", "idx", "inner", "len",
+ "ptr", "src", "srcs", "task", "val",
// All single-letter identifiers.
}
+// Constructed fromt badIdents in init().
+var badIdentsMap map[string]struct{}
+
+func init() {
+ badIdentsMap = make(map[string]struct{})
+ for _, ident := range badIdents {
+ badIdentsMap[ident] = struct{}{}
+ }
+}
+
// Generator drives code generation for a single invocation of the go_marshal
// utility.
//
@@ -85,16 +99,20 @@ func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*G
}
for _, i := range imports {
// All imports on the extra imports list are unconditionally marked as
- // used, so they're always added to the generated code.
+ // used, so that they're always added to the generated code.
g.imports.add(i).markUsed()
}
- g.imports.add(marshalImport).markUsed()
- // The follow imports may or may not be used by the generated
- // code, depending what's required for the target types. Don't
- // mark these imports as used by default.
- g.imports.add(usermemImport)
- g.imports.add(safecopyImport)
+
+ // The following imports may or may not be used by the generated code,
+ // depending on what's required for the target types. Don't mark these as
+ // used by default.
+ g.imports.add("io")
+ g.imports.add("reflect")
+ g.imports.add("runtime")
g.imports.add("unsafe")
+ g.imports.add(marshalImport)
+ g.imports.add(safecopyImport)
+ g.imports.add(usermemImport)
return &g, nil
}
@@ -104,6 +122,14 @@ func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*G
func (g *Generator) writeHeader() error {
var b sourceBuffer
b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n")
+
+ // Emit build tags.
+ if t := tags.Aggregate(g.inputs); len(t) > 0 {
+ b.emit(strings.Join(t.Lines(), "\n"))
+ b.emit("\n\n")
+ }
+
+ // Package header.
b.emit("package %s\n\n", g.pkg)
if err := b.write(g.output); err != nil {
return err
@@ -168,9 +194,9 @@ func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) {
return files, fsets, nil
}
-// collectMarshallabeTypes walks the parsed AST and collects a list of type
+// collectMarshallableTypes walks the parsed AST and collects a list of type
// declarations for which we need to generate the Marshallable interface.
-func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec {
+func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec {
var types []*ast.TypeSpec
for _, decl := range a.Decls {
gdecl, ok := decl.(*ast.GenDecl)
@@ -197,14 +223,22 @@ func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*as
continue
}
for _, spec := range gdecl.Specs {
- // We already confirmed we're in a type declaration earlier.
+ // We already confirmed we're in a type declaration earlier, so this
+ // cast will succeed.
t := spec.(*ast.TypeSpec)
- if _, ok := t.Type.(*ast.StructType); ok {
- debugfAt(f.Position(t.Pos()), "Collected marshallable type %s.\n", t.Name.Name)
+ switch t.Type.(type) {
+ case *ast.StructType:
+ debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name)
+ types = append(types, t)
+ continue
+ case *ast.Ident: // Newtype on primitive.
+ debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name)
types = append(types, t)
continue
}
- debugf("Skipping declaration %v since it's not a struct declaration.\n", gdecl)
+ // A user specifically requested marshalling on this type, but we
+ // don't support it.
+ abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name))
}
}
return types
@@ -218,11 +252,6 @@ func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*as
// identifiers in the generated code don't conflict with any imported package
// names.
func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt {
- badImportNames := make(map[string]bool)
- for _, i := range badIdents {
- badImportNames[i] = true
- }
-
is := make(map[string]importStmt)
for _, decl := range a.Decls {
gdecl, ok := decl.(*ast.GenDecl)
@@ -239,7 +268,7 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp
if len(i.name) == 1 {
abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name))
}
- if badImportNames[i.name] {
+ if _, ok := badIdentsMap[i.name]; ok {
abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name))
}
}
@@ -249,12 +278,20 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp
}
func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
- // We're guaranteed to have only struct type specs by now. See
- // Generator.collectMarshallabeTypes.
i := newInterfaceGenerator(t, fset)
- i.validate()
- i.emitMarshallable()
- return i
+ switch ty := t.Type.(type) {
+ case *ast.StructType:
+ i.validateStruct()
+ i.emitMarshallableForStruct()
+ return i
+ case *ast.Ident:
+ i.validatePrimitiveNewtype(ty)
+ i.emitMarshallableForPrimitiveNewtype()
+ return i
+ default:
+ // This should've been filtered out by collectMarshallabeTypes.
+ panic(fmt.Sprintf("Unexpected type %+v", ty))
+ }
}
// generateOneTestSuite generates a test suite for the automatically generated
@@ -300,7 +337,7 @@ func (g *Generator) Run() error {
for i, a := range asts {
// Collect type declarations marked for code generation and generate
// Marshallable interfaces.
- for _, t := range g.collectMarshallabeTypes(a, fsets[i]) {
+ for _, t := range g.collectMarshallableTypes(a, fsets[i]) {
impl := g.generateOne(t, fsets[i])
// Collect Marshallable types referenced by the generated code.
for ref, _ := range impl.ms {
@@ -318,17 +355,6 @@ func (g *Generator) Run() error {
}
}
- // Tool was invoked with input files with no data structures marked for code
- // generation. This is probably not what the user intended.
- if len(impls) == 0 {
- var buf bytes.Buffer
- fmt.Fprintf(&buf, "go_marshal invoked on these files, but they don't contain any types requiring code generation. Perhaps mark some with \"// +marshal\"?:\n")
- for _, i := range g.inputs {
- fmt.Fprintf(&buf, " %s\n", i)
- }
- abort(buf.String())
- }
-
// Write output file header. These include things like package name and
// import statements.
if err := g.writeHeader(); err != nil {
@@ -360,6 +386,7 @@ func (g *Generator) writeTests(ts []*testGenerator) error {
return err
}
+ // Collect and write test import statements.
imports := newImportTable()
for _, t := range ts {
imports.merge(t.imports)
@@ -369,6 +396,27 @@ func (g *Generator) writeTests(ts []*testGenerator) error {
return err
}
+ // Write test functions.
+
+ // If we didn't generate any Marshallable implementations, we can't just
+ // emit an empty test file, since that causes the build to fail with "no
+ // tests/benchmarks/examples found". Unfortunately we can't signal bazel to
+ // omit the entire package since the outputs are already defined before
+ // go-marshal is called. If we'd otherwise emit an empty test suite, emit an
+ // empty example instead.
+ if len(ts) == 0 {
+ b.reset()
+ b.emit("func ExampleEmptyTestSuite() {\n")
+ b.inIndent(func() {
+ b.emit("// This example is intentionally empty to ensure this file contains at least\n")
+ b.emit("// one testable entity. go-marshal is forced to emit a test file if a package\n")
+ b.emit("// is marked marshallable, but emitting a test file with no entities results\n")
+ b.emit("// in a build failure.\n")
+ })
+ b.emit("}\n")
+ return b.write(g.outputTest)
+ }
+
for _, t := range ts {
if err := t.write(g.outputTest); err != nil {
return err
diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go
index a712c14dc..ea1af998e 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces.go
@@ -55,9 +55,6 @@ func (g *interfaceGenerator) typeName() string {
// newinterfaceGenerator creates a new interface generator.
func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
- if _, ok := t.Type.(*ast.StructType); !ok {
- panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t))
- }
g := &interfaceGenerator{
t: t,
r: receiverName(t),
@@ -103,9 +100,31 @@ func (g *interfaceGenerator) abortAt(p token.Pos, msg string) {
abortAt(g.f.Position(p), msg)
}
-// validate ensures the type we're working with can be marshalled. These checks
-// are done ahead of time and in one place so we can make assumptions later.
-func (g *interfaceGenerator) validate() {
+func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) {
+ switch t.Name {
+ case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
+ // These are the only primitive types we're allow. Below, we provide
+ // suggestions for some disallowed types and reject them, then attempt
+ // to marshal any remaining types by invoking the marshal.Marshallable
+ // interface on them. If these types don't actually implement
+ // marshal.Marshallable, compilation of the generated code will fail
+ // with an appropriate error message.
+ return
+ case "int":
+ g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
+ case "uint":
+ g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
+ case "string":
+ g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
+ default:
+ debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
+ }
+}
+
+// validateStruct ensures the type we're working with can be marshalled. These
+// checks are done ahead of time and in one place so we can make assumptions
+// later.
+func (g *interfaceGenerator) validateStruct() {
g.forEachField(func(f *ast.Field) {
if len(f.Names) == 0 {
g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields")
@@ -115,25 +134,7 @@ func (g *interfaceGenerator) validate() {
g.forEachField(func(f *ast.Field) {
fieldDispatcher{
primitive: func(_, t *ast.Ident) {
- switch t.Name {
- case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
- // These are the only primitive types we're allow. Below, we
- // provide suggestions for some disallowed types and reject
- // them, then attempt to marshal any remaining types by
- // invoking the marshal.Marshallable interface on them. If
- // these types don't actually implement
- // marshal.Marshallable, compilation of the generated code
- // will fail with an appropriate error message.
- return
- case "int":
- g.abortAt(f.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
- case "uint":
- g.abortAt(f.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
- case "string":
- g.abortAt(f.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
- default:
- debugfAt(g.f.Position(f.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
- }
+ g.validatePrimitiveNewtype(t)
},
selector: func(_, _, _ *ast.Ident) {
// No validation to perform on selector fields. However this
@@ -190,7 +191,8 @@ func (g *interfaceGenerator) shiftDynamic(bufVar, name string) {
g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name)
}
-func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) {
+// marshalStructFieldScalar writes a single scalar field from a struct to a byte slice.
+func (g *interfaceGenerator) marshalStructFieldScalar(accessor, typ, bufVar string) {
switch typ {
case "int8", "uint8", "byte":
g.emit("%s[0] = byte(%s)\n", bufVar, accessor)
@@ -213,43 +215,27 @@ func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string)
}
}
-func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string) {
+// unmarshalStructFieldScalar reads a single scalar field from a struct, from a
+// byte slice.
+func (g *interfaceGenerator) unmarshalStructFieldScalar(accessor, typ, bufVar string) {
switch typ {
- case "int8":
- g.emit("%s = int8(%s[0])\n", accessor, bufVar)
- g.shift(bufVar, 1)
- case "uint8":
- g.emit("%s = uint8(%s[0])\n", accessor, bufVar)
- g.shift(bufVar, 1)
case "byte":
g.emit("%s = %s[0]\n", accessor, bufVar)
g.shift(bufVar, 1)
-
- case "int16":
- g.recordUsedImport("usermem")
- g.emit("%s = int16(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, bufVar)
- g.shift(bufVar, 2)
- case "uint16":
+ case "int8", "uint8":
+ g.emit("%s = %s(%s[0])\n", accessor, typ, bufVar)
+ g.shift(bufVar, 1)
+ case "int16", "uint16":
g.recordUsedImport("usermem")
- g.emit("%s = usermem.ByteOrder.Uint16(%s[:2])\n", accessor, bufVar)
+ g.emit("%s = %s(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, typ, bufVar)
g.shift(bufVar, 2)
-
- case "int32":
- g.recordUsedImport("usermem")
- g.emit("%s = int32(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, bufVar)
- g.shift(bufVar, 4)
- case "uint32":
+ case "int32", "uint32":
g.recordUsedImport("usermem")
- g.emit("%s = usermem.ByteOrder.Uint32(%s[:4])\n", accessor, bufVar)
+ g.emit("%s = %s(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, typ, bufVar)
g.shift(bufVar, 4)
-
- case "int64":
- g.recordUsedImport("usermem")
- g.emit("%s = int64(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, bufVar)
- g.shift(bufVar, 8)
- case "uint64":
+ case "int64", "uint64":
g.recordUsedImport("usermem")
- g.emit("%s = usermem.ByteOrder.Uint64(%s[:8])\n", accessor, bufVar)
+ g.emit("%s = %s(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar)
g.shift(bufVar, 8)
default:
g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
@@ -258,6 +244,49 @@ func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string
}
}
+// marshalPrimitiveScalar writes a single primitive variable to a byte slice.
+func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) {
+ switch typ {
+ case "int8", "uint8", "byte":
+ g.emit("%s[0] = byte(*%s)\n", bufVar, accessor)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor)
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor)
+ default:
+ g.emit("inner := (*%s)(%s)\n", typ, accessor)
+ g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
+ }
+}
+
+// unmarshalPrimitiveScalar read a single primitive variable from a byte slice.
+func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) {
+ switch typ {
+ case "byte":
+ g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar)
+ case "int8", "uint8":
+ g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar)
+
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar)
+ default:
+ g.emit("inner := (*%s)(%s)\n", typ, accessor)
+ g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
+ }
+}
+
// areFieldsPackedExpression returns a go expression checking whether g.t's fields are
// packed. Returns "", false if g.t has no fields that may be potentially
// packed, otherwise returns <clause>, true, where <clause> is an expression
@@ -274,7 +303,7 @@ func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) {
return strings.Join(cs, " && "), true
}
-func (g *interfaceGenerator) emitMarshallable() {
+func (g *interfaceGenerator) emitMarshallableForStruct() {
// Is g.t a packed struct without consideing field types?
thisPacked := true
g.forEachField(func(f *ast.Field) {
@@ -301,7 +330,7 @@ func (g *interfaceGenerator) emitMarshallable() {
primitiveSize += size
} else {
g.recordUsedMarshallable(t.Name)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%s.SizeBytes()", g.fieldAccessor(n)))
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name))
}
},
selector: func(n, tX, tSel *ast.Ident) {
@@ -357,10 +386,10 @@ func (g *interfaceGenerator) emitMarshallable() {
}
return
}
- g.marshalScalar(g.fieldAccessor(n), t.Name, "dst")
+ g.marshalStructFieldScalar(g.fieldAccessor(n), t.Name, "dst")
},
selector: func(n, tX, tSel *ast.Ident) {
- g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
+ g.marshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
},
array: func(n, t *ast.Ident, size int) {
if n.Name == "_" {
@@ -377,9 +406,9 @@ func (g *interfaceGenerator) emitMarshallable() {
return
}
- g.emit("for i := 0; i < %d; i++ {\n", size)
+ g.emit("for idx := 0; idx < %d; idx++ {\n", size)
g.inIndent(func() {
- g.marshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "dst")
+ g.marshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst")
})
g.emit("}\n")
},
@@ -406,10 +435,10 @@ func (g *interfaceGenerator) emitMarshallable() {
}
return
}
- g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src")
+ g.unmarshalStructFieldScalar(g.fieldAccessor(n), t.Name, "src")
},
selector: func(n, tX, tSel *ast.Ident) {
- g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
+ g.unmarshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
},
array: func(n, t *ast.Ident, size int) {
if n.Name == "_" {
@@ -426,9 +455,9 @@ func (g *interfaceGenerator) emitMarshallable() {
return
}
- g.emit("for i := 0; i < %d; i++ {\n", size)
+ g.emit("for idx := 0; idx < %d; idx++ {\n", size)
g.inIndent(func() {
- g.unmarshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "src")
+ g.unmarshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src")
})
g.emit("}\n")
},
@@ -504,4 +533,290 @@ func (g *interfaceGenerator) emitMarshallable() {
})
g.emit("}\n\n")
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
+ g.emit("%s.MarshalBytes(buf)\n", g.r)
+ g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
+ g.emit("return err\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !%s {\n", cond)
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast serialization.
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyOutBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
+ g.emit("_, err := task.CopyInBytes(addr, buf)\n")
+ g.emit("if err != nil {\n")
+ g.inIndent(func() {
+ g.emit("return err\n")
+ })
+ g.emit("}\n")
+
+ g.emit("%s.UnmarshalBytes(buf)\n", g.r)
+ g.emit("return nil\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !%s {\n", cond)
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast deserialization.
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("_, err := task.CopyInBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyInBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.recordUsedImport("io")
+ g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r)
+ g.emit("%s.MarshalBytes(buf)\n", g.r)
+ g.emit("n, err := w.Write(buf)\n")
+ g.emit("return int64(n), err\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !%s {\n", cond)
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast serialization.
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("len, err := w.Write(buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the Write.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return int64(len), err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+}
+
+// emitMarshallableForPrimitiveNewtype outputs code to implement the
+// marshal.Marshallable interface for a newtype on a primitive. Primitive
+// newtypes are always packed, so we can omit the various fallbacks required for
+// non-packed structs.
+func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype() {
+ g.recordUsedImport("io")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ g.recordUsedImport("usermem")
+
+ nt := g.t.Type.(*ast.Ident)
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if size, dynamic := g.scalarSize(nt); !dynamic {
+ g.emit("return %d\n", size)
+ } else {
+ g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.marshalPrimitiveScalar(g.r, nt.Name, "dst")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName())
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Scalar newtypes are always packed.\n")
+ g.emit("return true\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ // Fast serialization.
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyOutBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("_, err := task.CopyInBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyInBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("len, err := w.Write(buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the Write.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return int64(len), err\n")
+
+ })
+ g.emit("}\n\n")
+
}
diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go
index bcda17c3b..8ba47eb67 100644
--- a/tools/go_marshal/gomarshal/generator_tests.go
+++ b/tools/go_marshal/gomarshal/generator_tests.go
@@ -22,9 +22,11 @@ import (
)
var standardImports = []string{
+ "bytes",
"fmt",
"reflect",
"testing",
+
"gvisor.dev/gvisor/tools/go_marshal/analysis",
}
@@ -47,9 +49,6 @@ type testGenerator struct {
}
func newTestGenerator(t *ast.TypeSpec) *testGenerator {
- if _, ok := t.Type.(*ast.StructType); !ok {
- panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t))
- }
g := &testGenerator{
t: t,
r: receiverName(t),
@@ -67,14 +66,6 @@ func (g *testGenerator) typeName() string {
return g.t.Name.Name
}
-func (g *testGenerator) forEachField(fn func(f *ast.Field)) {
- // This is guaranteed to succeed because g.t is always a struct.
- st := g.t.Type.(*ast.StructType)
- for _, field := range st.Fields.List {
- fn(field)
- }
-}
-
func (g *testGenerator) testFuncName(base string) string {
return fmt.Sprintf("%s%s", base, strings.Title(g.t.Name.Name))
}
@@ -87,10 +78,10 @@ func (g *testGenerator) inTestFunction(name string, body func()) {
func (g *testGenerator) emitTestNonZeroSize() {
g.inTestFunction("TestSizeNonZero", func() {
- g.emit("x := &%s{}\n", g.typeName())
+ g.emit("var x %v\n", g.typeName())
g.emit("if x.SizeBytes() == 0 {\n")
g.inIndent(func() {
- g.emit("t.Fatal(\"Marshallable.Size() should not return zero\")\n")
+ g.emit("t.Fatal(\"Marshallable.SizeBytes() should not return zero\")\n")
})
g.emit("}\n")
})
@@ -98,7 +89,7 @@ func (g *testGenerator) emitTestNonZeroSize() {
func (g *testGenerator) emitTestSuspectAlignment() {
g.inTestFunction("TestSuspectAlignment", func() {
- g.emit("x := %s{}\n", g.typeName())
+ g.emit("var x %v\n", g.typeName())
g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n")
})
}
@@ -116,26 +107,64 @@ func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() {
g.emit("y.UnmarshalBytes(buf)\n")
g.emit("if !reflect.DeepEqual(x, y) {\n")
g.inIndent(func() {
- g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, y))\n")
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n")
})
g.emit("}\n")
g.emit("yUnsafe.UnmarshalBytes(bufUnsafe)\n")
g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n")
g.inIndent(func() {
- g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, yUnsafe))\n")
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n")
})
g.emit("}\n\n")
g.emit("z.UnmarshalUnsafe(buf)\n")
g.emit("if !reflect.DeepEqual(x, z) {\n")
g.inIndent(func() {
- g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, z))\n")
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, z))\n")
})
g.emit("}\n")
g.emit("zUnsafe.UnmarshalUnsafe(bufUnsafe)\n")
g.emit("if !reflect.DeepEqual(x, zUnsafe) {\n")
g.inIndent(func() {
- g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, zUnsafe))\n")
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, zUnsafe))\n")
+ })
+ g.emit("}\n")
+ })
+}
+
+func (g *testGenerator) emitTestWriteToUnmarshalPreservesData() {
+ g.inTestFunction("TestWriteToUnmarshalPreservesData", func() {
+ g.emit("var x, y, yUnsafe %s\n", g.typeName())
+ g.emit("analysis.RandomizeValue(&x)\n\n")
+
+ g.emit("var buf bytes.Buffer\n\n")
+
+ g.emit("x.WriteTo(&buf)\n")
+ g.emit("y.UnmarshalBytes(buf.Bytes())\n\n")
+ g.emit("yUnsafe.UnmarshalUnsafe(buf.Bytes())\n\n")
+
+ g.emit("if !reflect.DeepEqual(x, y) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n")
+ })
+ g.emit("}\n")
+ g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n")
+ })
+ g.emit("}\n")
+ })
+}
+
+func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() {
+ g.inTestFunction("TestSizeBytesOnTypedNilPtr", func() {
+ g.emit("var x %s\n", g.typeName())
+ g.emit("sizeFromConcrete := x.SizeBytes()\n")
+ g.emit("sizeFromTypedNilPtr := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emit("if sizeFromTypedNilPtr != sizeFromConcrete {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)")
})
g.emit("}\n")
})
@@ -145,6 +174,8 @@ func (g *testGenerator) emitTests() {
g.emitTestNonZeroSize()
g.emitTestSuspectAlignment()
g.emitTestMarshalUnmarshalPreservesData()
+ g.emitTestWriteToUnmarshalPreservesData()
+ g.emitTestSizeBytesOnTypedNilPtr()
}
func (g *testGenerator) write(out io.Writer) error {
diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go
index 967537abf..e2bca4e7c 100644
--- a/tools/go_marshal/gomarshal/util.go
+++ b/tools/go_marshal/gomarshal/util.go
@@ -219,6 +219,11 @@ type sourceBuffer struct {
b bytes.Buffer
}
+func (b *sourceBuffer) reset() {
+ b.indent = 0
+ b.b.Reset()
+}
+
func (b *sourceBuffer) incIndent() {
b.indent++
}
@@ -305,7 +310,7 @@ func (i *importStmt) markUsed() {
}
func (i *importStmt) equivalent(other *importStmt) bool {
- return i == other
+ return i.name == other.name && i.path == other.path && i.aliased == other.aliased
}
// importTable represents a collection of importStmts.
@@ -324,7 +329,7 @@ func newImportTable() *importTable {
// result in a panic.
func (i *importTable) merge(other *importTable) {
for name, im := range other.is {
- if dup, ok := i.is[name]; ok && dup.equivalent(im) {
+ if dup, ok := i.is[name]; ok && !dup.equivalent(im) {
panic(fmt.Sprintf("Found colliding import statements: ours: %+v, other's: %+v", dup, im))
}
@@ -332,16 +337,27 @@ func (i *importTable) merge(other *importTable) {
}
}
+func (i *importTable) addStmt(s *importStmt) *importStmt {
+ if old, ok := i.is[s.name]; ok && !old.equivalent(s) {
+ // A collision should always be between an import inserted by the
+ // go-marshal tool and an import from the original source file (assuming
+ // the original source file was valid). We could theoretically handle
+ // the collision by assigning a local name to our import. However, this
+ // would need to be plumbed throughout the generator. Given that
+ // collisions should be rare, simply panic on collision.
+ panic(fmt.Sprintf("Import collision: old: %s as %v; new: %v as %v", old.path, old.name, s.path, s.name))
+ }
+ i.is[s.name] = s
+ return s
+}
+
func (i *importTable) add(s string) *importStmt {
n := newImport(s)
- i.is[n.name] = n
- return n
+ return i.addStmt(n)
}
func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
- n := newImportFromSpec(spec, f)
- i.is[n.name] = n
- return n
+ return i.addStmt(newImportFromSpec(spec, f))
}
// Marks the import named n as used. If no such import is in the table, returns
diff --git a/tools/go_marshal/marshal/BUILD b/tools/go_marshal/marshal/BUILD
index ad508c72f..bacfaa5a4 100644
--- a/tools/go_marshal/marshal/BUILD
+++ b/tools/go_marshal/marshal/BUILD
@@ -10,4 +10,7 @@ go_library(
visibility = [
"//:sandbox",
],
+ deps = [
+ "//pkg/usermem",
+ ],
)
diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go
index a313a27ed..f129788e0 100644
--- a/tools/go_marshal/marshal/marshal.go
+++ b/tools/go_marshal/marshal/marshal.go
@@ -20,10 +20,38 @@
// tools/go_marshal. See the go_marshal README for details.
package marshal
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Task provides a subset of kernel.Task, used in marshalling. We don't import
+// the kernel package directly to avoid circular dependency.
+type Task interface {
+ // CopyScratchBuffer provides a task goroutine-local scratch buffer. See
+ // kernel.CopyScratchBuffer.
+ CopyScratchBuffer(size int) []byte
+
+ // CopyOutBytes writes the contents of b to the task's memory. See
+ // kernel.CopyOutBytes.
+ CopyOutBytes(addr usermem.Addr, b []byte) (int, error)
+
+ // CopyInBytes reads the contents of the task's memory to b. See
+ // kernel.CopyInBytes.
+ CopyInBytes(addr usermem.Addr, b []byte) (int, error)
+}
+
// Marshallable represents a type that can be marshalled to and from memory.
type Marshallable interface {
+ io.WriterTo
+
// SizeBytes is the size of the memory representation of a type in
// marshalled form.
+ //
+ // SizeBytes must handle a nil receiver. Practically, this means SizeBytes
+ // cannot deference any fields on the object implementing it (but will
+ // likely make use of the type of these fields).
SizeBytes() int
// MarshalBytes serializes a copy of a type to dst. dst must be at least
@@ -48,13 +76,27 @@ type Marshallable interface {
// MarshalBytes.
MarshalUnsafe(dst []byte)
- // UnmarshalUnsafe deserializes a type directly to the underlying memory
- // allocated for the object by the runtime.
+ // UnmarshalUnsafe deserializes a type by directly copying to the underlying
+ // memory allocated for the object by the runtime.
//
// This allows much faster unmarshalling of types which have no implicit
// padding, see Marshallable.Packed. When Packed would return false,
// UnmarshalUnsafe should fall back to the safer but slower unmarshal
- // mechanism implemented in UnmarshalBytes (usually by calling
- // UnmarshalBytes directly).
+ // mechanism implemented in UnmarshalBytes.
UnmarshalUnsafe(src []byte)
+
+ // CopyIn deserializes a Marshallable type from a task's memory. This may
+ // only be called from a task goroutine. This is more efficient than calling
+ // UnmarshalUnsafe on Marshallable.Packed types, as the type being
+ // marshalled does not escape. The implementation should avoid creating
+ // extra copies in memory by directly deserializing to the object's
+ // underlying memory.
+ CopyIn(task Task, addr usermem.Addr) error
+
+ // CopyOut serializes a Marshallable type to a task's memory. This may only
+ // be called from a task goroutine. This is more efficient than calling
+ // MarshalUnsafe on Marshallable.Packed types, as the type being serialized
+ // does not escape. The implementation should avoid creating extra copies in
+ // memory by directly serializing from the object's underlying memory.
+ CopyOut(task Task, addr usermem.Addr) error
}
diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD
index e345e3a8e..f27c5ce52 100644
--- a/tools/go_marshal/test/BUILD
+++ b/tools/go_marshal/test/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_binary", "go_library", "go_test")
licenses(["notice"])
@@ -27,3 +27,15 @@ go_library(
marshal = True,
deps = ["//tools/go_marshal/test/external"],
)
+
+go_binary(
+ name = "escape",
+ testonly = 1,
+ srcs = ["escape.go"],
+ gc_goopts = ["-m"],
+ deps = [
+ ":test",
+ "//pkg/usermem",
+ "//tools/go_marshal/marshal",
+ ],
+)
diff --git a/tools/go_marshal/test/benchmark_test.go b/tools/go_marshal/test/benchmark_test.go
index e12403741..c79defe9e 100644
--- a/tools/go_marshal/test/benchmark_test.go
+++ b/tools/go_marshal/test/benchmark_test.go
@@ -24,7 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/tools/go_marshal/analysis"
- test "gvisor.dev/gvisor/tools/go_marshal/test"
+ "gvisor.dev/gvisor/tools/go_marshal/test"
)
// Marshalling using the standard encoding/binary package.
diff --git a/tools/go_marshal/test/escape.go b/tools/go_marshal/test/escape.go
new file mode 100644
index 000000000..184f05ea3
--- /dev/null
+++ b/tools/go_marshal/test/escape.go
@@ -0,0 +1,114 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This binary provides a convienient target for analyzing how the go-marshal
+// API causes its various arguments to escape to the heap. To use, build and
+// observe the output from the go compiler's escape analysis:
+//
+// $ bazel build :escape
+// ...
+// escape.go:67:2: moved to heap: task
+// escape.go:77:31: make([]byte, size) escapes to heap
+// escape.go:87:31: make([]byte, size) escapes to heap
+// escape.go:96:6: moved to heap: stat
+// ...
+//
+// This is not an automated test, but simply a minimal binary for easy analysis.
+package main
+
+import (
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/test"
+)
+
+// dummyTask implements marshal.Task.
+type dummyTask struct {
+}
+
+func (*dummyTask) CopyScratchBuffer(size int) []byte {
+ return make([]byte, size)
+}
+
+func (*dummyTask) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) {
+ return len(b), nil
+}
+
+func (*dummyTask) CopyInBytes(addr usermem.Addr, b []byte) (int, error) {
+ return len(b), nil
+}
+
+func (task *dummyTask) MarshalBytes(addr usermem.Addr, marshallable marshal.Marshallable) {
+ buf := task.CopyScratchBuffer(marshallable.SizeBytes())
+ marshallable.MarshalBytes(buf)
+ task.CopyOutBytes(addr, buf)
+}
+
+func (task *dummyTask) MarshalUnsafe(addr usermem.Addr, marshallable marshal.Marshallable) {
+ buf := task.CopyScratchBuffer(marshallable.SizeBytes())
+ marshallable.MarshalUnsafe(buf)
+ task.CopyOutBytes(addr, buf)
+}
+
+// Expected escapes:
+// - task: passed to marshal.Marshallable.CopyOut as the marshal.Task interface.
+func doCopyOut() {
+ task := dummyTask{}
+ var stat test.Stat
+ stat.CopyOut(&task, usermem.Addr(0xf000ba12))
+}
+
+// Expected escapes:
+// - buf: make allocates on the heap.
+func doMarshalBytesDirect() {
+ task := dummyTask{}
+ var stat test.Stat
+ buf := task.CopyScratchBuffer(stat.SizeBytes())
+ stat.MarshalBytes(buf)
+ task.CopyOutBytes(usermem.Addr(0xf000ba12), buf)
+}
+
+// Expected escapes:
+// - buf: make allocates on the heap.
+func doMarshalUnsafeDirect() {
+ task := dummyTask{}
+ var stat test.Stat
+ buf := task.CopyScratchBuffer(stat.SizeBytes())
+ stat.MarshalUnsafe(buf)
+ task.CopyOutBytes(usermem.Addr(0xf000ba12), buf)
+}
+
+// Expected escapes:
+// - stat: passed to dummyTask.MarshalBytes as the marshal.Marshallable interface.
+func doMarshalBytesViaMarshallable() {
+ task := dummyTask{}
+ var stat test.Stat
+ task.MarshalBytes(usermem.Addr(0xf000ba12), &stat)
+}
+
+// Expected escapes:
+// - stat: passed to dummyTask.MarshalUnsafe as the marshal.Marshallable interface.
+func doMarshalUnsafeViaMarshallable() {
+ task := dummyTask{}
+ var stat test.Stat
+ task.MarshalUnsafe(usermem.Addr(0xf000ba12), &stat)
+}
+
+func main() {
+ doCopyOut()
+ doMarshalBytesDirect()
+ doMarshalUnsafeDirect()
+ doMarshalBytesViaMarshallable()
+ doMarshalUnsafeViaMarshallable()
+}
diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go
index 8de02d707..93229dedb 100644
--- a/tools/go_marshal/test/test.go
+++ b/tools/go_marshal/test/test.go
@@ -103,3 +103,13 @@ type Stat struct {
CTime Timespec
_ [3]int64
}
+
+// SignalSet is an example marshallable newtype on a primitive.
+//
+// +marshal
+type SignalSet uint64
+
+// SignalSetAlias is an example newtype on another marshallable type.
+//
+// +marshal
+type SignalSetAlias SignalSet
diff --git a/tools/go_stateify/BUILD b/tools/go_stateify/BUILD
index a133d6f8b..503cdf2e5 100644
--- a/tools/go_stateify/BUILD
+++ b/tools/go_stateify/BUILD
@@ -5,5 +5,6 @@ package(licenses = ["notice"])
go_binary(
name = "stateify",
srcs = ["main.go"],
- visibility = ["//visibility:public"],
+ visibility = ["//:sandbox"],
+ deps = ["//tools/tags"],
)
diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl
index 0f261d89f..6a5e666f0 100644
--- a/tools/go_stateify/defs.bzl
+++ b/tools/go_stateify/defs.bzl
@@ -6,8 +6,7 @@ def _go_stateify_impl(ctx):
# Run the stateify command.
args = ["-output=%s" % output.path]
- args.append("-pkg=%s" % ctx.attr.package)
- args.append("-arch=%s" % ctx.attr.arch)
+ args.append("-fullpkg=%s" % ctx.attr.package)
if ctx.attr._statepkg:
args.append("-statepkg=%s" % ctx.attr._statepkg)
if ctx.attr.imports:
@@ -44,18 +43,11 @@ for statified types.
mandatory = False,
),
"package": attr.string(
- doc = "The package name for the input sources.",
- mandatory = True,
- ),
- "arch": attr.string(
- doc = "Target platform.",
+ doc = "The fully qualified package name for the input sources.",
mandatory = True,
),
"out": attr.output(
- doc = """
-The name of the generated file output. This must not conflict with any other
-files and must be added to the srcs of the relevant go_library.
-""",
+ doc = "Name of the generator output file.",
mandatory = True,
),
"_tool": attr.label(
diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go
index 7d5d291e6..3437aa476 100644
--- a/tools/go_stateify/main.go
+++ b/tools/go_stateify/main.go
@@ -22,126 +22,22 @@ import (
"go/ast"
"go/parser"
"go/token"
- "io/ioutil"
"os"
"path/filepath"
"reflect"
"strings"
"sync"
+
+ "gvisor.dev/gvisor/tools/tags"
)
var (
- pkg = flag.String("pkg", "", "output package")
+ fullPkg = flag.String("fullpkg", "", "fully qualified output package")
imports = flag.String("imports", "", "extra imports for the output file")
output = flag.String("output", "", "output file")
statePkg = flag.String("statepkg", "", "state import package; defaults to empty")
- arch = flag.String("arch", "", "specify the target platform")
)
-// The known architectures.
-var okgoarch = []string{
- "386",
- "amd64",
- "arm",
- "arm64",
- "mips",
- "mipsle",
- "mips64",
- "mips64le",
- "ppc64",
- "ppc64le",
- "riscv64",
- "s390x",
- "sparc64",
- "wasm",
-}
-
-// readfile returns the content of the named file.
-func readfile(file string) string {
- data, err := ioutil.ReadFile(file)
- if err != nil {
- panic(fmt.Sprintf("readfile err: %v", err))
- }
- return string(data)
-}
-
-// matchfield reports whether the field (x,y,z) matches this build.
-// all the elements in the field must be satisfied.
-func matchfield(f string, goarch string) bool {
- for _, tag := range strings.Split(f, ",") {
- if !matchtag(tag, goarch) {
- return false
- }
- }
- return true
-}
-
-// matchtag reports whether the tag (x or !x) matches this build.
-func matchtag(tag string, goarch string) bool {
- if tag == "" {
- return false
- }
- if tag[0] == '!' {
- if len(tag) == 1 || tag[1] == '!' {
- return false
- }
- return !matchtag(tag[1:], goarch)
- }
- return tag == goarch
-}
-
-// canBuild reports whether we can build this file for target platform by
-// checking file name and build tags. The code is derived from the Go source
-// cmd.dist.build.shouldbuild.
-func canBuild(file, goTargetArch string) bool {
- name := filepath.Base(file)
- excluded := func(list []string, ok string) bool {
- for _, x := range list {
- if x == ok || (ok == "android" && x == "linux") || (ok == "illumos" && x == "solaris") {
- continue
- }
- i := strings.Index(name, x)
- if i <= 0 || name[i-1] != '_' {
- continue
- }
- i += len(x)
- if i == len(name) || name[i] == '.' || name[i] == '_' {
- return true
- }
- }
- return false
- }
- if excluded(okgoarch, goTargetArch) {
- return false
- }
-
- // Check file contents for // +build lines.
- for _, p := range strings.Split(readfile(file), "\n") {
- p = strings.TrimSpace(p)
- if p == "" {
- continue
- }
- if !strings.HasPrefix(p, "//") {
- break
- }
- if !strings.Contains(p, "+build") {
- continue
- }
- fields := strings.Fields(p[2:])
- if len(fields) < 1 || fields[0] != "+build" {
- continue
- }
- for _, p := range fields[1:] {
- if matchfield(p, goTargetArch) {
- goto fieldmatch
- }
- }
- return false
- fieldmatch:
- }
- return true
-}
-
// resolveTypeName returns a qualified type name.
func resolveTypeName(name string, typ ast.Expr) (field string, qualified string) {
for done := false; !done; {
@@ -275,7 +171,7 @@ func main() {
flag.Usage()
os.Exit(1)
}
- if *pkg == "" {
+ if *fullPkg == "" {
fmt.Fprintf(os.Stderr, "Error: package required.")
os.Exit(1)
}
@@ -307,7 +203,7 @@ func main() {
// Declare our emission closures.
emitRegister := func(name string) {
- initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *pkg, name, name, name, name))
+ initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *fullPkg, name, name, name, name))
}
emitZeroCheck := func(name string) {
fmt.Fprintf(outputFile, " if !%sIsZeroValue(x.%s) { m.Failf(\"%s is %%v, expected zero\", x.%s) }\n", statePrefix, name, name, name)
@@ -329,9 +225,17 @@ func main() {
fmt.Fprintf(outputFile, " m.Save(\"%s\", &x.%s)\n", name, name)
}
- // Emit the package name.
+ // Automated warning.
fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n")
- fmt.Fprintf(outputFile, "package %s\n\n", *pkg)
+
+ // Emit build tags.
+ if t := tags.Aggregate(flag.Args()); len(t) > 0 {
+ fmt.Fprintf(outputFile, "%s\n\n", strings.Join(t.Lines(), "\n"))
+ }
+
+ // Emit the package name.
+ _, pkg := filepath.Split(*fullPkg)
+ fmt.Fprintf(outputFile, "package %s\n\n", pkg)
// Emit the imports lazily.
var once sync.Once
@@ -364,10 +268,6 @@ func main() {
os.Exit(1)
}
- if !canBuild(filename, *arch) {
- continue
- }
-
files = append(files, f)
}
diff --git a/tools/images/BUILD b/tools/images/BUILD
index f1699b184..fe11f08a3 100644
--- a/tools/images/BUILD
+++ b/tools/images/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "cc_binary")
+load("//tools:defs.bzl", "cc_binary", "gtest")
load("//tools/images:defs.bzl", "vm_image", "vm_test")
package(
@@ -32,8 +32,8 @@ cc_binary(
srcs = ["test.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:test_main",
- "@com_google_googletest//:gtest",
],
)
diff --git a/tools/images/defs.bzl b/tools/images/defs.bzl
index 32235813a..de365d153 100644
--- a/tools/images/defs.bzl
+++ b/tools/images/defs.bzl
@@ -57,7 +57,10 @@ def _vm_image_impl(ctx):
command = argv,
input_manifests = runfiles_manifests,
)
- return [DefaultInfo(files = depset([ctx.outputs.out]))]
+ return [DefaultInfo(
+ files = depset([ctx.outputs.out]),
+ runfiles = ctx.runfiles(files = [ctx.outputs.out]),
+ )]
_vm_image = rule(
attrs = {
diff --git a/tools/installers/BUILD b/tools/installers/BUILD
index 01bc4de8c..d78a265ca 100644
--- a/tools/installers/BUILD
+++ b/tools/installers/BUILD
@@ -5,10 +5,15 @@ package(
licenses = ["notice"],
)
+filegroup(
+ name = "runsc",
+ srcs = ["//runsc"],
+)
+
sh_binary(
name = "head",
srcs = ["head.sh"],
- data = ["//runsc"],
+ data = [":runsc"],
)
sh_binary(
diff --git a/tools/installers/head.sh b/tools/installers/head.sh
index 4435cb27a..9de8f138c 100755
--- a/tools/installers/head.sh
+++ b/tools/installers/head.sh
@@ -15,7 +15,7 @@
# limitations under the License.
# Install our runtime.
-third_party/gvisor/runsc/runsc install
+$(dirname $0)/runsc install
# Restart docker.
service docker restart || true
diff --git a/tools/installers/master.sh b/tools/installers/master.sh
index 7b1956454..52f9734a6 100755
--- a/tools/installers/master.sh
+++ b/tools/installers/master.sh
@@ -15,6 +15,21 @@
# limitations under the License.
# Install runsc from the master branch.
+set -e
+
curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add -
add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main"
-apt-get update && apt-get install -y runsc
+while true; do
+ if apt-get update; then
+ apt-get install -y runsc
+ break
+ fi
+ result=$?
+ # Check if apt update failed to aquire the file lock.
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
+runsc install
+service docker restart
+
diff --git a/tools/tag_release.sh b/tools/tag_release.sh
index f33b902d6..4dbfe420a 100755
--- a/tools/tag_release.sh
+++ b/tools/tag_release.sh
@@ -21,13 +21,19 @@
set -xeu
# Check arguments.
-if [ "$#" -ne 2 ]; then
- echo "usage: $0 <commit|revid> <release.rc>"
+if [ "$#" -ne 3 ]; then
+ echo "usage: $0 <commit|revid> <release.rc> <message-file>"
exit 1
fi
declare -r target_commit="$1"
declare -r release="$2"
+declare -r message_file="$3"
+
+if ! [[ -r "${message_file}" ]]; then
+ echo "error: message file '${message_file}' is not readable."
+ exit 1
+fi
closest_commit() {
while read line; do
@@ -64,6 +70,6 @@ fi
# Tag the given commit (annotated, to record the committer).
declare -r tag="release-${release}"
-(git tag -m "Release ${release}" -a "${tag}" "${commit}" && \
+(git tag -F "${message_file}" -a "${tag}" "${commit}" && \
git push origin tag "${tag}") || \
(git tag -d "${tag}" && false)
diff --git a/tools/tags/BUILD b/tools/tags/BUILD
new file mode 100644
index 000000000..1c02e2c89
--- /dev/null
+++ b/tools/tags/BUILD
@@ -0,0 +1,11 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "tags",
+ srcs = ["tags.go"],
+ marshal = False,
+ stateify = False,
+ visibility = ["//tools:__subpackages__"],
+)
diff --git a/tools/tags/tags.go b/tools/tags/tags.go
new file mode 100644
index 000000000..f35904e0a
--- /dev/null
+++ b/tools/tags/tags.go
@@ -0,0 +1,89 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tags is a utility for parsing build tags.
+package tags
+
+import (
+ "fmt"
+ "io/ioutil"
+ "strings"
+)
+
+// OrSet is a set of tags on a single line.
+//
+// Note that tags may include ",", and we don't distinguish this case in the
+// logic below. Ideally, this constraints can be split into separate top-level
+// build tags in order to resolve any issues.
+type OrSet []string
+
+// Line returns the line for this or.
+func (or OrSet) Line() string {
+ return fmt.Sprintf("// +build %s", strings.Join([]string(or), " "))
+}
+
+// AndSet is the set of all OrSets.
+type AndSet []OrSet
+
+// Lines returns the lines to be printed.
+func (and AndSet) Lines() (ls []string) {
+ for _, or := range and {
+ ls = append(ls, or.Line())
+ }
+ return
+}
+
+// Join joins this AndSet with another.
+func (and AndSet) Join(other AndSet) AndSet {
+ return append(and, other...)
+}
+
+// Tags returns the unique set of +build tags.
+//
+// Derived form the runtime's canBuild.
+func Tags(file string) (tags AndSet) {
+ data, err := ioutil.ReadFile(file)
+ if err != nil {
+ return nil
+ }
+ // Check file contents for // +build lines.
+ for _, p := range strings.Split(string(data), "\n") {
+ p = strings.TrimSpace(p)
+ if p == "" {
+ continue
+ }
+ if !strings.HasPrefix(p, "//") {
+ break
+ }
+ if !strings.Contains(p, "+build") {
+ continue
+ }
+ fields := strings.Fields(p[2:])
+ if len(fields) < 1 || fields[0] != "+build" {
+ continue
+ }
+ tags = append(tags, OrSet(fields[1:]))
+ }
+ return tags
+}
+
+// Aggregate aggregates all tags from a set of files.
+//
+// Note that these may be in conflict, in which case the build will fail.
+func Aggregate(files []string) (tags AndSet) {
+ for _, file := range files {
+ tags = tags.Join(Tags(file))
+ }
+ return tags
+}