summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--benchmarks/BUILD1
-rw-r--r--benchmarks/README.md126
-rw-r--r--benchmarks/harness/BUILD5
-rw-r--r--benchmarks/harness/__init__.py9
-rw-r--r--benchmarks/harness/machine.py11
-rw-r--r--benchmarks/harness/machine_producers/BUILD40
-rw-r--r--benchmarks/harness/machine_producers/gcloud_producer.py268
-rw-r--r--benchmarks/harness/machine_producers/gcloud_producer_test.py48
-rw-r--r--benchmarks/harness/machine_producers/machine_producer.py21
-rw-r--r--benchmarks/harness/machine_producers/mock_producer.py23
-rw-r--r--benchmarks/harness/machine_producers/testdata/get_five.json211
-rw-r--r--benchmarks/harness/machine_producers/testdata/get_one.json145
-rw-r--r--benchmarks/harness/ssh_connection.py9
-rw-r--r--benchmarks/runner/BUILD10
-rw-r--r--benchmarks/runner/__init__.py119
-rw-r--r--benchmarks/runner/commands.py135
-rw-r--r--benchmarks/runner/runner_test.py2
-rw-r--r--benchmarks/suites/http.py2
-rwxr-xr-xbenchmarks/tcp/tcp_benchmark.sh21
-rw-r--r--benchmarks/tcp/tcp_proxy.go66
-rw-r--r--benchmarks/workloads/BUILD40
-rw-r--r--benchmarks/workloads/ab/BUILD5
-rw-r--r--benchmarks/workloads/absl/BUILD5
-rw-r--r--benchmarks/workloads/curl/BUILD6
-rw-r--r--benchmarks/workloads/ffmpeg/BUILD6
-rw-r--r--benchmarks/workloads/fio/BUILD5
-rw-r--r--benchmarks/workloads/httpd/BUILD6
-rw-r--r--benchmarks/workloads/iperf/BUILD5
-rw-r--r--benchmarks/workloads/netcat/BUILD6
-rw-r--r--benchmarks/workloads/nginx/BUILD6
-rw-r--r--benchmarks/workloads/node/BUILD6
-rw-r--r--benchmarks/workloads/node_template/BUILD6
-rw-r--r--benchmarks/workloads/redis/BUILD6
-rw-r--r--benchmarks/workloads/redisbenchmark/BUILD5
-rw-r--r--benchmarks/workloads/ruby/BUILD13
-rw-r--r--benchmarks/workloads/ruby_template/BUILD7
-rw-r--r--benchmarks/workloads/sleep/BUILD6
-rw-r--r--benchmarks/workloads/sysbench/BUILD5
-rw-r--r--benchmarks/workloads/syscall/BUILD5
-rw-r--r--benchmarks/workloads/tensorflow/BUILD6
-rw-r--r--benchmarks/workloads/true/BUILD7
-rw-r--r--kokoro/issue_reviver.cfg15
-rw-r--r--pkg/abi/linux/netfilter.go89
-rw-r--r--pkg/abi/linux/time.go13
-rw-r--r--pkg/amutex/BUILD1
-rw-r--r--pkg/amutex/amutex_test.go3
-rw-r--r--pkg/atomicbitops/BUILD1
-rw-r--r--pkg/atomicbitops/atomic_bitops_test.go3
-rw-r--r--pkg/compressio/BUILD5
-rw-r--r--pkg/compressio/compressio.go2
-rw-r--r--pkg/control/server/BUILD1
-rw-r--r--pkg/control/server/server.go2
-rw-r--r--pkg/cpuid/cpuid.go44
-rw-r--r--pkg/eventchannel/BUILD2
-rw-r--r--pkg/eventchannel/event.go2
-rw-r--r--pkg/eventchannel/event_test.go2
-rw-r--r--pkg/fdchannel/BUILD1
-rw-r--r--pkg/fdchannel/fdchannel_test.go3
-rw-r--r--pkg/fdnotifier/BUILD1
-rw-r--r--pkg/fdnotifier/fdnotifier.go2
-rw-r--r--pkg/flipcall/BUILD3
-rw-r--r--pkg/flipcall/flipcall_example_test.go3
-rw-r--r--pkg/flipcall/flipcall_test.go3
-rw-r--r--pkg/flipcall/flipcall_unsafe.go10
-rw-r--r--pkg/gate/BUILD1
-rw-r--r--pkg/gate/gate_test.go2
-rw-r--r--pkg/goid/BUILD26
-rw-r--r--pkg/goid/empty_test.go (renamed from pkg/sentry/socket/rpcinet/rpcinet.go)12
-rw-r--r--pkg/goid/goid.go24
-rw-r--r--pkg/goid/goid_amd64.s (renamed from pkg/sentry/socket/rpcinet/device.go)12
-rw-r--r--pkg/goid/goid_race.go25
-rw-r--r--pkg/goid/goid_test.go74
-rw-r--r--pkg/goid/goid_unsafe.go64
-rw-r--r--pkg/linewriter/BUILD1
-rw-r--r--pkg/linewriter/linewriter.go3
-rw-r--r--pkg/log/BUILD5
-rw-r--r--pkg/log/log.go2
-rw-r--r--pkg/metric/BUILD1
-rw-r--r--pkg/metric/metric.go2
-rw-r--r--pkg/p9/BUILD1
-rw-r--r--pkg/p9/client.go2
-rw-r--r--pkg/p9/client_file.go29
-rw-r--r--pkg/p9/file.go16
-rw-r--r--pkg/p9/handlers.go29
-rw-r--r--pkg/p9/messages.go129
-rw-r--r--pkg/p9/messages_test.go15
-rw-r--r--pkg/p9/p9.go4
-rw-r--r--pkg/p9/p9test/BUILD2
-rw-r--r--pkg/p9/p9test/client_test.go2
-rw-r--r--pkg/p9/p9test/p9test.go2
-rw-r--r--pkg/p9/path_tree.go3
-rw-r--r--pkg/p9/pool.go2
-rw-r--r--pkg/p9/server.go2
-rw-r--r--pkg/p9/transport.go2
-rw-r--r--pkg/p9/version.go8
-rw-r--r--pkg/procid/BUILD2
-rw-r--r--pkg/procid/procid_test.go3
-rw-r--r--pkg/rand/BUILD5
-rw-r--r--pkg/rand/rand_linux.go2
-rw-r--r--pkg/refs/BUILD2
-rw-r--r--pkg/refs/refcounter.go2
-rw-r--r--pkg/refs/refcounter_test.go3
-rw-r--r--pkg/sentry/arch/BUILD7
-rw-r--r--pkg/sentry/arch/arch_aarch64.go293
-rw-r--r--pkg/sentry/arch/arch_arm64.go266
-rw-r--r--pkg/sentry/arch/arch_state_aarch64.go (renamed from pkg/sentry/fsimpl/proc/mounts.go)33
-rw-r--r--pkg/sentry/arch/arch_state_x86.go2
-rw-r--r--pkg/sentry/arch/arch_x86.go2
-rw-r--r--pkg/sentry/arch/registers.proto37
-rw-r--r--pkg/sentry/arch/signal.go250
-rw-r--r--pkg/sentry/arch/signal_amd64.go230
-rw-r--r--pkg/sentry/arch/signal_arm64.go126
-rw-r--r--pkg/sentry/arch/signal_stack.go2
-rw-r--r--pkg/sentry/arch/syscalls_arm64.go62
-rw-r--r--pkg/sentry/control/BUILD1
-rw-r--r--pkg/sentry/control/pprof.go2
-rw-r--r--pkg/sentry/device/BUILD5
-rw-r--r--pkg/sentry/device/device.go2
-rw-r--r--pkg/sentry/fs/BUILD3
-rw-r--r--pkg/sentry/fs/copy_up.go11
-rw-r--r--pkg/sentry/fs/copy_up_test.go2
-rw-r--r--pkg/sentry/fs/dirent.go2
-rw-r--r--pkg/sentry/fs/dirent_cache.go3
-rw-r--r--pkg/sentry/fs/dirent_cache_limiter.go3
-rw-r--r--pkg/sentry/fs/fdpipe/BUILD1
-rw-r--r--pkg/sentry/fs/fdpipe/pipe.go2
-rw-r--r--pkg/sentry/fs/fdpipe/pipe_state.go2
-rw-r--r--pkg/sentry/fs/file.go2
-rw-r--r--pkg/sentry/fs/file_overlay.go4
-rw-r--r--pkg/sentry/fs/filesystems.go2
-rw-r--r--pkg/sentry/fs/fs.go3
-rw-r--r--pkg/sentry/fs/fsutil/BUILD1
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper.go2
-rw-r--r--pkg/sentry/fs/fsutil/host_mappable.go2
-rw-r--r--pkg/sentry/fs/fsutil/inode.go44
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go2
-rw-r--r--pkg/sentry/fs/gofer/BUILD1
-rw-r--r--pkg/sentry/fs/gofer/context_file.go14
-rw-r--r--pkg/sentry/fs/gofer/inode.go20
-rw-r--r--pkg/sentry/fs/gofer/session.go2
-rw-r--r--pkg/sentry/fs/host/BUILD1
-rw-r--r--pkg/sentry/fs/host/inode.go2
-rw-r--r--pkg/sentry/fs/host/socket.go2
-rw-r--r--pkg/sentry/fs/host/tty.go3
-rw-r--r--pkg/sentry/fs/inode.go27
-rw-r--r--pkg/sentry/fs/inode_inotify.go3
-rw-r--r--pkg/sentry/fs/inode_operations.go25
-rw-r--r--pkg/sentry/fs/inode_overlay.go41
-rw-r--r--pkg/sentry/fs/inode_overlay_test.go4
-rw-r--r--pkg/sentry/fs/inotify.go2
-rw-r--r--pkg/sentry/fs/inotify_watch.go2
-rw-r--r--pkg/sentry/fs/lock/BUILD1
-rw-r--r--pkg/sentry/fs/lock/lock.go5
-rw-r--r--pkg/sentry/fs/mounts.go2
-rw-r--r--pkg/sentry/fs/overlay.go5
-rw-r--r--pkg/sentry/fs/proc/BUILD3
-rw-r--r--pkg/sentry/fs/proc/cgroup.go4
-rw-r--r--pkg/sentry/fs/proc/cpuinfo.go12
-rw-r--r--pkg/sentry/fs/proc/exec_args.go4
-rw-r--r--pkg/sentry/fs/proc/fds.go4
-rw-r--r--pkg/sentry/fs/proc/filesystems.go4
-rw-r--r--pkg/sentry/fs/proc/fs.go4
-rw-r--r--pkg/sentry/fs/proc/inode.go4
-rw-r--r--pkg/sentry/fs/proc/loadavg.go4
-rw-r--r--pkg/sentry/fs/proc/meminfo.go4
-rw-r--r--pkg/sentry/fs/proc/mounts.go4
-rw-r--r--pkg/sentry/fs/proc/net.go4
-rw-r--r--pkg/sentry/fs/proc/proc.go13
-rw-r--r--pkg/sentry/fs/proc/rpcinet_proc.go217
-rw-r--r--pkg/sentry/fs/proc/seqfile/BUILD1
-rw-r--r--pkg/sentry/fs/proc/seqfile/seqfile.go2
-rw-r--r--pkg/sentry/fs/proc/stat.go4
-rw-r--r--pkg/sentry/fs/proc/sys.go13
-rw-r--r--pkg/sentry/fs/proc/sys_net.go6
-rw-r--r--pkg/sentry/fs/proc/task.go4
-rw-r--r--pkg/sentry/fs/proc/uid_gid_map.go4
-rw-r--r--pkg/sentry/fs/proc/uptime.go4
-rw-r--r--pkg/sentry/fs/proc/version.go4
-rw-r--r--pkg/sentry/fs/ramfs/BUILD1
-rw-r--r--pkg/sentry/fs/ramfs/dir.go2
-rw-r--r--pkg/sentry/fs/restore.go2
-rw-r--r--pkg/sentry/fs/tmpfs/BUILD1
-rw-r--r--pkg/sentry/fs/tmpfs/inode_file.go2
-rw-r--r--pkg/sentry/fs/tmpfs/tmpfs.go18
-rw-r--r--pkg/sentry/fs/tty/BUILD1
-rw-r--r--pkg/sentry/fs/tty/dir.go2
-rw-r--r--pkg/sentry/fs/tty/line_discipline.go2
-rw-r--r--pkg/sentry/fs/tty/queue.go3
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD1
-rw-r--r--pkg/sentry/fsimpl/ext/directory.go3
-rw-r--r--pkg/sentry/fsimpl/ext/filesystem.go2
-rw-r--r--pkg/sentry/fsimpl/ext/regular_file.go2
-rw-r--r--pkg/sentry/fsimpl/kernfs/BUILD2
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go63
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go48
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go4
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go2
-rw-r--r--pkg/sentry/fsimpl/kernfs/symlink.go21
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD14
-rw-r--r--pkg/sentry/fsimpl/proc/filesystem.go11
-rw-r--r--pkg/sentry/fsimpl/proc/loadavg.go42
-rw-r--r--pkg/sentry/fsimpl/proc/meminfo.go79
-rw-r--r--pkg/sentry/fsimpl/proc/stat.go129
-rw-r--r--pkg/sentry/fsimpl/proc/subtasks.go126
-rw-r--r--pkg/sentry/fsimpl/proc/sys.go51
-rw-r--r--pkg/sentry/fsimpl/proc/task.go69
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go315
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go41
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_files.go245
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_net.go (renamed from pkg/sentry/fsimpl/proc/net.go)3
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go143
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys_test.go (renamed from pkg/sentry/fsimpl/proc/net_test.go)0
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_test.go23
-rw-r--r--pkg/sentry/fsimpl/proc/version.go70
-rw-r--r--pkg/sentry/fsimpl/tmpfs/BUILD3
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go16
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go37
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file_test.go250
-rw-r--r--pkg/sentry/fsimpl/tmpfs/stat_test.go232
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go116
-rw-r--r--pkg/sentry/kernel/BUILD5
-rw-r--r--pkg/sentry/kernel/abstract_socket_namespace.go2
-rw-r--r--pkg/sentry/kernel/auth/BUILD3
-rw-r--r--pkg/sentry/kernel/auth/user_namespace.go2
-rw-r--r--pkg/sentry/kernel/epoll/BUILD1
-rw-r--r--pkg/sentry/kernel/epoll/epoll.go2
-rw-r--r--pkg/sentry/kernel/eventfd/BUILD1
-rw-r--r--pkg/sentry/kernel/eventfd/eventfd.go2
-rw-r--r--pkg/sentry/kernel/fasync/BUILD1
-rw-r--r--pkg/sentry/kernel/fasync/fasync.go3
-rw-r--r--pkg/sentry/kernel/fd_table.go2
-rw-r--r--pkg/sentry/kernel/fd_table_test.go2
-rw-r--r--pkg/sentry/kernel/fs_context.go2
-rw-r--r--pkg/sentry/kernel/futex/BUILD8
-rw-r--r--pkg/sentry/kernel/futex/futex.go3
-rw-r--r--pkg/sentry/kernel/futex/futex_test.go2
-rw-r--r--pkg/sentry/kernel/kernel.go2
-rw-r--r--pkg/sentry/kernel/memevent/BUILD1
-rw-r--r--pkg/sentry/kernel/memevent/memory_events.go2
-rw-r--r--pkg/sentry/kernel/pipe/BUILD1
-rw-r--r--pkg/sentry/kernel/pipe/buffer.go2
-rw-r--r--pkg/sentry/kernel/pipe/node.go3
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go2
-rw-r--r--pkg/sentry/kernel/pipe/pipe_util.go2
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go3
-rw-r--r--pkg/sentry/kernel/semaphore/BUILD1
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore.go2
-rw-r--r--pkg/sentry/kernel/shm/BUILD1
-rw-r--r--pkg/sentry/kernel/shm/shm.go2
-rw-r--r--pkg/sentry/kernel/signal_handlers.go3
-rw-r--r--pkg/sentry/kernel/signalfd/BUILD1
-rw-r--r--pkg/sentry/kernel/signalfd/signalfd.go3
-rw-r--r--pkg/sentry/kernel/syscalls.go2
-rw-r--r--pkg/sentry/kernel/syslog.go3
-rw-r--r--pkg/sentry/kernel/task.go5
-rw-r--r--pkg/sentry/kernel/thread_group.go2
-rw-r--r--pkg/sentry/kernel/threads.go2
-rw-r--r--pkg/sentry/kernel/time/BUILD1
-rw-r--r--pkg/sentry/kernel/time/time.go2
-rw-r--r--pkg/sentry/kernel/timekeeper.go2
-rw-r--r--pkg/sentry/kernel/tty.go2
-rw-r--r--pkg/sentry/kernel/uts_namespace.go3
-rw-r--r--pkg/sentry/limits/BUILD1
-rw-r--r--pkg/sentry/limits/limits.go3
-rw-r--r--pkg/sentry/mm/BUILD2
-rw-r--r--pkg/sentry/mm/aio_context.go3
-rw-r--r--pkg/sentry/mm/mm.go8
-rw-r--r--pkg/sentry/pgalloc/BUILD1
-rw-r--r--pkg/sentry/pgalloc/pgalloc.go2
-rw-r--r--pkg/sentry/platform/interrupt/BUILD1
-rw-r--r--pkg/sentry/platform/interrupt/interrupt.go3
-rw-r--r--pkg/sentry/platform/kvm/BUILD1
-rw-r--r--pkg/sentry/platform/kvm/address_space.go2
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go2
-rw-r--r--pkg/sentry/platform/kvm/kvm.go2
-rw-r--r--pkg/sentry/platform/kvm/machine.go2
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go4
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go4
-rw-r--r--pkg/sentry/platform/ptrace/BUILD1
-rw-r--r--pkg/sentry/platform/ptrace/ptrace.go2
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go2
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go2
-rw-r--r--pkg/sentry/platform/ring0/defs.go2
-rw-r--r--pkg/sentry/platform/ring0/defs_amd64.go1
-rw-r--r--pkg/sentry/platform/ring0/defs_arm64.go1
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s10
-rw-r--r--pkg/sentry/platform/ring0/pagetables/BUILD5
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pcids_x86.go2
-rw-r--r--pkg/sentry/socket/control/control.go2
-rw-r--r--pkg/sentry/socket/netfilter/BUILD1
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go408
-rw-r--r--pkg/sentry/socket/netlink/BUILD1
-rw-r--r--pkg/sentry/socket/netlink/port/BUILD1
-rw-r--r--pkg/sentry/socket/netlink/port/port.go3
-rw-r--r--pkg/sentry/socket/netlink/socket.go2
-rw-r--r--pkg/sentry/socket/netstack/BUILD1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go152
-rw-r--r--pkg/sentry/socket/rpcinet/BUILD69
-rw-r--r--pkg/sentry/socket/rpcinet/conn/BUILD17
-rw-r--r--pkg/sentry/socket/rpcinet/conn/conn.go187
-rw-r--r--pkg/sentry/socket/rpcinet/notifier/BUILD16
-rw-r--r--pkg/sentry/socket/rpcinet/notifier/notifier.go231
-rw-r--r--pkg/sentry/socket/rpcinet/socket.go909
-rw-r--r--pkg/sentry/socket/rpcinet/stack.go177
-rw-r--r--pkg/sentry/socket/rpcinet/stack_unsafe.go193
-rw-r--r--pkg/sentry/socket/rpcinet/syscall_rpc.proto352
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD1
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go3
-rw-r--r--pkg/sentry/socket/unix/transport/queue.go3
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go2
-rw-r--r--pkg/sentry/socket/unix/unix.go5
-rw-r--r--pkg/sentry/strace/BUILD3
-rw-r--r--pkg/sentry/strace/linux64_amd64.go (renamed from pkg/sentry/strace/linux64.go)19
-rw-r--r--pkg/sentry/strace/linux64_arm64.go323
-rw-r--r--pkg/sentry/strace/socket.go2
-rw-r--r--pkg/sentry/strace/syscalls.go9
-rw-r--r--pkg/sentry/syscalls/linux/BUILD3
-rw-r--r--pkg/sentry/syscalls/linux/error.go2
-rw-r--r--pkg/sentry/syscalls/linux/linux64_amd64.go4
-rw-r--r--pkg/sentry/syscalls/linux/linux64_arm64.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_clone_amd64.go35
-rw-r--r--pkg/sentry/syscalls/linux/sys_clone_arm64.go35
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_thread.go13
-rw-r--r--pkg/sentry/syscalls/linux/sys_xattr.go48
-rw-r--r--pkg/sentry/time/BUILD4
-rw-r--r--pkg/sentry/time/calibrated_clock.go2
-rw-r--r--pkg/sentry/usage/BUILD1
-rw-r--r--pkg/sentry/usage/memory.go2
-rw-r--r--pkg/sentry/vfs/BUILD3
-rw-r--r--pkg/sentry/vfs/dentry.go2
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go2
-rw-r--r--pkg/sentry/vfs/mount_test.go3
-rw-r--r--pkg/sentry/vfs/mount_unsafe.go4
-rw-r--r--pkg/sentry/vfs/pathname.go3
-rw-r--r--pkg/sentry/vfs/permissions.go24
-rw-r--r--pkg/sentry/vfs/resolving_path.go2
-rw-r--r--pkg/sentry/vfs/vfs.go2
-rw-r--r--pkg/sentry/watchdog/BUILD1
-rw-r--r--pkg/sentry/watchdog/watchdog.go2
-rw-r--r--pkg/sleep/sleep_test.go31
-rw-r--r--pkg/sync/BUILD (renamed from pkg/syncutil/BUILD)9
-rw-r--r--pkg/sync/LICENSE (renamed from pkg/syncutil/LICENSE)0
-rw-r--r--pkg/sync/README.md (renamed from pkg/syncutil/README.md)0
-rw-r--r--pkg/sync/aliases.go37
-rw-r--r--pkg/sync/atomicptr_unsafe.go (renamed from pkg/syncutil/atomicptr_unsafe.go)0
-rw-r--r--pkg/sync/atomicptrtest/BUILD (renamed from pkg/syncutil/atomicptrtest/BUILD)4
-rw-r--r--pkg/sync/atomicptrtest/atomicptr_test.go (renamed from pkg/syncutil/atomicptrtest/atomicptr_test.go)0
-rw-r--r--pkg/sync/downgradable_rwmutex_test.go (renamed from pkg/syncutil/downgradable_rwmutex_test.go)2
-rw-r--r--pkg/sync/downgradable_rwmutex_unsafe.go (renamed from pkg/syncutil/downgradable_rwmutex_unsafe.go)2
-rw-r--r--pkg/sync/memmove_unsafe.go (renamed from pkg/syncutil/memmove_unsafe.go)2
-rw-r--r--pkg/sync/norace_unsafe.go (renamed from pkg/syncutil/norace_unsafe.go)2
-rw-r--r--pkg/sync/race_unsafe.go (renamed from pkg/syncutil/race_unsafe.go)2
-rw-r--r--pkg/sync/seqatomic_unsafe.go (renamed from pkg/syncutil/seqatomic_unsafe.go)16
-rw-r--r--pkg/sync/seqatomictest/BUILD (renamed from pkg/syncutil/seqatomictest/BUILD)10
-rw-r--r--pkg/sync/seqatomictest/seqatomic_test.go (renamed from pkg/syncutil/seqatomictest/seqatomic_test.go)18
-rw-r--r--pkg/sync/seqcount.go (renamed from pkg/syncutil/seqcount.go)2
-rw-r--r--pkg/sync/seqcount_test.go (renamed from pkg/syncutil/seqcount_test.go)2
-rw-r--r--pkg/sync/syncutil.go (renamed from pkg/syncutil/syncutil.go)4
-rw-r--r--pkg/tcpip/BUILD3
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD1
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go2
-rw-r--r--pkg/tcpip/checker/checker.go22
-rw-r--r--pkg/tcpip/header/BUILD1
-rw-r--r--pkg/tcpip/header/ipv6.go50
-rw-r--r--pkg/tcpip/header/ipv6_test.go96
-rw-r--r--pkg/tcpip/header/ndp_router_solicit.go36
-rw-r--r--pkg/tcpip/iptables/BUILD5
-rw-r--r--pkg/tcpip/iptables/iptables.go120
-rw-r--r--pkg/tcpip/iptables/targets.go16
-rw-r--r--pkg/tcpip/iptables/types.go49
-rw-r--r--pkg/tcpip/link/fdbased/BUILD1
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go2
-rw-r--r--pkg/tcpip/link/sharedmem/BUILD2
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/BUILD1
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/pipe_test.go3
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go2
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go2
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD1
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go2
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go2
-rw-r--r--pkg/tcpip/ports/BUILD1
-rw-r--r--pkg/tcpip/ports/ports.go2
-rw-r--r--pkg/tcpip/stack/BUILD4
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go2
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go2
-rw-r--r--pkg/tcpip/stack/ndp.go239
-rw-r--r--pkg/tcpip/stack/ndp_test.go536
-rw-r--r--pkg/tcpip/stack/nic.go232
-rw-r--r--pkg/tcpip/stack/stack.go87
-rw-r--r--pkg/tcpip/stack/stack_test.go192
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go56
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go89
-rw-r--r--pkg/tcpip/tcpip.go12
-rw-r--r--pkg/tcpip/timer_test.go2
-rw-r--r--pkg/tcpip/transport/icmp/BUILD1
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go25
-rw-r--r--pkg/tcpip/transport/packet/BUILD1
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go3
-rw-r--r--pkg/tcpip/transport/raw/BUILD1
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go3
-rw-r--r--pkg/tcpip/transport/tcp/BUILD16
-rw-r--r--pkg/tcpip/transport/tcp/accept.go11
-rw-r--r--pkg/tcpip/transport/tcp/connect.go340
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go224
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go424
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go32
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go3
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go13
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go27
-rw-r--r--pkg/tcpip/transport/tcp/segment_queue.go2
-rw-r--r--pkg/tcpip/transport/tcp/snd.go23
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go204
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go111
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go98
-rw-r--r--pkg/tmutex/BUILD1
-rw-r--r--pkg/tmutex/tmutex_test.go3
-rw-r--r--pkg/unet/BUILD1
-rw-r--r--pkg/unet/unet_test.go3
-rw-r--r--pkg/urpc/BUILD1
-rw-r--r--pkg/urpc/urpc.go2
-rw-r--r--pkg/waiter/BUILD1
-rw-r--r--pkg/waiter/waiter.go2
-rw-r--r--runsc/boot/BUILD2
-rw-r--r--runsc/boot/compat.go2
-rw-r--r--runsc/boot/limits.go2
-rw-r--r--runsc/boot/loader.go2
-rw-r--r--runsc/boot/loader_test.go2
-rw-r--r--runsc/cmd/BUILD1
-rw-r--r--runsc/cmd/create.go1
-rw-r--r--runsc/cmd/gofer.go2
-rw-r--r--runsc/cmd/start.go1
-rw-r--r--runsc/container/BUILD2
-rw-r--r--runsc/container/console_test.go2
-rw-r--r--runsc/container/container_test.go2
-rw-r--r--runsc/container/multi_container_test.go2
-rw-r--r--runsc/container/state_file.go2
-rw-r--r--runsc/fsgofer/BUILD1
-rw-r--r--runsc/fsgofer/fsgofer.go12
-rw-r--r--runsc/sandbox/BUILD1
-rw-r--r--runsc/sandbox/network.go23
-rw-r--r--runsc/sandbox/sandbox.go2
-rw-r--r--runsc/testutil/BUILD1
-rw-r--r--runsc/testutil/testutil.go2
-rwxr-xr-xscripts/common.sh2
-rwxr-xr-xscripts/issue_reviver.sh27
-rw-r--r--test/iptables/filter_input.go16
-rw-r--r--test/iptables/filter_output.go16
-rw-r--r--test/iptables/iptables_test.go2
-rw-r--r--test/iptables/iptables_util.go36
-rw-r--r--test/iptables/nat.go10
-rw-r--r--test/syscalls/BUILD5
-rw-r--r--test/syscalls/linux/BUILD1
-rw-r--r--test/syscalls/linux/inotify.cc4
-rw-r--r--test/syscalls/linux/poll.cc3
-rw-r--r--test/syscalls/linux/preadv2.cc2
-rw-r--r--test/syscalls/linux/readv_common.cc2
-rw-r--r--test/syscalls/linux/socket_bind_to_device_distribution.cc25
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc185
-rw-r--r--test/syscalls/linux/socket_ip_udp_generic.cc40
-rw-r--r--test/syscalls/linux/tcp_socket.cc14
-rw-r--r--test/syscalls/linux/udp_socket_test_cases.cc8
-rw-r--r--test/syscalls/linux/xattr.cc152
464 files changed, 9114 insertions, 4821 deletions
diff --git a/benchmarks/BUILD b/benchmarks/BUILD
index dbadeeaf2..1455c6c5b 100644
--- a/benchmarks/BUILD
+++ b/benchmarks/BUILD
@@ -5,5 +5,6 @@ py_binary(
srcs = ["run.py"],
main = "run.py",
python_version = "PY3",
+ srcs_version = "PY3",
deps = ["//benchmarks/runner"],
)
diff --git a/benchmarks/README.md b/benchmarks/README.md
index ad44cd6ac..ff21614c5 100644
--- a/benchmarks/README.md
+++ b/benchmarks/README.md
@@ -6,66 +6,55 @@ These scripts are tools for collecting performance data for Docker-based tests.
The scripts assume the following:
-* You have a local machine with bazel installed.
-* You have some machine(s) with docker installed. These machines will be
- refered to as the "Environment".
-* Environment machines have the runtime(s) under test installed, such that you
- can run docker with a command like: `docker run --runtime=$RUNTIME
- your/image`.
-* You are able to login to machines in the environment with the local machine
- via ssh and the user for ssh can run docker commands without using `sudo`.
+* There are two sets of machines: one where the scripts will be run
+ (controller) and one or more machines on which docker containers will be run
+ (environment).
+* The controller machine must have bazel installed along with this source
+ code. You should be able to run a command like `bazel run :benchmarks --
+ --list`
+* Environment machines must have docker and the required runtimes installed.
+ More specifically, you should be able to run a command like: `docker run
+ --runtime=$RUNTIME your/image`.
+* The controller has ssh private key which can be used to login to environment
+ machines and run docker commands without using `sudo`. This is not required
+ if running locally via the `run-local` command.
* The docker daemon on each of your environment machines is listening on
`unix:///var/run/docker.sock` (docker's default).
For configuring the environment manually, consult the
[dockerd documentation][dockerd].
-## Environment
-
-All benchmarks require a user defined yaml file describe the environment. These
-files are of the form:
-
-```yaml
-machine1: local
-machine2:
- hostname: 100.100.100.100
- username: username
- key_path: ~/private_keyfile
- key_password: passphrase
-machine3:
- hostname: 100.100.100.101
- username: username
- key_path: ~/private_keyfile
- key_password: passphrase
-```
+## Running benchmarks
-The yaml file defines an environment with three machines named `machine1`,
-`machine2` and `machine3`. `machine1` is the local machine, `machine2` and
-`machine3` are remote machines. Both `machine2` and `machine3` should be
-reachable by `ssh`. For example, the command `ssh -i ~/private_keyfile
-username@100.100.100.100` (using the passphrase `passphrase`) should connect to
-`machine2`.
+Run the following from the benchmarks directory:
-The above is an example only. Machines should be uniform, since they are treated
-as such by the tests. Machines must also be accessible to each other via their
-default routes. Furthermore, some benchmarks will meaningless if running on the
-local machine, such as density.
+```bash
+bazel run :benchmarks -- run-local startup
-For remote machines, `hostname`, `key_path`, and `username` are required and
-others are optional. In addition key files must be generated
-[using the instrcutions below](#generating-ssh-keys).
+...
+method,metric,result
+startup.empty,startup_time_ms,652.5772
+startup.node,startup_time_ms,1654.4042000000002
+startup.ruby,startup_time_ms,1429.835
+```
-The above yaml file can be checked for correctness with the `validate` command
-in the top level perf.py script:
+The above command ran the startup benchmark locally, which consists of three
+benchmarks (empty, node, and ruby). Benchmark tools ran it on the default
+runtime, runc. Running on another installed runtime, like say runsc, is as
+simple as:
-`bazel run :benchmarks -- validate $PWD/examples/localhost.yaml`
+```bash
+bazel run :benchmakrs -- run-local startup --runtime=runsc
+```
-## Running benchmarks
+There is help: ``bash bash bazel run :benchmarks -- --help bazel
+run :benchmarks -- run-local --help` ``
To list available benchmarks, use the `list` commmand:
```bash
bazel run :benchmarks -- list
+ls
...
Benchmark: sysbench.cpu
@@ -75,24 +64,44 @@ Metrics: events_per_second
:param max_prime: The maximum prime number to search.
```
-To run benchmarks, use the `run` command. For example, to run the sysbench
-benchmark above:
+You can choose benchmarks by name or regex like:
```bash
-bazel run :benchmarks -- run --env $PWD/examples/localhost.yaml sysbench.cpu
+bazel run :benchmarks -- run-local startup.node
+...
+metric,result
+startup_time_ms,1671.7178000000001
+
+```
+
+or
+
+```bash
+bazel run :benchmarks -- run-local s
+...
+method,metric,result
+startup.empty,startup_time_ms,1792.8292
+startup.node,startup_time_ms,3113.5274
+startup.ruby,startup_time_ms,3025.2424
+sysbench.cpu,cpu_events_per_second,12661.47
+sysbench.memory,memory_ops_per_second,7228268.44
+sysbench.mutex,mutex_time,17.4835
+sysbench.mutex,mutex_latency,3496.7
+sysbench.mutex,mutex_deviation,0.04
+syscall.syscall,syscall_time_ns,2065.0
```
You can run parameterized benchmarks, for example to run with different
runtimes:
```bash
-bazel run :benchmarks -- run --env $PWD/examples/localhost.yaml --runtime=runc --runtime=runsc sysbench.cpu
+bazel run :benchmarks -- run-local --runtime=runc --runtime=runsc sysbench.cpu
```
Or with different parameters:
```bash
-bazel run :benchmarks -- run --env $PWD/examples/localhost.yaml --max_prime=10 --max_prime=100 sysbench.cpu
+bazel run :benchmarks -- run-local --max_prime=10 --max_prime=100 sysbench.cpu
```
## Writing benchmarks
@@ -121,7 +130,7 @@ The harness requires workloads to run. These are all available in the
In general, a workload consists of a Dockerfile to build it (while these are not
hermetic, in general they should be as fixed and isolated as possible), some
-parses for output if required, parser tests and sample data. Provided the test
+parsers for output if required, parser tests and sample data. Provided the test
is named after the workload package and contains a function named `sample`, this
variable will be used to automatically mock workload output when the `--mock`
flag is provided to the main tool.
@@ -149,24 +158,5 @@ To write a new benchmark, open a module in the `suites` directory and use the
above signature. You should add a descriptive doc string to describe what your
benchmark is and any test centric arguments.
-## Generating SSH Keys
-
-The scripts only support RSA Keys, and ssh library used in paramiko. Paramiko
-only supports RSA keys that look like the following (PEM format):
-
-```bash
-$ cat /path/to/ssh/key
-
------BEGIN RSA PRIVATE KEY-----
-...private key text...
------END RSA PRIVATE KEY-----
-
-```
-
-To generate ssh keys in PEM format, use the [`-t rsa -m PEM -b 4096`][RSA-keys].
-option.
-
[dockerd]: https://docs.docker.com/engine/reference/commandline/dockerd/
[docker-py]: https://docker-py.readthedocs.io/en/stable/
-[paramiko]: http://docs.paramiko.org/en/2.4/api/client.html
-[RSA-keys]: https://serverfault.com/questions/939909/ssh-keygen-does-not-create-rsa-private-key
diff --git a/benchmarks/harness/BUILD b/benchmarks/harness/BUILD
index 9546220c4..081a74243 100644
--- a/benchmarks/harness/BUILD
+++ b/benchmarks/harness/BUILD
@@ -24,6 +24,7 @@ py_library(
name = "container",
srcs = ["container.py"],
deps = [
+ "//benchmarks/workloads",
requirement("asn1crypto", False),
requirement("chardet", False),
requirement("certifi", False),
@@ -45,6 +46,7 @@ py_library(
"//benchmarks/harness:container",
"//benchmarks/harness:ssh_connection",
"//benchmarks/harness:tunnel_dispatcher",
+ "//benchmarks/harness/machine_mocks",
requirement("asn1crypto", False),
requirement("chardet", False),
requirement("certifi", False),
@@ -53,6 +55,7 @@ py_library(
requirement("idna", False),
requirement("ptyprocess", False),
requirement("requests", False),
+ requirement("six", False),
requirement("urllib3", False),
requirement("websocket-client", False),
],
@@ -64,7 +67,7 @@ py_library(
deps = [
"//benchmarks/harness",
requirement("bcrypt", False),
- requirement("cffi", False),
+ requirement("cffi", True),
requirement("paramiko", True),
requirement("cryptography", False),
],
diff --git a/benchmarks/harness/__init__.py b/benchmarks/harness/__init__.py
index a7f34da9e..61fd25f73 100644
--- a/benchmarks/harness/__init__.py
+++ b/benchmarks/harness/__init__.py
@@ -13,13 +13,20 @@
# limitations under the License.
"""Core benchmark utilities."""
+import getpass
import os
# LOCAL_WORKLOADS_PATH defines the path to use for local workloads. This is a
# format string that accepts a single string parameter.
LOCAL_WORKLOADS_PATH = os.path.join(
- os.path.dirname(__file__), "../workloads/{}")
+ os.path.dirname(__file__), "../workloads/{}/tar.tar")
# REMOTE_WORKLOADS_PATH defines the path to use for storing the workloads on the
# remote host. This is a format string that accepts a single string parameter.
REMOTE_WORKLOADS_PATH = "workloads/{}"
+
+# DEFAULT_USER is the default user running this script.
+DEFAULT_USER = getpass.getuser()
+
+# DEFAULT_USER_HOME is the home directory of the user running the script.
+DEFAULT_USER_HOME = os.environ["HOME"] if "HOME" in os.environ else ""
diff --git a/benchmarks/harness/machine.py b/benchmarks/harness/machine.py
index 66b719b63..2df4c9e31 100644
--- a/benchmarks/harness/machine.py
+++ b/benchmarks/harness/machine.py
@@ -160,15 +160,17 @@ class LocalMachine(Machine):
stdout, stderr = process.communicate()
return stdout.decode("utf-8"), stderr.decode("utf-8")
- def read(self, path: str) -> str:
+ def read(self, path: str) -> bytes:
# Read the exact path locally.
return open(path, "r").read()
def pull(self, workload: str) -> str:
# Run the docker build command locally.
logging.info("Building %s@%s locally...", workload, self._name)
- self.run("docker build --tag={} {}".format(
- workload, harness.LOCAL_WORKLOADS_PATH.format(workload)))
+ with open(harness.LOCAL_WORKLOADS_PATH.format(workload),
+ "rb") as dockerfile:
+ self._docker_client.images.build(
+ fileobj=dockerfile, tag=workload, custom_context=True)
return workload # Workload is the tag.
def container(self, image: str, **kwargs) -> container.Container:
@@ -212,6 +214,9 @@ class RemoteMachine(Machine):
# Push to the remote machine and build.
logging.info("Building %s@%s remotely...", workload, self._name)
remote_path = self._ssh_connection.send_workload(workload)
+ # Workloads are all tarballs.
+ self.run("tar -xvf {remote_path}/tar.tar -C {remote_path}".format(
+ remote_path=remote_path))
self.run("docker build --tag={} {}".format(workload, remote_path))
return workload # Workload is the tag.
diff --git a/benchmarks/harness/machine_producers/BUILD b/benchmarks/harness/machine_producers/BUILD
index a48da02a1..c4e943882 100644
--- a/benchmarks/harness/machine_producers/BUILD
+++ b/benchmarks/harness/machine_producers/BUILD
@@ -20,6 +20,7 @@ py_library(
srcs = ["mock_producer.py"],
deps = [
"//benchmarks/harness:machine",
+ "//benchmarks/harness/machine_producers:gcloud_producer",
"//benchmarks/harness/machine_producers:machine_producer",
],
)
@@ -38,3 +39,42 @@ py_library(
name = "gcloud_mock_recorder",
srcs = ["gcloud_mock_recorder.py"],
)
+
+py_library(
+ name = "gcloud_producer",
+ srcs = ["gcloud_producer.py"],
+ deps = [
+ "//benchmarks/harness:machine",
+ "//benchmarks/harness/machine_producers:gcloud_mock_recorder",
+ "//benchmarks/harness/machine_producers:machine_producer",
+ ],
+)
+
+filegroup(
+ name = "test_data",
+ srcs = [
+ "testdata/get_five.json",
+ "testdata/get_one.json",
+ ],
+)
+
+py_library(
+ name = "gcloud_producer_test_lib",
+ srcs = ["gcloud_producer_test.py"],
+ deps = [
+ "//benchmarks/harness/machine_producers:machine_producer",
+ "//benchmarks/harness/machine_producers:mock_producer",
+ ],
+)
+
+py_test(
+ name = "gcloud_producer_test",
+ srcs = [":gcloud_producer_test_lib"],
+ data = [
+ ":test_data",
+ ],
+ python_version = "PY3",
+ tags = [
+ "local",
+ ],
+)
diff --git a/benchmarks/harness/machine_producers/gcloud_producer.py b/benchmarks/harness/machine_producers/gcloud_producer.py
new file mode 100644
index 000000000..e0b77d52b
--- /dev/null
+++ b/benchmarks/harness/machine_producers/gcloud_producer.py
@@ -0,0 +1,268 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""A machine producer which produces machine objects using `gcloud`.
+
+Machine producers produce valid harness.Machine objects which are backed by
+real machines. This producer produces those machines on the given user's GCP
+account using the `gcloud` tool.
+
+GCloudProducer creates instances on the given GCP account named like:
+`machine-XXXXXXX-XXXX-XXXX-XXXXXXXXXXXX` in a randomized fashion such that name
+collisions with user instances shouldn't happen.
+
+ Typical usage example:
+
+ producer = GCloudProducer(args)
+ machines = producer.get_machines(NUM_MACHINES)
+ # run stuff on machines with machines[i].run(CMD)
+ producer.release_machines(NUM_MACHINES)
+"""
+import datetime
+import json
+import subprocess
+import threading
+from typing import List, Dict, Any
+import uuid
+
+from benchmarks.harness import machine
+from benchmarks.harness.machine_producers import gcloud_mock_recorder
+from benchmarks.harness.machine_producers import machine_producer
+
+
+class GCloudProducer(machine_producer.MachineProducer):
+ """Implementation of MachineProducer backed by GCP.
+
+ Produces Machine objects backed by GCP instances.
+
+ Attributes:
+ project: The GCP project name under which to create the machines.
+ ssh_key_file: path to a valid ssh private key. See README on vaild ssh keys.
+ image: image name as a string.
+ image_project: image project as a string.
+ machine_type: type of GCP to create. e.g. n1-standard-4
+ zone: string to a valid GCP zone.
+ ssh_user: string of user name for ssh_key
+ ssh_password: string of password for ssh key
+ mock: a mock printer which will print mock data if required. Mock data is
+ recorded output from subprocess calls (returncode, stdout, args).
+ condition: mutex for this class around machine creation and deleteion.
+ """
+
+ def __init__(self,
+ project: str,
+ ssh_key_file: str,
+ image: str,
+ image_project: str,
+ machine_type: str,
+ zone: str,
+ ssh_user: str,
+ ssh_password: str,
+ mock: gcloud_mock_recorder.MockPrinter = None):
+ self.project = project
+ self.ssh_key_file = ssh_key_file
+ self.image = image
+ self.image_project = image_project
+ self.machine_type = machine_type
+ self.zone = zone
+ self.ssh_user = ssh_user
+ self.ssh_password = ssh_password
+ self.mock = mock
+ self.condition = threading.Condition()
+
+ def get_machines(self, num_machines: int) -> List[machine.Machine]:
+ """Returns requested number of machines backed by GCP instances."""
+ if num_machines <= 0:
+ raise ValueError(
+ "Cannot ask for {num} machines!".format(num=num_machines))
+ with self.condition:
+ names = self._get_unique_names(num_machines)
+ self._build_instances(names)
+ instances = self._start_command(names)
+ self._add_ssh_key_to_instances(names)
+ return self._machines_from_instances(instances)
+
+ def release_machines(self, machine_list: List[machine.Machine]):
+ """Releases the requested number of machines, deleting the instances."""
+ if not machine_list:
+ return
+ cmd = "gcloud compute instances delete --quiet".split(" ")
+ names = [str(m) for m in machine_list]
+ cmd.extend(names)
+ cmd.append("--zone={zone}".format(zone=self.zone))
+ self._run_command(cmd, detach=True)
+
+ def _machines_from_instances(
+ self, instances: List[Dict[str, Any]]) -> List[machine.Machine]:
+ """Creates Machine Objects from json data describing created instances."""
+ machines = []
+ for instance in instances:
+ name = instance["name"]
+ kwargs = {
+ "hostname":
+ instance["networkInterfaces"][0]["accessConfigs"][0]["natIP"],
+ "key_path":
+ self.ssh_key_file,
+ "username":
+ self.ssh_user,
+ "key_password":
+ self.ssh_password
+ }
+ machines.append(machine.RemoteMachine(name=name, **kwargs))
+ return machines
+
+ def _get_unique_names(self, num_names) -> List[str]:
+ """Returns num_names unique names based on data from the GCP project."""
+ curr_machines = self._list_machines()
+ curr_names = set([machine["name"] for machine in curr_machines])
+ ret = []
+ while len(ret) < num_names:
+ new_name = "machine-" + str(uuid.uuid4())
+ if new_name not in curr_names:
+ ret.append(new_name)
+ curr_names.update(new_name)
+ return ret
+
+ def _build_instances(self, names: List[str]) -> List[Dict[str, Any]]:
+ """Creates instances using gcloud command.
+
+ Runs the command `gcloud compute instances create` and returns json data
+ on created instances on success. Creates len(names) instances, one for each
+ name.
+
+ Args:
+ names: list of names of instances to create.
+
+ Returns:
+ List of json data describing created machines.
+ """
+ if not names:
+ raise ValueError(
+ "_build_instances cannot create instances without names.")
+ cmd = "gcloud compute instances create".split(" ")
+ cmd.extend(names)
+ cmd.extend(
+ "--preemptible --image={image} --zone={zone} --machine-type={machine_type}"
+ .format(
+ image=self.image, zone=self.zone,
+ machine_type=self.machine_type).split(" "))
+ if self.image_project:
+ cmd.append("--image-project={project}".format(project=self.image_project))
+ res = self._run_command(cmd)
+ return json.loads(res.stdout)
+
+ def _start_command(self, names):
+ """Starts instances using gcloud command.
+
+ Runs the command `gcloud compute instances start` on list of instances by
+ name and returns json data on started instances on success.
+
+ Args:
+ names: list of names of instances to start.
+
+ Returns:
+ List of json data describing started machines.
+ """
+ if not names:
+ raise ValueError("_start_command cannot start empty instance list.")
+ cmd = "gcloud compute instances start".split(" ")
+ cmd.extend(names)
+ cmd.append("--zone={zone}".format(zone=self.zone))
+ cmd.append("--project={project}".format(project=self.project))
+ res = self._run_command(cmd)
+ return json.loads(res.stdout)
+
+ def _add_ssh_key_to_instances(self, names: List[str]) -> None:
+ """Adds ssh key to instances by calling gcloud ssh command.
+
+ Runs the command `gcloud compute ssh instance_name` on list of images by
+ name. Tries to ssh into given instance
+
+ Args:
+ names: list of machine names to which to add the ssh-key
+ self.ssh_key_file.
+
+ Raises:
+ subprocess.CalledProcessError: when underlying subprocess call returns an
+ error other than 255 (Connection closed by remote host).
+ TimeoutError: when 3 unsuccessful tries to ssh into the host return 255.
+ """
+ for name in names:
+ cmd = "gcloud compute ssh {name}".format(name=name).split(" ")
+ cmd.append("--ssh-key-file={key}".format(key=self.ssh_key_file))
+ cmd.append("--zone={zone}".format(zone=self.zone))
+ cmd.append("--command=uname")
+ timeout = datetime.timedelta(seconds=5 * 60)
+ start = datetime.datetime.now()
+ while datetime.datetime.now() <= timeout + start:
+ try:
+ self._run_command(cmd)
+ break
+ except subprocess.CalledProcessError as e:
+ if datetime.datetime.now() > timeout + start:
+ raise TimeoutError(
+ "Could not SSH into instance after 5 min: {name}".format(
+ name=name))
+ # 255 is the returncode for ssh connection refused.
+ elif e.returncode == 255:
+
+ continue
+ else:
+ raise e
+
+ def _list_machines(self) -> List[Dict[str, Any]]:
+ """Runs `list` gcloud command and returns list of Machine data."""
+ cmd = "gcloud compute instances list --project {project}".format(
+ project=self.project).split(" ")
+ res = self._run_command(cmd)
+ return json.loads(res.stdout)
+
+ def _run_command(self,
+ cmd: List[str],
+ detach: bool = False) -> [None, subprocess.CompletedProcess]:
+ """Runs command as a subprocess.
+
+ Runs command as subprocess and returns the result.
+ If this has a mock recorder, use the record method to record the subprocess
+ call.
+
+ Args:
+ cmd: command to be run as a list of strings.
+ detach: if True, run the child process and don't wait for it to return.
+
+ Returns:
+ Completed process object to be parsed by caller or None if detach=True.
+
+ Raises:
+ CalledProcessError: if subprocess.run returns an error.
+ """
+ cmd = cmd + ["--format=json"]
+ if detach:
+ p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ if self.mock:
+ out, _ = p.communicate()
+ self.mock.record(
+ subprocess.CompletedProcess(
+ returncode=p.returncode, stdout=out, args=p.args))
+ return
+
+ res = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ if self.mock:
+ self.mock.record(res)
+ if res.returncode != 0:
+ raise subprocess.CalledProcessError(
+ cmd=res.args,
+ output=res.stdout,
+ stderr=res.stderr,
+ returncode=res.returncode)
+ return res
diff --git a/benchmarks/harness/machine_producers/gcloud_producer_test.py b/benchmarks/harness/machine_producers/gcloud_producer_test.py
new file mode 100644
index 000000000..c8adb2bdc
--- /dev/null
+++ b/benchmarks/harness/machine_producers/gcloud_producer_test.py
@@ -0,0 +1,48 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests GCloudProducer using mock data.
+
+GCloudProducer produces machines using 'get_machines' and 'release_machines'
+methods. The tests check recorded data (jsonified subprocess.CompletedProcess
+objects) of the producer producing one and five machines.
+"""
+import os
+import types
+
+from benchmarks.harness.machine_producers import machine_producer
+from benchmarks.harness.machine_producers import mock_producer
+
+TEST_DIR = os.path.dirname(__file__)
+
+
+def run_get_release(producer: machine_producer.MachineProducer,
+ num_machines: int,
+ validator: types.FunctionType = None):
+ machines = producer.get_machines(num_machines)
+ assert len(machines) == num_machines
+ if validator:
+ validator(machines=machines, cmd="uname -a", workload=None)
+ producer.release_machines(machines)
+
+
+def test_run_one():
+ mock = mock_producer.MockReader(TEST_DIR + "get_one.json")
+ producer = mock_producer.MockGCloudProducer(mock)
+ run_get_release(producer, 1)
+
+
+def test_run_five():
+ mock = mock_producer.MockReader(TEST_DIR + "get_five.json")
+ producer = mock_producer.MockGCloudProducer(mock)
+ run_get_release(producer, 5)
diff --git a/benchmarks/harness/machine_producers/machine_producer.py b/benchmarks/harness/machine_producers/machine_producer.py
index 124ee14cc..f5591c026 100644
--- a/benchmarks/harness/machine_producers/machine_producer.py
+++ b/benchmarks/harness/machine_producers/machine_producer.py
@@ -13,6 +13,7 @@
# limitations under the License.
"""Abstract types."""
+import threading
from typing import List
from benchmarks.harness import machine
@@ -28,3 +29,23 @@ class MachineProducer:
def release_machines(self, machine_list: List[machine.Machine]):
"""Releases the given set of machines."""
raise NotImplementedError
+
+
+class LocalMachineProducer(MachineProducer):
+ """Produces Local Machines."""
+
+ def __init__(self, limit: int):
+ self.limit_sem = threading.Semaphore(value=limit)
+
+ def get_machines(self, num_machines: int) -> List[machine.Machine]:
+ """Returns the request number of MockMachines."""
+
+ self.limit_sem.acquire()
+ return [machine.LocalMachine("local") for _ in range(num_machines)]
+
+ def release_machines(self, machine_list: List[machine.MockMachine]):
+ """No-op."""
+ if not machine_list:
+ raise ValueError("Cannot release an empty list!")
+ self.limit_sem.release()
+ machine_list.clear()
diff --git a/benchmarks/harness/machine_producers/mock_producer.py b/benchmarks/harness/machine_producers/mock_producer.py
index 4f29ad53f..37e9cb4b7 100644
--- a/benchmarks/harness/machine_producers/mock_producer.py
+++ b/benchmarks/harness/machine_producers/mock_producer.py
@@ -13,9 +13,11 @@
# limitations under the License.
"""Producers of mocks."""
-from typing import List
+from typing import List, Any
from benchmarks.harness import machine
+from benchmarks.harness.machine_producers import gcloud_mock_recorder
+from benchmarks.harness.machine_producers import gcloud_producer
from benchmarks.harness.machine_producers import machine_producer
@@ -29,3 +31,22 @@ class MockMachineProducer(machine_producer.MachineProducer):
def release_machines(self, machine_list: List[machine.MockMachine]):
"""No-op."""
return
+
+
+class MockGCloudProducer(gcloud_producer.GCloudProducer):
+ """Mocks GCloudProducer for testing purposes."""
+
+ def __init__(self, mock: gcloud_mock_recorder.MockReader, **kwargs):
+ gcloud_producer.GCloudProducer.__init__(
+ self, project="mock", ssh_private_key_path="mock", **kwargs)
+ self.mock = mock
+
+ def _validate_ssh_file(self):
+ pass
+
+ def _run_command(self, cmd):
+ return self.mock.pop(cmd)
+
+ def _machines_from_instances(
+ self, instances: List[Any]) -> List[machine.MockMachine]:
+ return [machine.MockMachine() for _ in instances]
diff --git a/benchmarks/harness/machine_producers/testdata/get_five.json b/benchmarks/harness/machine_producers/testdata/get_five.json
new file mode 100644
index 000000000..32bad1b06
--- /dev/null
+++ b/benchmarks/harness/machine_producers/testdata/get_five.json
@@ -0,0 +1,211 @@
+[
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "list",
+ "--project",
+ "project",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":{\"natIP\":\"0.0.0.0\"}]}]}]",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "create",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92",
+ "machine-da5859b5-bae6-435d-8005-0202d6f6e065",
+ "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05",
+ "machine-1149147d-71e2-43ea-8fe1-49256e5c441c",
+ "--preemptible",
+ "--image=ubuntu-1910-eoan-v20191204",
+ "--zone=us-west1-b",
+ "--image-project=ubuntu-os-cloud",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "start",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92",
+ "machine-da5859b5-bae6-435d-8005-0202d6f6e065",
+ "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05",
+ "machine-1149147d-71e2-43ea-8fe1-49256e5c441c",
+ "--zone=us-west1-b",
+ "--project=project",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]},{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-da5859b5-bae6-435d-8005-0202d6f6e065",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-1149147d-71e2-43ea-8fe1-49256e5c441c",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "delete",
+ "--quiet",
+ "machine-42c9bf6e-8d45-4c37-b1c0-7e4fdcf530fc",
+ "machine-5f28f145-cc2d-427d-9cbf-428d164cdb92",
+ "machine-da5859b5-bae6-435d-8005-0202d6f6e065",
+ "machine-880a8a2f-918c-4f9e-a43c-ed3c8e02ea05",
+ "machine-1149147d-71e2-43ea-8fe1-49256e5c441c",
+ "--zone=us-west1-b",
+ "--format=json"
+ ],
+ "stdout": "[]\n",
+ "returncode": "0"
+ }
+]
diff --git a/benchmarks/harness/machine_producers/testdata/get_one.json b/benchmarks/harness/machine_producers/testdata/get_one.json
new file mode 100644
index 000000000..c359c19c8
--- /dev/null
+++ b/benchmarks/harness/machine_producers/testdata/get_one.json
@@ -0,0 +1,145 @@
+[
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "list",
+ "--project",
+ "linux-testing-user",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]",
+
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "create",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--preemptible",
+ "--image=ubuntu-1910-eoan-v20191204",
+ "--zone=us-west1-b",
+ "--image-project=ubuntu-os-cloud",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "start",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--zone=us-west1-b",
+ "--project=linux-testing-user",
+ "--format=json"
+ ],
+ "stdout": "[{\"name\":\"name\", \"networkInterfaces\":[{\"accessConfigs\":[{\"natIP\":\"0.0.0.0\"}]}]}]",
+
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "",
+ "returncode": "255"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "ssh",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--ssh-key-file=/usr/local/google/home/user/.ssh/benchmark-tools",
+ "--zone=us-west1-b",
+ "--command=uname",
+ "--format=json"
+ ],
+ "stdout": "Linux\n[]\n",
+ "returncode": "0"
+ },
+ {
+ "args": [
+ "gcloud",
+ "compute",
+ "instances",
+ "delete",
+ "--quiet",
+ "machine-129dfcf9-b05b-4c16-a4cd-21353b570ddc",
+ "--zone=us-west1-b",
+ "--format=json"
+ ],
+ "stdout": "[]\n",
+ "returncode": "0"
+ }
+]
diff --git a/benchmarks/harness/ssh_connection.py b/benchmarks/harness/ssh_connection.py
index fcbfbcdb2..e0bf258f1 100644
--- a/benchmarks/harness/ssh_connection.py
+++ b/benchmarks/harness/ssh_connection.py
@@ -94,7 +94,7 @@ class SSHConnection:
return stdout, stderr
def send_workload(self, name: str) -> str:
- """Sends a workload to the remote machine.
+ """Sends a workload tarball to the remote machine.
Args:
name: The workload name.
@@ -103,9 +103,6 @@ class SSHConnection:
The remote path.
"""
with self._client() as client:
- for dirpath, _, filenames in os.walk(
- harness.LOCAL_WORKLOADS_PATH.format(name)):
- for filename in filenames:
- send_one_file(client, os.path.join(dirpath, filename),
- harness.REMOTE_WORKLOADS_PATH.format(name))
+ send_one_file(client, harness.LOCAL_WORKLOADS_PATH.format(name),
+ harness.REMOTE_WORKLOADS_PATH.format(name))
return harness.REMOTE_WORKLOADS_PATH.format(name)
diff --git a/benchmarks/runner/BUILD b/benchmarks/runner/BUILD
index de24824cc..e1b2ea550 100644
--- a/benchmarks/runner/BUILD
+++ b/benchmarks/runner/BUILD
@@ -10,7 +10,9 @@ py_library(
],
visibility = ["//benchmarks:__pkg__"],
deps = [
+ ":commands",
"//benchmarks/harness:benchmark_driver",
+ "//benchmarks/harness/machine_producers:machine_producer",
"//benchmarks/harness/machine_producers:mock_producer",
"//benchmarks/harness/machine_producers:yaml_producer",
"//benchmarks/suites",
@@ -30,6 +32,14 @@ py_library(
],
)
+py_library(
+ name = "commands",
+ srcs = ["commands.py"],
+ deps = [
+ requirement("click", True),
+ ],
+)
+
py_test(
name = "runner_test",
srcs = ["runner_test.py"],
diff --git a/benchmarks/runner/__init__.py b/benchmarks/runner/__init__.py
index 9bf9cfd65..ba80d83d7 100644
--- a/benchmarks/runner/__init__.py
+++ b/benchmarks/runner/__init__.py
@@ -15,10 +15,13 @@
import copy
import csv
+import json
import logging
+import os
import pkgutil
import pydoc
import re
+import subprocess
import sys
import types
from typing import List
@@ -26,10 +29,14 @@ from typing import Tuple
import click
+from benchmarks import harness
from benchmarks import suites
from benchmarks.harness import benchmark_driver
+from benchmarks.harness.machine_producers import gcloud_producer
+from benchmarks.harness.machine_producers import machine_producer
from benchmarks.harness.machine_producers import mock_producer
from benchmarks.harness.machine_producers import yaml_producer
+from benchmarks.runner import commands
@click.group()
@@ -100,30 +107,77 @@ def list_all(method):
print("\n")
-# pylint: disable=too-many-arguments
-# pylint: disable=too-many-branches
-# pylint: disable=too-many-locals
-@runner.command(
- context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
+@runner.command("run-local", commands.LocalCommand)
@click.pass_context
-@click.argument("method")
-@click.option("--mock/--no-mock", default=False, help="Mock the machines.")
-@click.option("--env", default=None, help="Specify a yaml file with machines.")
-@click.option(
- "--runtime", default=["runc"], help="The runtime to use.", multiple=True)
-@click.option("--metric", help="The metric to extract.", multiple=True)
-@click.option(
- "--runs", default=1, help="The number of times to run each benchmark.")
-@click.option(
- "--stat",
- default="median",
- help="How to aggregate the data from all runs."
- "\nmedian - returns the median of all runs (default)"
- "\nall - returns all results comma separated"
- "\nmeanstd - returns result as mean,std")
-# pylint: disable=too-many-statements
-def run(ctx, method: str, runs: int, env: str, mock: bool, runtime: List[str],
- metric: List[str], stat: str, **kwargs):
+def run_local(ctx, limit: float, **kwargs):
+ """Runs benchmarks locally."""
+ run(ctx, machine_producer.LocalMachineProducer(limit=limit), **kwargs)
+
+
+@runner.command("run-mock", commands.RunCommand)
+@click.pass_context
+def run_mock(ctx, **kwargs):
+ """Runs benchmarks on Mock machines. Used for testing."""
+ run(ctx, mock_producer.MockMachineProducer(), **kwargs)
+
+
+@runner.command("run-gcp", commands.GCPCommand)
+@click.pass_context
+def run_gcp(ctx, project: str, ssh_key_file: str, image: str,
+ image_project: str, machine_type: str, zone: str, ssh_user: str,
+ ssh_password: str, **kwargs):
+ """Runs all benchmarks on GCP instances."""
+
+ if not ssh_user:
+ ssh_user = harness.DEFAULT_USER
+
+ # Get the default project if one was not provided.
+ if not project:
+ sub = subprocess.run(
+ "gcloud config get-value project".split(" "), stdout=subprocess.PIPE)
+ if sub.returncode:
+ raise ValueError(
+ "Cannot get default project from gcloud. Is it configured>")
+ project = sub.stdout.decode("utf-8").strip("\n")
+
+ if not image_project:
+ image_project = project
+
+ # Check that the ssh-key exists and is readable.
+ if not os.access(ssh_key_file, os.R_OK):
+ raise ValueError(
+ "ssh key given `{ssh_key}` is does not exist or is not readable."
+ .format(ssh_key=ssh_key_file))
+
+ # Check that the image exists.
+ sub = subprocess.run(
+ "gcloud compute images describe {image} --project {image_project} --format=json"
+ .format(image=image, image_project=image_project).split(" "),
+ stdout=subprocess.PIPE)
+ if sub.returncode or "READY" not in json.loads(sub.stdout)["status"]:
+ raise ValueError(
+ "given image was not found or is not ready: {image} {image_project}."
+ .format(image=image, image_project=image_project))
+
+ # Check and set zone to default.
+ if not zone:
+ sub = subprocess.run(
+ "gcloud config get-value compute/zone".split(" "),
+ stdout=subprocess.PIPE)
+ if sub.returncode:
+ raise ValueError(
+ "Default zone is not set in gcloud. Set one or pass a zone with the --zone flag."
+ )
+ zone = sub.stdout.decode("utf-8").strip("\n")
+
+ producer = gcloud_producer.GCloudProducer(project, ssh_key_file, image,
+ image_project, machine_type, zone,
+ ssh_user, ssh_password)
+ run(ctx, producer, **kwargs)
+
+
+def run(ctx, producer: machine_producer.MachineProducer, method: str, runs: int,
+ runtime: List[str], metric: List[str], stat: str, **kwargs):
"""Runs arbitrary benchmarks.
All unknown command line flags are passed through to the underlying benchmark
@@ -139,16 +193,13 @@ def run(ctx, method: str, runs: int, env: str, mock: bool, runtime: List[str],
All benchmarks are run in parallel where possible, but have exclusive
ownership over the individual machines.
- Exactly one of the --mock and --env flag must be specified.
-
Every benchmark method will be run the times indicated by --runs.
Args:
ctx: Click context.
+ producer: A Machine Producer from which to get Machines.
method: A regular expression for methods to be run.
runs: Number of runs.
- env: Environment to use.
- mock: If true, use mocked environment (supercedes env).
runtime: A list of runtimes to test.
metric: A list of metrics to extract.
stat: The class of statistics to extract.
@@ -218,20 +269,6 @@ def run(ctx, method: str, runs: int, env: str, mock: bool, runtime: List[str],
sys.exit(1)
fold("method", list(methods.keys()), allow_flatten=True)
- # Construct the environment.
- if mock and env:
- # You can't provide both.
- logging.error("both --mock and --env are set: which one is it?")
- sys.exit(1)
- elif mock:
- producer = mock_producer.MockMachineProducer()
- elif env:
- producer = yaml_producer.YamlMachineProducer(env)
- else:
- # You must provide one of mock or env.
- logging.error("no enviroment provided: use --mock or --env.")
- sys.exit(1)
-
# Spin up the drivers.
#
# We ensure that metric is the last entry, because we have special behavior.
diff --git a/benchmarks/runner/commands.py b/benchmarks/runner/commands.py
new file mode 100644
index 000000000..7ab12fac6
--- /dev/null
+++ b/benchmarks/runner/commands.py
@@ -0,0 +1,135 @@
+# python3
+# Copyright 2019 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Module with the guts of `click` commands.
+
+Overrides of the click.core.Command. This is done so flags are inherited between
+similar commands (the run command). The classes below are meant to be used in
+click templates like so.
+
+@runner.command("run-mock", RunCommand)
+def run_mock(**kwargs):
+ # mock implementation
+
+"""
+import click
+
+from benchmarks import harness
+
+
+class RunCommand(click.core.Command):
+ """Base Run Command with flags.
+
+ Attributes:
+ method: regex of which suite to choose (e.g. sysbench would run
+ sysbench.cpu, sysbench.memory, and sysbench.mutex) See list command for
+ details.
+ metric: metric(s) to extract. See list command for details.
+ runtime: the runtime(s) on which to run.
+ runs: the number of runs to do of each method.
+ stat: how to compile results in the case of multiple run (e.g. median).
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ method = click.core.Argument(("method",))
+
+ metric = click.core.Option(("--metric",),
+ help="The metric to extract.",
+ multiple=True)
+
+ runtime = click.core.Option(("--runtime",),
+ default=["runc"],
+ help="The runtime to use.",
+ multiple=True)
+ runs = click.core.Option(("--runs",),
+ default=1,
+ help="The number of times to run each benchmark.")
+ stat = click.core.Option(
+ ("--stat",),
+ default="median",
+ help="How to aggregate the data from all runs."
+ "\nmedian - returns the median of all runs (default)"
+ "\nall - returns all results comma separated"
+ "\nmeanstd - returns result as mean,std")
+ self.params.extend([method, runtime, runs, stat, metric])
+ self.ignore_unknown_options = True
+ self.allow_extra_args = True
+
+
+class LocalCommand(RunCommand):
+ """LocalCommand inherits all flags from RunCommand.
+
+ Attributes:
+ limit: limits the number of machines on which to run benchmarks. This limits
+ for local how many benchmarks may run at a time. e.g. "startup" requires
+ one machine -- passing two machines would limit two startup jobs at a
+ time. Default is infinity.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.params.append(
+ click.core.Option(
+ ("--limit",),
+ default=1,
+ help="Limit of number of benchmarks that can run at a given time."))
+
+
+class GCPCommand(RunCommand):
+ """GCPCommand inherits all flags from RunCommand and adds flags for run_gcp method.
+
+ Attributes:
+ project: GCP project
+ ssh_key_path: path to the ssh-key to use for the run
+ image: name of the image to build machines from
+ image_project: GCP project under which to find image
+ zone: a GCP zone (e.g. us-west1-b)
+ ssh_user: username to use for the ssh-key
+ ssh_password: password to use for the ssh-key
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ project = click.core.Option(
+ ("--project",),
+ help="Project to run on if not default value given by 'gcloud config get-value project'."
+ )
+ ssh_key_path = click.core.Option(
+ ("--ssh-key-file",),
+ help="Path to a valid ssh private key to use. See README on generating a valid ssh key. Set to ~/.ssh/benchmark-tools by default.",
+ default=harness.DEFAULT_USER_HOME + "/.ssh/benchmark-tools")
+ image = click.core.Option(("--image",),
+ help="The image on which to build VMs.",
+ default="bm-tools-testing")
+ image_project = click.core.Option(
+ ("--image_project",),
+ help="The project under which the image to be used is listed.",
+ default="")
+ machine_type = click.core.Option(("--machine_type",),
+ help="Type to make all machines.",
+ default="n1-standard-4")
+ zone = click.core.Option(("--zone",),
+ help="The GCP zone to run on.",
+ default="")
+ ssh_user = click.core.Option(("--ssh-user",),
+ help="User for the ssh key.",
+ default=harness.DEFAULT_USER)
+ ssh_password = click.core.Option(("--ssh-password",),
+ help="Password for the ssh key.",
+ default="")
+ self.params.extend([
+ project, ssh_key_path, image, image_project, machine_type, zone,
+ ssh_user, ssh_password
+ ])
diff --git a/benchmarks/runner/runner_test.py b/benchmarks/runner/runner_test.py
index 5719c2838..7818d631a 100644
--- a/benchmarks/runner/runner_test.py
+++ b/benchmarks/runner/runner_test.py
@@ -49,7 +49,7 @@ def test_list():
def test_run():
cli_runner = testing.CliRunner()
- result = cli_runner.invoke(runner.runner, ["run", "--mock", "."])
+ result = cli_runner.invoke(runner.runner, ["run-mock", "."])
print(result.output)
assert result.exit_code == 0
diff --git a/benchmarks/suites/http.py b/benchmarks/suites/http.py
index ea9024e43..6efea938c 100644
--- a/benchmarks/suites/http.py
+++ b/benchmarks/suites/http.py
@@ -92,7 +92,7 @@ def http_app(server: machine.Machine,
redis = server.pull("redis")
image = server.pull(workload)
redis_port = 6379
- redis_name = "redis_server"
+ redis_name = "{workload}_redis_server".format(workload=workload)
with server.container(redis, name=redis_name).detach():
server.container(server_netcat, links={redis_name: redis_name})\
diff --git a/benchmarks/tcp/tcp_benchmark.sh b/benchmarks/tcp/tcp_benchmark.sh
index 69344c9c3..e65801a7b 100755
--- a/benchmarks/tcp/tcp_benchmark.sh
+++ b/benchmarks/tcp/tcp_benchmark.sh
@@ -41,6 +41,8 @@ duplicate=0.1 # 0.1% means duplicates are 1/10x as frequent as losses.
duration=30 # 30s is enough time to consistent results (experimentally).
helper_dir=$(dirname $0)
netstack_opts=
+disable_linux_gso=
+num_client_threads=1
# Check for netem support.
lsmod_output=$(lsmod | grep sch_netem)
@@ -125,6 +127,13 @@ while [ $# -gt 0 ]; do
shift
netstack_opts="${netstack_opts} -memprofile=$1"
;;
+ --disable-linux-gso)
+ disable_linux_gso=1
+ ;;
+ --num-client-threads)
+ shift
+ num_client_threads=$1
+ ;;
--helpers)
shift
[ "$#" -le 0 ] && echo "no helper dir provided" && exit 1
@@ -147,6 +156,8 @@ while [ $# -gt 0 ]; do
echo " --loss set the loss probability (%)"
echo " --duplicate set the duplicate probability (%)"
echo " --helpers set the helper directory"
+ echo " --num-client-threads number of parallel client threads to run"
+ echo " --disable-linux-gso disable segmentation offload in the Linux network stack"
echo ""
echo "The output will of the script will be:"
echo " <throughput> <client-cpu-usage> <server-cpu-usage>"
@@ -301,6 +312,14 @@ fi
# Add client and server addresses, and bring everything up.
${nsjoin_binary} /tmp/client.netns ip addr add ${client_addr}/${mask} dev client.0
${nsjoin_binary} /tmp/server.netns ip addr add ${server_addr}/${mask} dev server.0
+if [ "${disable_linux_gso}" == "1" ]; then
+ ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 tso off
+ ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 gro off
+ ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 gso off
+ ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 tso off
+ ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 gso off
+ ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 gro off
+fi
${nsjoin_binary} /tmp/client.netns ip link set client.0 up
${nsjoin_binary} /tmp/client.netns ip link set lo up
${nsjoin_binary} /tmp/server.netns ip link set server.0 up
@@ -338,7 +357,7 @@ trap cleanup EXIT
# Run the benchmark, recording the results file.
while ${nsjoin_binary} /tmp/client.netns iperf \\
- -p ${proxy_port} -c ${client_addr} -t ${duration} -f m 2>&1 \\
+ -p ${proxy_port} -c ${client_addr} -t ${duration} -f m -P ${num_client_threads} 2>&1 \\
| tee \$results_file \\
| grep "connect failed" >/dev/null; do
sleep 0.1 # Wait for all services.
diff --git a/benchmarks/tcp/tcp_proxy.go b/benchmarks/tcp/tcp_proxy.go
index 361a56755..72ada5700 100644
--- a/benchmarks/tcp/tcp_proxy.go
+++ b/benchmarks/tcp/tcp_proxy.go
@@ -84,8 +84,8 @@ func (netImpl) printStats() {
}
const (
- nicID = 1 // Fixed.
- rcvBufSize = 1 << 20 // 1MB.
+ nicID = 1 // Fixed.
+ bufSize = 4 << 20 // 4MB.
)
type netstackImpl struct {
@@ -94,11 +94,11 @@ type netstackImpl struct {
mode string
}
-func setupNetwork(ifaceName string) (fd int, err error) {
+func setupNetwork(ifaceName string, numChannels int) (fds []int, err error) {
// Get all interfaces in the namespace.
ifaces, err := net.Interfaces()
if err != nil {
- return -1, fmt.Errorf("querying interfaces: %v", err)
+ return nil, fmt.Errorf("querying interfaces: %v", err)
}
for _, iface := range ifaces {
@@ -107,39 +107,47 @@ func setupNetwork(ifaceName string) (fd int, err error) {
}
// Create the socket.
const protocol = 0x0300 // htons(ETH_P_ALL)
- fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW, protocol)
- if err != nil {
- return -1, fmt.Errorf("unable to create raw socket: %v", err)
- }
+ fds := make([]int, numChannels)
+ for i := range fds {
+ fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW, protocol)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create raw socket: %v", err)
+ }
- // Bind to the appropriate device.
- ll := syscall.SockaddrLinklayer{
- Protocol: protocol,
- Ifindex: iface.Index,
- Pkttype: syscall.PACKET_HOST,
- }
- if err := syscall.Bind(fd, &ll); err != nil {
- return -1, fmt.Errorf("unable to bind to %q: %v", iface.Name, err)
- }
+ // Bind to the appropriate device.
+ ll := syscall.SockaddrLinklayer{
+ Protocol: protocol,
+ Ifindex: iface.Index,
+ Pkttype: syscall.PACKET_HOST,
+ }
+ if err := syscall.Bind(fd, &ll); err != nil {
+ return nil, fmt.Errorf("unable to bind to %q: %v", iface.Name, err)
+ }
- // RAW Sockets by default have a very small SO_RCVBUF of 256KB,
- // up it to at least 1MB to reduce packet drops.
- if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, rcvBufSize); err != nil {
- return -1, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", rcvBufSize, err)
- }
+ // RAW Sockets by default have a very small SO_RCVBUF of 256KB,
+ // up it to at least 4MB to reduce packet drops.
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bufSize); err != nil {
+ return nil, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", bufSize, err)
+ }
- if !*swgso && *gso != 0 {
- if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil {
- return -1, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err)
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bufSize); err != nil {
+ return nil, fmt.Errorf("setsockopt(..., SO_SNDBUF, %v,..) = %v", bufSize, err)
+ }
+
+ if !*swgso && *gso != 0 {
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil {
+ return nil, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err)
+ }
}
+ fds[i] = fd
}
- return fd, nil
+ return fds, nil
}
- return -1, fmt.Errorf("failed to find interface: %v", ifaceName)
+ return nil, fmt.Errorf("failed to find interface: %v", ifaceName)
}
func newNetstackImpl(mode string) (impl, error) {
- fd, err := setupNetwork(*iface)
+ fds, err := setupNetwork(*iface, runtime.GOMAXPROCS(-1))
if err != nil {
return nil, err
}
@@ -177,7 +185,7 @@ func newNetstackImpl(mode string) (impl, error) {
mac[0] &^= 0x1 // Clear multicast bit.
mac[0] |= 0x2 // Set local assignment bit (IEEE802).
ep, err := fdbased.New(&fdbased.Options{
- FDs: []int{fd},
+ FDs: fds,
MTU: uint32(*mtu),
EthernetHeader: true,
Address: tcpip.LinkAddress(mac),
diff --git a/benchmarks/workloads/BUILD b/benchmarks/workloads/BUILD
index 643806105..ccb86af5b 100644
--- a/benchmarks/workloads/BUILD
+++ b/benchmarks/workloads/BUILD
@@ -11,25 +11,25 @@ py_library(
filegroup(
name = "files",
srcs = [
- "//benchmarks/workloads/ab:files",
- "//benchmarks/workloads/absl:files",
- "//benchmarks/workloads/curl:files",
- "//benchmarks/workloads/ffmpeg:files",
- "//benchmarks/workloads/fio:files",
- "//benchmarks/workloads/httpd:files",
- "//benchmarks/workloads/iperf:files",
- "//benchmarks/workloads/netcat:files",
- "//benchmarks/workloads/nginx:files",
- "//benchmarks/workloads/node:files",
- "//benchmarks/workloads/node_template:files",
- "//benchmarks/workloads/redis:files",
- "//benchmarks/workloads/redisbenchmark:files",
- "//benchmarks/workloads/ruby:files",
- "//benchmarks/workloads/ruby_template:files",
- "//benchmarks/workloads/sleep:files",
- "//benchmarks/workloads/sysbench:files",
- "//benchmarks/workloads/syscall:files",
- "//benchmarks/workloads/tensorflow:files",
- "//benchmarks/workloads/true:files",
+ "//benchmarks/workloads/ab:tar",
+ "//benchmarks/workloads/absl:tar",
+ "//benchmarks/workloads/curl:tar",
+ "//benchmarks/workloads/ffmpeg:tar",
+ "//benchmarks/workloads/fio:tar",
+ "//benchmarks/workloads/httpd:tar",
+ "//benchmarks/workloads/iperf:tar",
+ "//benchmarks/workloads/netcat:tar",
+ "//benchmarks/workloads/nginx:tar",
+ "//benchmarks/workloads/node:tar",
+ "//benchmarks/workloads/node_template:tar",
+ "//benchmarks/workloads/redis:tar",
+ "//benchmarks/workloads/redisbenchmark:tar",
+ "//benchmarks/workloads/ruby:tar",
+ "//benchmarks/workloads/ruby_template:tar",
+ "//benchmarks/workloads/sleep:tar",
+ "//benchmarks/workloads/sysbench:tar",
+ "//benchmarks/workloads/syscall:tar",
+ "//benchmarks/workloads/tensorflow:tar",
+ "//benchmarks/workloads/true:tar",
],
)
diff --git a/benchmarks/workloads/ab/BUILD b/benchmarks/workloads/ab/BUILD
index e99a8d674..4fc0ab735 100644
--- a/benchmarks/workloads/ab/BUILD
+++ b/benchmarks/workloads/ab/BUILD
@@ -1,4 +1,5 @@
load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement")
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -27,8 +28,8 @@ py_test(
],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/absl/BUILD b/benchmarks/workloads/absl/BUILD
index bb499620e..61e010096 100644
--- a/benchmarks/workloads/absl/BUILD
+++ b/benchmarks/workloads/absl/BUILD
@@ -1,4 +1,5 @@
load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement")
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -27,8 +28,8 @@ py_test(
],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/curl/BUILD b/benchmarks/workloads/curl/BUILD
index 83f3c71a0..eb0fb6165 100644
--- a/benchmarks/workloads/curl/BUILD
+++ b/benchmarks/workloads/curl/BUILD
@@ -1,10 +1,12 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/ffmpeg/BUILD b/benchmarks/workloads/ffmpeg/BUILD
index c1f2afc40..be472dfb2 100644
--- a/benchmarks/workloads/ffmpeg/BUILD
+++ b/benchmarks/workloads/ffmpeg/BUILD
@@ -1,3 +1,5 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
@@ -8,8 +10,8 @@ py_library(
srcs = ["__init__.py"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/fio/BUILD b/benchmarks/workloads/fio/BUILD
index 7fc96cfa5..de257adad 100644
--- a/benchmarks/workloads/fio/BUILD
+++ b/benchmarks/workloads/fio/BUILD
@@ -1,4 +1,5 @@
load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement")
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -27,8 +28,8 @@ py_test(
],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/httpd/BUILD b/benchmarks/workloads/httpd/BUILD
index 83f3c71a0..eb0fb6165 100644
--- a/benchmarks/workloads/httpd/BUILD
+++ b/benchmarks/workloads/httpd/BUILD
@@ -1,10 +1,12 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/iperf/BUILD b/benchmarks/workloads/iperf/BUILD
index fe0acbfce..8832a996c 100644
--- a/benchmarks/workloads/iperf/BUILD
+++ b/benchmarks/workloads/iperf/BUILD
@@ -1,4 +1,5 @@
load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement")
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -27,8 +28,8 @@ py_test(
],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/netcat/BUILD b/benchmarks/workloads/netcat/BUILD
index 83f3c71a0..eb0fb6165 100644
--- a/benchmarks/workloads/netcat/BUILD
+++ b/benchmarks/workloads/netcat/BUILD
@@ -1,10 +1,12 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/nginx/BUILD b/benchmarks/workloads/nginx/BUILD
index 83f3c71a0..eb0fb6165 100644
--- a/benchmarks/workloads/nginx/BUILD
+++ b/benchmarks/workloads/nginx/BUILD
@@ -1,10 +1,12 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/node/BUILD b/benchmarks/workloads/node/BUILD
index 59460d02f..71cd9f519 100644
--- a/benchmarks/workloads/node/BUILD
+++ b/benchmarks/workloads/node/BUILD
@@ -1,10 +1,12 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
"index.js",
diff --git a/benchmarks/workloads/node_template/BUILD b/benchmarks/workloads/node_template/BUILD
index ae7f121d3..ca996f068 100644
--- a/benchmarks/workloads/node_template/BUILD
+++ b/benchmarks/workloads/node_template/BUILD
@@ -1,10 +1,12 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
"index.hbs",
diff --git a/benchmarks/workloads/redis/BUILD b/benchmarks/workloads/redis/BUILD
index 83f3c71a0..eb0fb6165 100644
--- a/benchmarks/workloads/redis/BUILD
+++ b/benchmarks/workloads/redis/BUILD
@@ -1,10 +1,12 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/redisbenchmark/BUILD b/benchmarks/workloads/redisbenchmark/BUILD
index d40e75a3a..f5994a815 100644
--- a/benchmarks/workloads/redisbenchmark/BUILD
+++ b/benchmarks/workloads/redisbenchmark/BUILD
@@ -1,4 +1,5 @@
load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement")
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -27,8 +28,8 @@ py_test(
],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/ruby/BUILD b/benchmarks/workloads/ruby/BUILD
index 9846c7e70..e37d77804 100644
--- a/benchmarks/workloads/ruby/BUILD
+++ b/benchmarks/workloads/ruby/BUILD
@@ -1,3 +1,5 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
@@ -13,3 +15,14 @@ filegroup(
"index.rb",
],
)
+
+pkg_tar(
+ name = "tar",
+ srcs = [
+ "Dockerfile",
+ "Gemfile",
+ "Gemfile.lock",
+ "config.ru",
+ "index.rb",
+ ],
+)
diff --git a/benchmarks/workloads/ruby_template/BUILD b/benchmarks/workloads/ruby_template/BUILD
index 2b99892af..27f7c0c46 100644
--- a/benchmarks/workloads/ruby_template/BUILD
+++ b/benchmarks/workloads/ruby_template/BUILD
@@ -1,10 +1,12 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
"Gemfile",
@@ -13,4 +15,5 @@ filegroup(
"index.erb",
"main.rb",
],
+ strip_prefix = "third_party/gvisor/benchmarks/workloads/ruby_template",
)
diff --git a/benchmarks/workloads/sleep/BUILD b/benchmarks/workloads/sleep/BUILD
index 83f3c71a0..eb0fb6165 100644
--- a/benchmarks/workloads/sleep/BUILD
+++ b/benchmarks/workloads/sleep/BUILD
@@ -1,10 +1,12 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/sysbench/BUILD b/benchmarks/workloads/sysbench/BUILD
index 35f4d460b..fd2f8f03d 100644
--- a/benchmarks/workloads/sysbench/BUILD
+++ b/benchmarks/workloads/sysbench/BUILD
@@ -1,4 +1,5 @@
load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement")
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -27,8 +28,8 @@ py_test(
],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/syscall/BUILD b/benchmarks/workloads/syscall/BUILD
index e1ff3059b..5100cbb21 100644
--- a/benchmarks/workloads/syscall/BUILD
+++ b/benchmarks/workloads/syscall/BUILD
@@ -1,4 +1,5 @@
load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement")
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -27,8 +28,8 @@ py_test(
],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
"syscall.c",
diff --git a/benchmarks/workloads/tensorflow/BUILD b/benchmarks/workloads/tensorflow/BUILD
index 17f1f8ebb..026c3b316 100644
--- a/benchmarks/workloads/tensorflow/BUILD
+++ b/benchmarks/workloads/tensorflow/BUILD
@@ -1,3 +1,5 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
@@ -8,8 +10,8 @@ py_library(
srcs = ["__init__.py"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
diff --git a/benchmarks/workloads/true/BUILD b/benchmarks/workloads/true/BUILD
index 83f3c71a0..221c4b9a7 100644
--- a/benchmarks/workloads/true/BUILD
+++ b/benchmarks/workloads/true/BUILD
@@ -1,11 +1,14 @@
+load("@rules_pkg//:pkg.bzl", "pkg_tar")
+
package(
default_visibility = ["//benchmarks:__subpackages__"],
licenses = ["notice"],
)
-filegroup(
- name = "files",
+pkg_tar(
+ name = "tar",
srcs = [
"Dockerfile",
],
+ extension = "tar",
)
diff --git a/kokoro/issue_reviver.cfg b/kokoro/issue_reviver.cfg
new file mode 100644
index 000000000..2370d9250
--- /dev/null
+++ b/kokoro/issue_reviver.cfg
@@ -0,0 +1,15 @@
+build_file: "repo/scripts/issue_reviver.sh"
+
+before_action {
+ fetch_keystore {
+ keystore_resource {
+ keystore_config_id: 73898
+ keyname: "kokoro-github-access-token"
+ }
+ }
+}
+
+env_vars {
+ key: "KOKORO_GITHUB_ACCESS_TOKEN"
+ value: "73898_kokoro-github-access-token"
+}
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index 269ba5567..33fcc6c95 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -42,6 +42,15 @@ const (
NF_RETURN = -NF_REPEAT - 1
)
+// VerdictStrings maps int verdicts to the strings they represent. It is used
+// for debugging.
+var VerdictStrings = map[int32]string{
+ -NF_DROP - 1: "DROP",
+ -NF_ACCEPT - 1: "ACCEPT",
+ -NF_QUEUE - 1: "QUEUE",
+ NF_RETURN: "RETURN",
+}
+
// Socket options. These correspond to values in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
const (
@@ -179,7 +188,7 @@ const SizeOfXTCounters = 16
// the user data.
type XTEntryMatch struct {
MatchSize uint16
- Name [XT_EXTENSION_MAXNAMELEN]byte
+ Name ExtensionName
Revision uint8
// Data is omitted here because it would cause XTEntryMatch to be an
// extra byte larger (see http://www.catb.org/esr/structure-packing/).
@@ -199,7 +208,7 @@ const SizeOfXTEntryMatch = 32
// the user data.
type XTEntryTarget struct {
TargetSize uint16
- Name [XT_EXTENSION_MAXNAMELEN]byte
+ Name ExtensionName
Revision uint8
// Data is omitted here because it would cause XTEntryTarget to be an
// extra byte larger (see http://www.catb.org/esr/structure-packing/).
@@ -226,9 +235,9 @@ const SizeOfXTStandardTarget = 40
// ErrorName. It corresponds to struct xt_error_target in
// include/uapi/linux/netfilter/x_tables.h.
type XTErrorTarget struct {
- Target XTEntryTarget
- ErrorName [XT_FUNCTION_MAXNAMELEN]byte
- _ [2]byte
+ Target XTEntryTarget
+ Name ErrorName
+ _ [2]byte
}
// SizeOfXTErrorTarget is the size of an XTErrorTarget.
@@ -237,7 +246,7 @@ const SizeOfXTErrorTarget = 64
// IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds
// to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h.
type IPTGetinfo struct {
- Name [XT_TABLE_MAXNAMELEN]byte
+ Name TableName
ValidHooks uint32
HookEntry [NF_INET_NUMHOOKS]uint32
Underflow [NF_INET_NUMHOOKS]uint32
@@ -248,16 +257,11 @@ type IPTGetinfo struct {
// SizeOfIPTGetinfo is the size of an IPTGetinfo.
const SizeOfIPTGetinfo = 84
-// TableName returns the table name.
-func (info *IPTGetinfo) TableName() string {
- return tableName(info.Name[:])
-}
-
// IPTGetEntries is the argument for the IPT_SO_GET_ENTRIES sockopt. It
// corresponds to struct ipt_get_entries in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
type IPTGetEntries struct {
- Name [XT_TABLE_MAXNAMELEN]byte
+ Name TableName
Size uint32
_ [4]byte
// Entrytable is omitted here because it would cause IPTGetEntries to
@@ -266,34 +270,22 @@ type IPTGetEntries struct {
// Entrytable [0]IPTEntry
}
-// TableName returns the entries' table name.
-func (entries *IPTGetEntries) TableName() string {
- return tableName(entries.Name[:])
-}
-
// SizeOfIPTGetEntries is the size of an IPTGetEntries.
const SizeOfIPTGetEntries = 40
-// KernelIPTGetEntries is identical to IPTEntry, but includes the Elems field.
-// This struct marshaled via the binary package to write an KernelIPTGetEntries
-// to userspace.
+// KernelIPTGetEntries is identical to IPTGetEntries, but includes the
+// Entrytable field. This struct marshaled via the binary package to write an
+// KernelIPTGetEntries to userspace.
type KernelIPTGetEntries struct {
- Name [XT_TABLE_MAXNAMELEN]byte
- Size uint32
- _ [4]byte
+ IPTGetEntries
Entrytable []KernelIPTEntry
}
-// TableName returns the entries' table name.
-func (entries *KernelIPTGetEntries) TableName() string {
- return tableName(entries.Name[:])
-}
-
// IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It
// corresponds to struct ipt_replace in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
type IPTReplace struct {
- Name [XT_TABLE_MAXNAMELEN]byte
+ Name TableName
ValidHooks uint32
NumEntries uint32
Size uint32
@@ -306,14 +298,45 @@ type IPTReplace struct {
// Entries [0]IPTEntry
}
+// KernelIPTReplace is identical to IPTReplace, but includes the Entries field.
+type KernelIPTReplace struct {
+ IPTReplace
+ Entries [0]IPTEntry
+}
+
// SizeOfIPTReplace is the size of an IPTReplace.
const SizeOfIPTReplace = 96
-func tableName(name []byte) string {
- for i, c := range name {
+// ExtensionName holds the name of a netfilter extension.
+type ExtensionName [XT_EXTENSION_MAXNAMELEN]byte
+
+// String implements fmt.Stringer.
+func (en ExtensionName) String() string {
+ return goString(en[:])
+}
+
+// TableName holds the name of a netfilter table.
+type TableName [XT_TABLE_MAXNAMELEN]byte
+
+// String implements fmt.Stringer.
+func (tn TableName) String() string {
+ return goString(tn[:])
+}
+
+// ErrorName holds the name of a netfilter error. These can also hold
+// user-defined chains.
+type ErrorName [XT_FUNCTION_MAXNAMELEN]byte
+
+// String implements fmt.Stringer.
+func (en ErrorName) String() string {
+ return goString(en[:])
+}
+
+func goString(cstring []byte) string {
+ for i, c := range cstring {
if c == 0 {
- return string(name[:i])
+ return string(cstring[:i])
}
}
- return string(name)
+ return string(cstring)
}
diff --git a/pkg/abi/linux/time.go b/pkg/abi/linux/time.go
index 546668bca..5c5a58cd4 100644
--- a/pkg/abi/linux/time.go
+++ b/pkg/abi/linux/time.go
@@ -234,6 +234,19 @@ type StatxTimestamp struct {
_ int32
}
+// ToNsec returns the nanosecond representation.
+func (sxts StatxTimestamp) ToNsec() int64 {
+ return int64(sxts.Sec)*1e9 + int64(sxts.Nsec)
+}
+
+// ToNsecCapped returns the safe nanosecond representation.
+func (sxts StatxTimestamp) ToNsecCapped() int64 {
+ if sxts.Sec > maxSecInDuration {
+ return math.MaxInt64
+ }
+ return sxts.ToNsec()
+}
+
// NsecToStatxTimestamp translates nanoseconds to StatxTimestamp.
func NsecToStatxTimestamp(nsec int64) (ts StatxTimestamp) {
return StatxTimestamp{
diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD
index 6bc486b62..d99e37b40 100644
--- a/pkg/amutex/BUILD
+++ b/pkg/amutex/BUILD
@@ -15,4 +15,5 @@ go_test(
size = "small",
srcs = ["amutex_test.go"],
embed = [":amutex"],
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/amutex/amutex_test.go b/pkg/amutex/amutex_test.go
index 1d7f45641..8a3952f2a 100644
--- a/pkg/amutex/amutex_test.go
+++ b/pkg/amutex/amutex_test.go
@@ -15,9 +15,10 @@
package amutex
import (
- "sync"
"testing"
"time"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
type sleeper struct {
diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD
index 36beaade9..6403c60c2 100644
--- a/pkg/atomicbitops/BUILD
+++ b/pkg/atomicbitops/BUILD
@@ -20,4 +20,5 @@ go_test(
size = "small",
srcs = ["atomic_bitops_test.go"],
embed = [":atomicbitops"],
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/atomicbitops/atomic_bitops_test.go b/pkg/atomicbitops/atomic_bitops_test.go
index 965e9be79..9466d3e23 100644
--- a/pkg/atomicbitops/atomic_bitops_test.go
+++ b/pkg/atomicbitops/atomic_bitops_test.go
@@ -16,8 +16,9 @@ package atomicbitops
import (
"runtime"
- "sync"
"testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
const iterations = 100
diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD
index a0b21d4bd..2bb581b18 100644
--- a/pkg/compressio/BUILD
+++ b/pkg/compressio/BUILD
@@ -8,7 +8,10 @@ go_library(
srcs = ["compressio.go"],
importpath = "gvisor.dev/gvisor/pkg/compressio",
visibility = ["//:sandbox"],
- deps = ["//pkg/binary"],
+ deps = [
+ "//pkg/binary",
+ "//pkg/sync",
+ ],
)
go_test(
diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go
index 3b0bb086e..5f52cbe74 100644
--- a/pkg/compressio/compressio.go
+++ b/pkg/compressio/compressio.go
@@ -52,9 +52,9 @@ import (
"hash"
"io"
"runtime"
- "sync"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sync"
)
var bufPool = sync.Pool{
diff --git a/pkg/control/server/BUILD b/pkg/control/server/BUILD
index 21adf3adf..adbd1e3f8 100644
--- a/pkg/control/server/BUILD
+++ b/pkg/control/server/BUILD
@@ -9,6 +9,7 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/log",
+ "//pkg/sync",
"//pkg/unet",
"//pkg/urpc",
],
diff --git a/pkg/control/server/server.go b/pkg/control/server/server.go
index a56152d10..41abe1f2d 100644
--- a/pkg/control/server/server.go
+++ b/pkg/control/server/server.go
@@ -22,9 +22,9 @@ package server
import (
"os"
- "sync"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
"gvisor.dev/gvisor/pkg/urpc"
)
diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go
index d37047368..cf50ee53f 100644
--- a/pkg/cpuid/cpuid.go
+++ b/pkg/cpuid/cpuid.go
@@ -657,30 +657,28 @@ func (fs *FeatureSet) FlagsString(cpuinfoOnly bool) string {
return strings.Join(s, " ")
}
-// CPUInfo is to generate a section of one cpu in /proc/cpuinfo. This is a
-// minimal /proc/cpuinfo, it is missing some fields like "microcode" that are
+// WriteCPUInfoTo is to generate a section of one cpu in /proc/cpuinfo. This is
+// a minimal /proc/cpuinfo, it is missing some fields like "microcode" that are
// not always printed in Linux. The bogomips field is simply made up.
-func (fs FeatureSet) CPUInfo(cpu uint) string {
- var b bytes.Buffer
- fmt.Fprintf(&b, "processor\t: %d\n", cpu)
- fmt.Fprintf(&b, "vendor_id\t: %s\n", fs.VendorID)
- fmt.Fprintf(&b, "cpu family\t: %d\n", ((fs.ExtendedFamily<<4)&0xff)|fs.Family)
- fmt.Fprintf(&b, "model\t\t: %d\n", ((fs.ExtendedModel<<4)&0xff)|fs.Model)
- fmt.Fprintf(&b, "model name\t: %s\n", "unknown") // Unknown for now.
- fmt.Fprintf(&b, "stepping\t: %s\n", "unknown") // Unknown for now.
- fmt.Fprintf(&b, "cpu MHz\t\t: %.3f\n", cpuFreqMHz)
- fmt.Fprintln(&b, "fpu\t\t: yes")
- fmt.Fprintln(&b, "fpu_exception\t: yes")
- fmt.Fprintf(&b, "cpuid level\t: %d\n", uint32(xSaveInfo)) // Same as ax in vendorID.
- fmt.Fprintln(&b, "wp\t\t: yes")
- fmt.Fprintf(&b, "flags\t\t: %s\n", fs.FlagsString(true))
- fmt.Fprintf(&b, "bogomips\t: %.02f\n", cpuFreqMHz) // It's bogus anyway.
- fmt.Fprintf(&b, "clflush size\t: %d\n", fs.CacheLine)
- fmt.Fprintf(&b, "cache_alignment\t: %d\n", fs.CacheLine)
- fmt.Fprintf(&b, "address sizes\t: %d bits physical, %d bits virtual\n", 46, 48)
- fmt.Fprintln(&b, "power management:") // This is always here, but can be blank.
- fmt.Fprintln(&b, "") // The /proc/cpuinfo file ends with an extra newline.
- return b.String()
+func (fs FeatureSet) WriteCPUInfoTo(cpu uint, b *bytes.Buffer) {
+ fmt.Fprintf(b, "processor\t: %d\n", cpu)
+ fmt.Fprintf(b, "vendor_id\t: %s\n", fs.VendorID)
+ fmt.Fprintf(b, "cpu family\t: %d\n", ((fs.ExtendedFamily<<4)&0xff)|fs.Family)
+ fmt.Fprintf(b, "model\t\t: %d\n", ((fs.ExtendedModel<<4)&0xff)|fs.Model)
+ fmt.Fprintf(b, "model name\t: %s\n", "unknown") // Unknown for now.
+ fmt.Fprintf(b, "stepping\t: %s\n", "unknown") // Unknown for now.
+ fmt.Fprintf(b, "cpu MHz\t\t: %.3f\n", cpuFreqMHz)
+ fmt.Fprintln(b, "fpu\t\t: yes")
+ fmt.Fprintln(b, "fpu_exception\t: yes")
+ fmt.Fprintf(b, "cpuid level\t: %d\n", uint32(xSaveInfo)) // Same as ax in vendorID.
+ fmt.Fprintln(b, "wp\t\t: yes")
+ fmt.Fprintf(b, "flags\t\t: %s\n", fs.FlagsString(true))
+ fmt.Fprintf(b, "bogomips\t: %.02f\n", cpuFreqMHz) // It's bogus anyway.
+ fmt.Fprintf(b, "clflush size\t: %d\n", fs.CacheLine)
+ fmt.Fprintf(b, "cache_alignment\t: %d\n", fs.CacheLine)
+ fmt.Fprintf(b, "address sizes\t: %d bits physical, %d bits virtual\n", 46, 48)
+ fmt.Fprintln(b, "power management:") // This is always here, but can be blank.
+ fmt.Fprintln(b, "") // The /proc/cpuinfo file ends with an extra newline.
}
const (
diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD
index 0b4b7cc44..9d68682c7 100644
--- a/pkg/eventchannel/BUILD
+++ b/pkg/eventchannel/BUILD
@@ -15,6 +15,7 @@ go_library(
deps = [
":eventchannel_go_proto",
"//pkg/log",
+ "//pkg/sync",
"//pkg/unet",
"@com_github_golang_protobuf//proto:go_default_library",
"@com_github_golang_protobuf//ptypes:go_default_library_gen",
@@ -40,6 +41,7 @@ go_test(
srcs = ["event_test.go"],
embed = [":eventchannel"],
deps = [
+ "//pkg/sync",
"@com_github_golang_protobuf//proto:go_default_library",
],
)
diff --git a/pkg/eventchannel/event.go b/pkg/eventchannel/event.go
index d37ad0428..9a29c58bd 100644
--- a/pkg/eventchannel/event.go
+++ b/pkg/eventchannel/event.go
@@ -22,13 +22,13 @@ package eventchannel
import (
"encoding/binary"
"fmt"
- "sync"
"syscall"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
pb "gvisor.dev/gvisor/pkg/eventchannel/eventchannel_go_proto"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
diff --git a/pkg/eventchannel/event_test.go b/pkg/eventchannel/event_test.go
index 3649097d6..7f41b4a27 100644
--- a/pkg/eventchannel/event_test.go
+++ b/pkg/eventchannel/event_test.go
@@ -16,11 +16,11 @@ package eventchannel
import (
"fmt"
- "sync"
"testing"
"time"
"github.com/golang/protobuf/proto"
+ "gvisor.dev/gvisor/pkg/sync"
)
// testEmitter is an emitter that can be used in tests. It records all events
diff --git a/pkg/fdchannel/BUILD b/pkg/fdchannel/BUILD
index 56495cbd9..b0478c672 100644
--- a/pkg/fdchannel/BUILD
+++ b/pkg/fdchannel/BUILD
@@ -15,4 +15,5 @@ go_test(
size = "small",
srcs = ["fdchannel_test.go"],
embed = [":fdchannel"],
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/fdchannel/fdchannel_test.go b/pkg/fdchannel/fdchannel_test.go
index 5d01dc636..7a8a63a59 100644
--- a/pkg/fdchannel/fdchannel_test.go
+++ b/pkg/fdchannel/fdchannel_test.go
@@ -17,10 +17,11 @@ package fdchannel
import (
"io/ioutil"
"os"
- "sync"
"syscall"
"testing"
"time"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
func TestSendRecvFD(t *testing.T) {
diff --git a/pkg/fdnotifier/BUILD b/pkg/fdnotifier/BUILD
index aca2d8a82..91a202a30 100644
--- a/pkg/fdnotifier/BUILD
+++ b/pkg/fdnotifier/BUILD
@@ -11,6 +11,7 @@ go_library(
importpath = "gvisor.dev/gvisor/pkg/fdnotifier",
visibility = ["//:sandbox"],
deps = [
+ "//pkg/sync",
"//pkg/waiter",
"@org_golang_x_sys//unix:go_default_library",
],
diff --git a/pkg/fdnotifier/fdnotifier.go b/pkg/fdnotifier/fdnotifier.go
index f4aae1953..a6b63c982 100644
--- a/pkg/fdnotifier/fdnotifier.go
+++ b/pkg/fdnotifier/fdnotifier.go
@@ -22,10 +22,10 @@ package fdnotifier
import (
"fmt"
- "sync"
"syscall"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD
index e590a71ba..85bd83af1 100644
--- a/pkg/flipcall/BUILD
+++ b/pkg/flipcall/BUILD
@@ -19,7 +19,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/log",
"//pkg/memutil",
- "//pkg/syncutil",
+ "//pkg/sync",
],
)
@@ -31,4 +31,5 @@ go_test(
"flipcall_test.go",
],
embed = [":flipcall"],
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/flipcall/flipcall_example_test.go b/pkg/flipcall/flipcall_example_test.go
index 8d88b845d..2e28a149a 100644
--- a/pkg/flipcall/flipcall_example_test.go
+++ b/pkg/flipcall/flipcall_example_test.go
@@ -17,7 +17,8 @@ package flipcall
import (
"bytes"
"fmt"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
func Example() {
diff --git a/pkg/flipcall/flipcall_test.go b/pkg/flipcall/flipcall_test.go
index 168a487ec..33fd55a44 100644
--- a/pkg/flipcall/flipcall_test.go
+++ b/pkg/flipcall/flipcall_test.go
@@ -16,9 +16,10 @@ package flipcall
import (
"runtime"
- "sync"
"testing"
"time"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
var testPacketWindowSize = pageSize
diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go
index 27b8939fc..ac974b232 100644
--- a/pkg/flipcall/flipcall_unsafe.go
+++ b/pkg/flipcall/flipcall_unsafe.go
@@ -18,7 +18,7 @@ import (
"reflect"
"unsafe"
- "gvisor.dev/gvisor/pkg/syncutil"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Packets consist of a 16-byte header followed by an arbitrarily-sized
@@ -75,13 +75,13 @@ func (ep *Endpoint) Data() []byte {
var ioSync int64
func raceBecomeActive() {
- if syncutil.RaceEnabled {
- syncutil.RaceAcquire((unsafe.Pointer)(&ioSync))
+ if sync.RaceEnabled {
+ sync.RaceAcquire((unsafe.Pointer)(&ioSync))
}
}
func raceBecomeInactive() {
- if syncutil.RaceEnabled {
- syncutil.RaceReleaseMerge((unsafe.Pointer)(&ioSync))
+ if sync.RaceEnabled {
+ sync.RaceReleaseMerge((unsafe.Pointer)(&ioSync))
}
}
diff --git a/pkg/gate/BUILD b/pkg/gate/BUILD
index 4b9321711..f22bd070d 100644
--- a/pkg/gate/BUILD
+++ b/pkg/gate/BUILD
@@ -19,5 +19,6 @@ go_test(
],
deps = [
":gate",
+ "//pkg/sync",
],
)
diff --git a/pkg/gate/gate_test.go b/pkg/gate/gate_test.go
index 5dbd8d712..850693df8 100644
--- a/pkg/gate/gate_test.go
+++ b/pkg/gate/gate_test.go
@@ -15,11 +15,11 @@
package gate_test
import (
- "sync"
"testing"
"time"
"gvisor.dev/gvisor/pkg/gate"
+ "gvisor.dev/gvisor/pkg/sync"
)
func TestBasicEnter(t *testing.T) {
diff --git a/pkg/goid/BUILD b/pkg/goid/BUILD
new file mode 100644
index 000000000..5d31e5366
--- /dev/null
+++ b/pkg/goid/BUILD
@@ -0,0 +1,26 @@
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "goid",
+ srcs = [
+ "goid.go",
+ "goid_amd64.s",
+ "goid_race.go",
+ "goid_unsafe.go",
+ ],
+ importpath = "gvisor.dev/gvisor/pkg/goid",
+ visibility = ["//visibility:public"],
+)
+
+go_test(
+ name = "goid_test",
+ size = "small",
+ srcs = [
+ "empty_test.go",
+ "goid_test.go",
+ ],
+ embed = [":goid"],
+)
diff --git a/pkg/sentry/socket/rpcinet/rpcinet.go b/pkg/goid/empty_test.go
index 5d4fd4dac..c0a4b17ab 100644
--- a/pkg/sentry/socket/rpcinet/rpcinet.go
+++ b/pkg/goid/empty_test.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,5 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package rpcinet implements sockets using an RPC for each syscall.
-package rpcinet
+// +build !race
+
+package goid
+
+import "testing"
+
+// TestNothing exists to make the build system happy.
+func TestNothing(t *testing.T) {}
diff --git a/pkg/goid/goid.go b/pkg/goid/goid.go
new file mode 100644
index 000000000..39df30031
--- /dev/null
+++ b/pkg/goid/goid.go
@@ -0,0 +1,24 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !race
+
+// Package goid provides access to the ID of the current goroutine in
+// race/gotsan builds.
+package goid
+
+// Get returns the ID of the current goroutine.
+func Get() int64 {
+ panic("unimplemented for non-race builds")
+}
diff --git a/pkg/sentry/socket/rpcinet/device.go b/pkg/goid/goid_amd64.s
index 8cfd5f6e5..d9f5cd2a3 100644
--- a/pkg/sentry/socket/rpcinet/device.go
+++ b/pkg/goid/goid_amd64.s
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package rpcinet
+#include "textflag.h"
-import "gvisor.dev/gvisor/pkg/sentry/device"
-
-var socketDevice = device.NewAnonDevice()
+// func getg() *g
+TEXT ·getg(SB),NOSPLIT,$0-8
+ MOVQ (TLS), R14
+ MOVQ R14, ret+0(FP)
+ RET
diff --git a/pkg/goid/goid_race.go b/pkg/goid/goid_race.go
new file mode 100644
index 000000000..1766beaee
--- /dev/null
+++ b/pkg/goid/goid_race.go
@@ -0,0 +1,25 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Only available in race/gotsan builds.
+// +build race
+
+// Package goid provides access to the ID of the current goroutine in
+// race/gotsan builds.
+package goid
+
+// Get returns the ID of the current goroutine.
+func Get() int64 {
+ return goid()
+}
diff --git a/pkg/goid/goid_test.go b/pkg/goid/goid_test.go
new file mode 100644
index 000000000..31970ce79
--- /dev/null
+++ b/pkg/goid/goid_test.go
@@ -0,0 +1,74 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build race
+
+package goid
+
+import (
+ "runtime"
+ "sync"
+ "testing"
+)
+
+func TestInitialGoID(t *testing.T) {
+ const max = 10000
+ if id := goid(); id < 0 || id > max {
+ t.Errorf("got goid = %d, want 0 < goid <= %d", id, max)
+ }
+}
+
+// TestGoIDSquence verifies that goid returns values which could plausibly be
+// goroutine IDs. If this test breaks or becomes flaky, the structs in
+// goid_unsafe.go may need to be updated.
+func TestGoIDSquence(t *testing.T) {
+ // Goroutine IDs are cached by each P.
+ runtime.GOMAXPROCS(1)
+
+ // Fill any holes in lower range.
+ for i := 0; i < 50; i++ {
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ wg.Done()
+
+ // Leak the goroutine to prevent the ID from being
+ // reused.
+ select {}
+ }()
+ wg.Wait()
+ }
+
+ id := goid()
+ for i := 0; i < 100; i++ {
+ var (
+ newID int64
+ wg sync.WaitGroup
+ )
+ wg.Add(1)
+ go func() {
+ newID = goid()
+ wg.Done()
+
+ // Leak the goroutine to prevent the ID from being
+ // reused.
+ select {}
+ }()
+ wg.Wait()
+ if max := id + 100; newID <= id || newID > max {
+ t.Errorf("unexpected goroutine ID pattern, got goid = %d, want %d < goid <= %d (previous = %d)", newID, id, max, id)
+ }
+ id = newID
+ }
+}
diff --git a/pkg/goid/goid_unsafe.go b/pkg/goid/goid_unsafe.go
new file mode 100644
index 000000000..ded8004dd
--- /dev/null
+++ b/pkg/goid/goid_unsafe.go
@@ -0,0 +1,64 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package goid
+
+// Structs from Go runtime. These may change in the future and require
+// updating. These structs are currently the same on both AMD64 and ARM64,
+// but may diverge in the future.
+
+type stack struct {
+ lo uintptr
+ hi uintptr
+}
+
+type gobuf struct {
+ sp uintptr
+ pc uintptr
+ g uintptr
+ ctxt uintptr
+ ret uint64
+ lr uintptr
+ bp uintptr
+}
+
+type g struct {
+ stack stack
+ stackguard0 uintptr
+ stackguard1 uintptr
+
+ _panic uintptr
+ _defer uintptr
+ m uintptr
+ sched gobuf
+ syscallsp uintptr
+ syscallpc uintptr
+ stktopsp uintptr
+ param uintptr
+ atomicstatus uint32
+ stackLock uint32
+ goid int64
+
+ // More fields...
+ //
+ // We only use goid and the fields before it are only listed to
+ // calculate the correct offset.
+}
+
+func getg() *g
+
+// goid returns the ID of the current goroutine.
+func goid() int64 {
+ return getg().goid
+}
diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD
index a5d980d14..bcde6d308 100644
--- a/pkg/linewriter/BUILD
+++ b/pkg/linewriter/BUILD
@@ -8,6 +8,7 @@ go_library(
srcs = ["linewriter.go"],
importpath = "gvisor.dev/gvisor/pkg/linewriter",
visibility = ["//visibility:public"],
+ deps = ["//pkg/sync"],
)
go_test(
diff --git a/pkg/linewriter/linewriter.go b/pkg/linewriter/linewriter.go
index cd6e4e2ce..a1b1285d4 100644
--- a/pkg/linewriter/linewriter.go
+++ b/pkg/linewriter/linewriter.go
@@ -17,7 +17,8 @@ package linewriter
import (
"bytes"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// Writer is an io.Writer which buffers input, flushing
diff --git a/pkg/log/BUILD b/pkg/log/BUILD
index fc5f5779b..0df0f2849 100644
--- a/pkg/log/BUILD
+++ b/pkg/log/BUILD
@@ -16,7 +16,10 @@ go_library(
visibility = [
"//visibility:public",
],
- deps = ["//pkg/linewriter"],
+ deps = [
+ "//pkg/linewriter",
+ "//pkg/sync",
+ ],
)
go_test(
diff --git a/pkg/log/log.go b/pkg/log/log.go
index 9387586e6..91a81b288 100644
--- a/pkg/log/log.go
+++ b/pkg/log/log.go
@@ -25,12 +25,12 @@ import (
stdlog "log"
"os"
"runtime"
- "sync"
"sync/atomic"
"syscall"
"time"
"gvisor.dev/gvisor/pkg/linewriter"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Level is the log level.
diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD
index dd6ca6d39..9145f3233 100644
--- a/pkg/metric/BUILD
+++ b/pkg/metric/BUILD
@@ -14,6 +14,7 @@ go_library(
":metric_go_proto",
"//pkg/eventchannel",
"//pkg/log",
+ "//pkg/sync",
],
)
diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go
index eadde06e4..93d4f2b8c 100644
--- a/pkg/metric/metric.go
+++ b/pkg/metric/metric.go
@@ -18,12 +18,12 @@ package metric
import (
"errors"
"fmt"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/eventchannel"
"gvisor.dev/gvisor/pkg/log"
pb "gvisor.dev/gvisor/pkg/metric/metric_go_proto"
+ "gvisor.dev/gvisor/pkg/sync"
)
var (
diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD
index f32244c69..a3e05c96d 100644
--- a/pkg/p9/BUILD
+++ b/pkg/p9/BUILD
@@ -29,6 +29,7 @@ go_library(
"//pkg/fdchannel",
"//pkg/flipcall",
"//pkg/log",
+ "//pkg/sync",
"//pkg/unet",
"@org_golang_x_sys//unix:go_default_library",
],
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index 221516c6c..4045e41fa 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -17,12 +17,12 @@ package p9
import (
"errors"
"fmt"
- "sync"
"syscall"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/flipcall"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go
index de9357389..0254e4ccc 100644
--- a/pkg/p9/client_file.go
+++ b/pkg/p9/client_file.go
@@ -165,6 +165,35 @@ func (c *clientFile) SetAttr(valid SetAttrMask, attr SetAttr) error {
return c.client.sendRecv(&Tsetattr{FID: c.fid, Valid: valid, SetAttr: attr}, &Rsetattr{})
}
+// GetXattr implements File.GetXattr.
+func (c *clientFile) GetXattr(name string, size uint64) (string, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return "", syscall.EBADF
+ }
+ if !versionSupportsGetSetXattr(c.client.version) {
+ return "", syscall.EOPNOTSUPP
+ }
+
+ rgetxattr := Rgetxattr{}
+ if err := c.client.sendRecv(&Tgetxattr{FID: c.fid, Name: name, Size: size}, &rgetxattr); err != nil {
+ return "", err
+ }
+
+ return rgetxattr.Value, nil
+}
+
+// SetXattr implements File.SetXattr.
+func (c *clientFile) SetXattr(name, value string, flags uint32) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+ if !versionSupportsGetSetXattr(c.client.version) {
+ return syscall.EOPNOTSUPP
+ }
+
+ return c.client.sendRecv(&Tsetxattr{FID: c.fid, Name: name, Value: value, Flags: flags}, &Rsetxattr{})
+}
+
// Allocate implements File.Allocate.
func (c *clientFile) Allocate(mode AllocateMode, offset, length uint64) error {
if atomic.LoadUint32(&c.closed) != 0 {
diff --git a/pkg/p9/file.go b/pkg/p9/file.go
index 96d1f2a8e..4607cfcdf 100644
--- a/pkg/p9/file.go
+++ b/pkg/p9/file.go
@@ -89,6 +89,22 @@ type File interface {
// On the server, SetAttr has a write concurrency guarantee.
SetAttr(valid SetAttrMask, attr SetAttr) error
+ // GetXattr returns extended attributes of this node.
+ //
+ // Size indicates the size of the buffer that has been allocated to hold the
+ // attribute value. If the value is larger than size, implementations may
+ // return ERANGE to indicate that the buffer is too small, but they are also
+ // free to ignore the hint entirely (i.e. the value returned may be larger
+ // than size). All size checking is done independently at the syscall layer.
+ //
+ // TODO(b/127675828): Determine concurrency guarantees once implemented.
+ GetXattr(name string, size uint64) (string, error)
+
+ // SetXattr sets extended attributes on this node.
+ //
+ // TODO(b/127675828): Determine concurrency guarantees once implemented.
+ SetXattr(name, value string, flags uint32) error
+
// Allocate allows the caller to directly manipulate the allocated disk space
// for the file. See fallocate(2) for more details.
Allocate(mode AllocateMode, offset, length uint64) error
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
index b9582c07f..7d6653a07 100644
--- a/pkg/p9/handlers.go
+++ b/pkg/p9/handlers.go
@@ -913,6 +913,35 @@ func (t *Txattrcreate) handle(cs *connState) message {
}
// handle implements handler.handle.
+func (t *Tgetxattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ val, err := ref.file.GetXattr(t.Name, t.Size)
+ if err != nil {
+ return newErr(err)
+ }
+ return &Rgetxattr{Value: val}
+}
+
+// handle implements handler.handle.
+func (t *Tsetxattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.file.SetXattr(t.Name, t.Value, t.Flags); err != nil {
+ return newErr(err)
+ }
+ return &Rsetxattr{}
+}
+
+// handle implements handler.handle.
func (t *Treaddir) handle(cs *connState) message {
ref, ok := cs.LookupFID(t.Directory)
if !ok {
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go
index ffdd7e8c6..ceb723d86 100644
--- a/pkg/p9/messages.go
+++ b/pkg/p9/messages.go
@@ -1611,6 +1611,131 @@ func (r *Rxattrcreate) String() string {
return fmt.Sprintf("Rxattrcreate{}")
}
+// Tgetxattr is a getxattr request.
+type Tgetxattr struct {
+ // FID refers to the file for which to get xattrs.
+ FID FID
+
+ // Name is the xattr to get.
+ Name string
+
+ // Size is the buffer size for the xattr to get.
+ Size uint64
+}
+
+// Decode implements encoder.Decode.
+func (t *Tgetxattr) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Name = b.ReadString()
+ t.Size = b.Read64()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tgetxattr) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteString(t.Name)
+ b.Write64(t.Size)
+}
+
+// Type implements message.Type.
+func (*Tgetxattr) Type() MsgType {
+ return MsgTgetxattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tgetxattr) String() string {
+ return fmt.Sprintf("Tgetxattr{FID: %d, Name: %s, Size: %d}", t.FID, t.Name, t.Size)
+}
+
+// Rgetxattr is a getxattr response.
+type Rgetxattr struct {
+ // Value is the extended attribute value.
+ Value string
+}
+
+// Decode implements encoder.Decode.
+func (r *Rgetxattr) Decode(b *buffer) {
+ r.Value = b.ReadString()
+}
+
+// Encode implements encoder.Encode.
+func (r *Rgetxattr) Encode(b *buffer) {
+ b.WriteString(r.Value)
+}
+
+// Type implements message.Type.
+func (*Rgetxattr) Type() MsgType {
+ return MsgRgetxattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rgetxattr) String() string {
+ return fmt.Sprintf("Rgetxattr{Value: %s}", r.Value)
+}
+
+// Tsetxattr sets extended attributes.
+type Tsetxattr struct {
+ // FID refers to the file on which to set xattrs.
+ FID FID
+
+ // Name is the attribute name.
+ Name string
+
+ // Value is the attribute value.
+ Value string
+
+ // Linux setxattr(2) flags.
+ Flags uint32
+}
+
+// Decode implements encoder.Decode.
+func (t *Tsetxattr) Decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Name = b.ReadString()
+ t.Value = b.ReadString()
+ t.Flags = b.Read32()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tsetxattr) Encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteString(t.Name)
+ b.WriteString(t.Value)
+ b.Write32(t.Flags)
+}
+
+// Type implements message.Type.
+func (*Tsetxattr) Type() MsgType {
+ return MsgTsetxattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tsetxattr) String() string {
+ return fmt.Sprintf("Tsetxattr{FID: %d, Name: %s, Value: %s, Flags: %d}", t.FID, t.Name, t.Value, t.Flags)
+}
+
+// Rsetxattr is a setxattr response.
+type Rsetxattr struct {
+}
+
+// Decode implements encoder.Decode.
+func (r *Rsetxattr) Decode(b *buffer) {
+}
+
+// Encode implements encoder.Encode.
+func (r *Rsetxattr) Encode(b *buffer) {
+}
+
+// Type implements message.Type.
+func (*Rsetxattr) Type() MsgType {
+ return MsgRsetxattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rsetxattr) String() string {
+ return fmt.Sprintf("Rsetxattr{}")
+}
+
// Treaddir is a readdir request.
type Treaddir struct {
// Directory is the directory FID to read.
@@ -2363,6 +2488,10 @@ func init() {
msgRegistry.register(MsgRxattrwalk, func() message { return &Rxattrwalk{} })
msgRegistry.register(MsgTxattrcreate, func() message { return &Txattrcreate{} })
msgRegistry.register(MsgRxattrcreate, func() message { return &Rxattrcreate{} })
+ msgRegistry.register(MsgTgetxattr, func() message { return &Tgetxattr{} })
+ msgRegistry.register(MsgRgetxattr, func() message { return &Rgetxattr{} })
+ msgRegistry.register(MsgTsetxattr, func() message { return &Tsetxattr{} })
+ msgRegistry.register(MsgRsetxattr, func() message { return &Rsetxattr{} })
msgRegistry.register(MsgTreaddir, func() message { return &Treaddir{} })
msgRegistry.register(MsgRreaddir, func() message { return &Rreaddir{} })
msgRegistry.register(MsgTfsync, func() message { return &Tfsync{} })
diff --git a/pkg/p9/messages_test.go b/pkg/p9/messages_test.go
index 6ba6a1654..825c939da 100644
--- a/pkg/p9/messages_test.go
+++ b/pkg/p9/messages_test.go
@@ -194,6 +194,21 @@ func TestEncodeDecode(t *testing.T) {
Flags: 3,
},
&Rxattrcreate{},
+ &Tgetxattr{
+ FID: 1,
+ Name: "abc",
+ Size: 2,
+ },
+ &Rgetxattr{
+ Value: "xyz",
+ },
+ &Tsetxattr{
+ FID: 1,
+ Name: "abc",
+ Value: "xyz",
+ Flags: 2,
+ },
+ &Rsetxattr{},
&Treaddir{
Directory: 1,
Offset: 2,
diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go
index d3090535a..5ab00d625 100644
--- a/pkg/p9/p9.go
+++ b/pkg/p9/p9.go
@@ -339,6 +339,10 @@ const (
MsgRxattrwalk = 31
MsgTxattrcreate = 32
MsgRxattrcreate = 33
+ MsgTgetxattr = 34
+ MsgRgetxattr = 35
+ MsgTsetxattr = 36
+ MsgRsetxattr = 37
MsgTreaddir = 40
MsgRreaddir = 41
MsgTfsync = 50
diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD
index 28707c0ca..f4edd68b2 100644
--- a/pkg/p9/p9test/BUILD
+++ b/pkg/p9/p9test/BUILD
@@ -70,6 +70,7 @@ go_library(
"//pkg/fd",
"//pkg/log",
"//pkg/p9",
+ "//pkg/sync",
"//pkg/unet",
"@com_github_golang_mock//gomock:go_default_library",
],
@@ -83,6 +84,7 @@ go_test(
deps = [
"//pkg/fd",
"//pkg/p9",
+ "//pkg/sync",
"@com_github_golang_mock//gomock:go_default_library",
],
)
diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go
index 6e758148d..6e7bb3db2 100644
--- a/pkg/p9/p9test/client_test.go
+++ b/pkg/p9/p9test/client_test.go
@@ -22,7 +22,6 @@ import (
"os"
"reflect"
"strings"
- "sync"
"syscall"
"testing"
"time"
@@ -30,6 +29,7 @@ import (
"github.com/golang/mock/gomock"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sync"
)
func TestPanic(t *testing.T) {
diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go
index 4d3271b37..dd8b01b6d 100644
--- a/pkg/p9/p9test/p9test.go
+++ b/pkg/p9/p9test/p9test.go
@@ -17,13 +17,13 @@ package p9test
import (
"fmt"
- "sync"
"sync/atomic"
"syscall"
"testing"
"github.com/golang/mock/gomock"
"gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
diff --git a/pkg/p9/path_tree.go b/pkg/p9/path_tree.go
index 865459411..72ef53313 100644
--- a/pkg/p9/path_tree.go
+++ b/pkg/p9/path_tree.go
@@ -16,7 +16,8 @@ package p9
import (
"fmt"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// pathNode is a single node in a path traversal.
diff --git a/pkg/p9/pool.go b/pkg/p9/pool.go
index 52de889e1..2b14a5ce3 100644
--- a/pkg/p9/pool.go
+++ b/pkg/p9/pool.go
@@ -15,7 +15,7 @@
package p9
import (
- "sync"
+ "gvisor.dev/gvisor/pkg/sync"
)
// pool is a simple allocator.
diff --git a/pkg/p9/server.go b/pkg/p9/server.go
index 40b8fa023..fdfa83648 100644
--- a/pkg/p9/server.go
+++ b/pkg/p9/server.go
@@ -17,7 +17,6 @@ package p9
import (
"io"
"runtime/debug"
- "sync"
"sync/atomic"
"syscall"
@@ -25,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/fdchannel"
"gvisor.dev/gvisor/pkg/flipcall"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go
index 6e8b4bbcd..9c11e28ce 100644
--- a/pkg/p9/transport.go
+++ b/pkg/p9/transport.go
@@ -19,11 +19,11 @@ import (
"fmt"
"io"
"io/ioutil"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
diff --git a/pkg/p9/version.go b/pkg/p9/version.go
index 36a694c58..34a15eb55 100644
--- a/pkg/p9/version.go
+++ b/pkg/p9/version.go
@@ -26,7 +26,7 @@ const (
//
// Clients are expected to start requesting this version number and
// to continuously decrement it until a Tversion request succeeds.
- highestSupportedVersion uint32 = 9
+ highestSupportedVersion uint32 = 10
// lowestSupportedVersion is the lowest supported version X in a
// version string of the format 9P2000.L.Google.X.
@@ -161,3 +161,9 @@ func versionSupportsFlipcall(v uint32) bool {
func VersionSupportsOpenTruncateFlag(v uint32) bool {
return v >= 9
}
+
+// versionSupportsGetSetXattr returns true if version v supports
+// the Tgetxattr and Tsetxattr messages.
+func versionSupportsGetSetXattr(v uint32) bool {
+ return v >= 10
+}
diff --git a/pkg/procid/BUILD b/pkg/procid/BUILD
index 078f084b2..b506813f0 100644
--- a/pkg/procid/BUILD
+++ b/pkg/procid/BUILD
@@ -21,6 +21,7 @@ go_test(
"procid_test.go",
],
embed = [":procid"],
+ deps = ["//pkg/sync"],
)
go_test(
@@ -31,4 +32,5 @@ go_test(
"procid_test.go",
],
embed = [":procid"],
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/procid/procid_test.go b/pkg/procid/procid_test.go
index 88dd0b3ae..9ec08c3d6 100644
--- a/pkg/procid/procid_test.go
+++ b/pkg/procid/procid_test.go
@@ -17,9 +17,10 @@ package procid
import (
"os"
"runtime"
- "sync"
"syscall"
"testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// runOnMain is used to send functions to run on the main (initial) thread.
diff --git a/pkg/rand/BUILD b/pkg/rand/BUILD
index f4f2001f3..9d5b4859b 100644
--- a/pkg/rand/BUILD
+++ b/pkg/rand/BUILD
@@ -10,5 +10,8 @@ go_library(
],
importpath = "gvisor.dev/gvisor/pkg/rand",
visibility = ["//:sandbox"],
- deps = ["@org_golang_x_sys//unix:go_default_library"],
+ deps = [
+ "//pkg/sync",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
)
diff --git a/pkg/rand/rand_linux.go b/pkg/rand/rand_linux.go
index 2b92db3e6..0bdad5fad 100644
--- a/pkg/rand/rand_linux.go
+++ b/pkg/rand/rand_linux.go
@@ -19,9 +19,9 @@ package rand
import (
"crypto/rand"
"io"
- "sync"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/sync"
)
// reader implements an io.Reader that returns pseudorandom bytes.
diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD
index 7ad59dfd7..974d9af9b 100644
--- a/pkg/refs/BUILD
+++ b/pkg/refs/BUILD
@@ -27,6 +27,7 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/log",
+ "//pkg/sync",
],
)
@@ -35,4 +36,5 @@ go_test(
size = "small",
srcs = ["refcounter_test.go"],
embed = [":refs"],
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go
index ad69e0757..c45ba8200 100644
--- a/pkg/refs/refcounter.go
+++ b/pkg/refs/refcounter.go
@@ -21,10 +21,10 @@ import (
"fmt"
"reflect"
"runtime"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
)
// RefCounter is the interface to be implemented by objects that are reference
diff --git a/pkg/refs/refcounter_test.go b/pkg/refs/refcounter_test.go
index ffd3d3f07..1ab4a4440 100644
--- a/pkg/refs/refcounter_test.go
+++ b/pkg/refs/refcounter_test.go
@@ -16,8 +16,9 @@ package refs
import (
"reflect"
- "sync"
"testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
type testCounter struct {
diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD
index 18c73cc24..65f22af2b 100644
--- a/pkg/sentry/arch/BUILD
+++ b/pkg/sentry/arch/BUILD
@@ -9,17 +9,23 @@ go_library(
srcs = [
"aligned.go",
"arch.go",
+ "arch_aarch64.go",
"arch_amd64.go",
"arch_amd64.s",
+ "arch_arm64.go",
+ "arch_state_aarch64.go",
"arch_state_x86.go",
"arch_x86.go",
"auxv.go",
+ "signal.go",
"signal_act.go",
"signal_amd64.go",
+ "signal_arm64.go",
"signal_info.go",
"signal_stack.go",
"stack.go",
"syscalls_amd64.go",
+ "syscalls_arm64.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/arch",
visibility = ["//:sandbox"],
@@ -32,6 +38,7 @@ go_library(
"//pkg/sentry/context",
"//pkg/sentry/limits",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
],
)
diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go
new file mode 100644
index 000000000..ea4dedbdf
--- /dev/null
+++ b/pkg/sentry/arch/arch_aarch64.go
@@ -0,0 +1,293 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package arch
+
+import (
+ "fmt"
+ "io"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/log"
+ rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const (
+ // SyscallWidth is the width of insturctions.
+ SyscallWidth = 4
+)
+
+// aarch64FPState is aarch64 floating point state.
+type aarch64FPState []byte
+
+// initAarch64FPState (defined in asm files) sets up initial state.
+func initAarch64FPState(data *FloatingPointData) {
+ // TODO(gvisor.dev/issue/1238): floating-point is not supported.
+}
+
+func newAarch64FPStateSlice() []byte {
+ return alignedBytes(4096, 32)[:4096]
+}
+
+// newAarch64FPState returns an initialized floating point state.
+//
+// The returned state is large enough to store all floating point state
+// supported by host, even if the app won't use much of it due to a restricted
+// FeatureSet. Since they may still be able to see state not advertised by
+// CPUID we must ensure it does not contain any sentry state.
+func newAarch64FPState() aarch64FPState {
+ f := aarch64FPState(newAarch64FPStateSlice())
+ initAarch64FPState(f.FloatingPointData())
+ return f
+}
+
+// fork creates and returns an identical copy of the aarch64 floating point state.
+func (f aarch64FPState) fork() aarch64FPState {
+ n := aarch64FPState(newAarch64FPStateSlice())
+ copy(n, f)
+ return n
+}
+
+// FloatingPointData returns the raw data pointer.
+func (f aarch64FPState) FloatingPointData() *FloatingPointData {
+ return (*FloatingPointData)(&f[0])
+}
+
+// NewFloatingPointData returns a new floating point data blob.
+//
+// This is primarily for use in tests.
+func NewFloatingPointData() *FloatingPointData {
+ return (*FloatingPointData)(&(newAarch64FPState()[0]))
+}
+
+// State contains the common architecture bits for aarch64 (the build tag of this
+// file ensures it's only built on aarch64).
+type State struct {
+ // The system registers.
+ Regs syscall.PtraceRegs `state:".(syscallPtraceRegs)"`
+
+ // Our floating point state.
+ aarch64FPState `state:"wait"`
+
+ // FeatureSet is a pointer to the currently active feature set.
+ FeatureSet *cpuid.FeatureSet
+}
+
+// Proto returns a protobuf representation of the system registers in State.
+func (s State) Proto() *rpb.Registers {
+ regs := &rpb.ARM64Registers{
+ R0: s.Regs.Regs[0],
+ R1: s.Regs.Regs[1],
+ R2: s.Regs.Regs[2],
+ R3: s.Regs.Regs[3],
+ R4: s.Regs.Regs[4],
+ R5: s.Regs.Regs[5],
+ R6: s.Regs.Regs[6],
+ R7: s.Regs.Regs[7],
+ R8: s.Regs.Regs[8],
+ R9: s.Regs.Regs[9],
+ R10: s.Regs.Regs[10],
+ R11: s.Regs.Regs[11],
+ R12: s.Regs.Regs[12],
+ R13: s.Regs.Regs[13],
+ R14: s.Regs.Regs[14],
+ R15: s.Regs.Regs[15],
+ R16: s.Regs.Regs[16],
+ R17: s.Regs.Regs[17],
+ R18: s.Regs.Regs[18],
+ R19: s.Regs.Regs[19],
+ R20: s.Regs.Regs[20],
+ R21: s.Regs.Regs[21],
+ R22: s.Regs.Regs[22],
+ R23: s.Regs.Regs[23],
+ R24: s.Regs.Regs[24],
+ R25: s.Regs.Regs[25],
+ R26: s.Regs.Regs[26],
+ R27: s.Regs.Regs[27],
+ R28: s.Regs.Regs[28],
+ R29: s.Regs.Regs[29],
+ R30: s.Regs.Regs[30],
+ Sp: s.Regs.Sp,
+ Pc: s.Regs.Pc,
+ Pstate: s.Regs.Pstate,
+ }
+ return &rpb.Registers{Arch: &rpb.Registers_Arm64{Arm64: regs}}
+}
+
+// Fork creates and returns an identical copy of the state.
+func (s *State) Fork() State {
+ // TODO(gvisor.dev/issue/1238): floating-point is not supported.
+ return State{
+ Regs: s.Regs,
+ FeatureSet: s.FeatureSet,
+ }
+}
+
+// StateData implements Context.StateData.
+func (s *State) StateData() *State {
+ return s
+}
+
+// CPUIDEmulate emulates a cpuid instruction.
+func (s *State) CPUIDEmulate(l log.Logger) {
+ // TODO(gvisor.dev/issue/1255): cpuid is not supported.
+}
+
+// SingleStep implements Context.SingleStep.
+func (s *State) SingleStep() bool {
+ return false
+}
+
+// SetSingleStep enables single stepping.
+func (s *State) SetSingleStep() {
+ // Set the trap flag.
+ // TODO(gvisor.dev/issue/1239): ptrace single-step is not supported.
+}
+
+// ClearSingleStep enables single stepping.
+func (s *State) ClearSingleStep() {
+ // Clear the trap flag.
+ // TODO(gvisor.dev/issue/1239): ptrace single-step is not supported.
+}
+
+// RegisterMap returns a map of all registers.
+func (s *State) RegisterMap() (map[string]uintptr, error) {
+ return map[string]uintptr{
+ "R0": uintptr(s.Regs.Regs[0]),
+ "R1": uintptr(s.Regs.Regs[1]),
+ "R2": uintptr(s.Regs.Regs[2]),
+ "R3": uintptr(s.Regs.Regs[3]),
+ "R4": uintptr(s.Regs.Regs[4]),
+ "R5": uintptr(s.Regs.Regs[5]),
+ "R6": uintptr(s.Regs.Regs[6]),
+ "R7": uintptr(s.Regs.Regs[7]),
+ "R8": uintptr(s.Regs.Regs[8]),
+ "R9": uintptr(s.Regs.Regs[9]),
+ "R10": uintptr(s.Regs.Regs[10]),
+ "R11": uintptr(s.Regs.Regs[11]),
+ "R12": uintptr(s.Regs.Regs[12]),
+ "R13": uintptr(s.Regs.Regs[13]),
+ "R14": uintptr(s.Regs.Regs[14]),
+ "R15": uintptr(s.Regs.Regs[15]),
+ "R16": uintptr(s.Regs.Regs[16]),
+ "R17": uintptr(s.Regs.Regs[17]),
+ "R18": uintptr(s.Regs.Regs[18]),
+ "R19": uintptr(s.Regs.Regs[19]),
+ "R20": uintptr(s.Regs.Regs[20]),
+ "R21": uintptr(s.Regs.Regs[21]),
+ "R22": uintptr(s.Regs.Regs[22]),
+ "R23": uintptr(s.Regs.Regs[23]),
+ "R24": uintptr(s.Regs.Regs[24]),
+ "R25": uintptr(s.Regs.Regs[25]),
+ "R26": uintptr(s.Regs.Regs[26]),
+ "R27": uintptr(s.Regs.Regs[27]),
+ "R28": uintptr(s.Regs.Regs[28]),
+ "R29": uintptr(s.Regs.Regs[29]),
+ "R30": uintptr(s.Regs.Regs[30]),
+ "Sp": uintptr(s.Regs.Sp),
+ "Pc": uintptr(s.Regs.Pc),
+ "Pstate": uintptr(s.Regs.Pstate),
+ }, nil
+}
+
+// PtraceGetRegs implements Context.PtraceGetRegs.
+func (s *State) PtraceGetRegs(dst io.Writer) (int, error) {
+ return dst.Write(binary.Marshal(nil, usermem.ByteOrder, s.ptraceGetRegs()))
+}
+
+func (s *State) ptraceGetRegs() syscall.PtraceRegs {
+ return s.Regs
+}
+
+var ptraceRegsSize = int(binary.Size(syscall.PtraceRegs{}))
+
+// PtraceSetRegs implements Context.PtraceSetRegs.
+func (s *State) PtraceSetRegs(src io.Reader) (int, error) {
+ var regs syscall.PtraceRegs
+ buf := make([]byte, ptraceRegsSize)
+ if _, err := io.ReadFull(src, buf); err != nil {
+ return 0, err
+ }
+ binary.Unmarshal(buf, usermem.ByteOrder, &regs)
+ s.Regs = regs
+ return ptraceRegsSize, nil
+}
+
+// PtraceGetFPRegs implements Context.PtraceGetFPRegs.
+func (s *State) PtraceGetFPRegs(dst io.Writer) (int, error) {
+ // TODO(gvisor.dev/issue/1238): floating-point is not supported.
+ return 0, nil
+}
+
+// PtraceSetFPRegs implements Context.PtraceSetFPRegs.
+func (s *State) PtraceSetFPRegs(src io.Reader) (int, error) {
+ // TODO(gvisor.dev/issue/1238): floating-point is not supported.
+ return 0, nil
+}
+
+// Register sets defined in include/uapi/linux/elf.h.
+const (
+ _NT_PRSTATUS = 1
+ _NT_PRFPREG = 2
+)
+
+// PtraceGetRegSet implements Context.PtraceGetRegSet.
+func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) {
+ switch regset {
+ case _NT_PRSTATUS:
+ if maxlen < ptraceRegsSize {
+ return 0, syserror.EFAULT
+ }
+ return s.PtraceGetRegs(dst)
+ default:
+ return 0, syserror.EINVAL
+ }
+}
+
+// PtraceSetRegSet implements Context.PtraceSetRegSet.
+func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) {
+ switch regset {
+ case _NT_PRSTATUS:
+ if maxlen < ptraceRegsSize {
+ return 0, syserror.EFAULT
+ }
+ return s.PtraceSetRegs(src)
+ default:
+ return 0, syserror.EINVAL
+ }
+}
+
+// FullRestore indicates whether a full restore is required.
+func (s *State) FullRestore() bool {
+ return false
+}
+
+// New returns a new architecture context.
+func New(arch Arch, fs *cpuid.FeatureSet) Context {
+ switch arch {
+ case ARM64:
+ return &context64{
+ State{
+ FeatureSet: fs,
+ },
+ }
+ }
+ panic(fmt.Sprintf("unknown architecture %v", arch))
+}
diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go
new file mode 100644
index 000000000..0d5b7d317
--- /dev/null
+++ b/pkg/sentry/arch/arch_arm64.go
@@ -0,0 +1,266 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arch
+
+import (
+ "fmt"
+ "math/rand"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+)
+
+// Host specifies the host architecture.
+const Host = ARM64
+
+// These constants come directly from Linux.
+const (
+ // maxAddr64 is the maximum userspace address. It is TASK_SIZE in Linux
+ // for a 64-bit process.
+ maxAddr64 usermem.Addr = (1 << 48)
+
+ // maxStackRand64 is the maximum randomization to apply to the stack.
+ // It is defined by arch/arm64/mm/mmap.c:(STACK_RND_MASK << PAGE_SHIFT) in Linux.
+ maxStackRand64 = 0x3ffff << 12 // 16 GB
+
+ // maxMmapRand64 is the maximum randomization to apply to the mmap
+ // layout. It is defined by arch/arm64/mm/mmap.c:arch_mmap_rnd in Linux.
+ maxMmapRand64 = (1 << 33) * usermem.PageSize
+
+ // minGap64 is the minimum gap to leave at the top of the address space
+ // for the stack. It is defined by arch/arm64/mm/mmap.c:MIN_GAP in Linux.
+ minGap64 = (128 << 20) + maxStackRand64
+
+ // preferredPIELoadAddr is the standard Linux position-independent
+ // executable base load address. It is ELF_ET_DYN_BASE in Linux.
+ //
+ // The Platform {Min,Max}UserAddress() may preclude loading at this
+ // address. See other preferredFoo comments below.
+ preferredPIELoadAddr usermem.Addr = maxAddr64 / 6 * 5
+)
+
+// These constants are selected as heuristics to help make the Platform's
+// potentially limited address space conform as closely to Linux as possible.
+const (
+ preferredTopDownAllocMin usermem.Addr = 0x7e8000000000
+ preferredAllocationGap = 128 << 30 // 128 GB
+ preferredTopDownBaseMin = preferredTopDownAllocMin + preferredAllocationGap
+
+ // minMmapRand64 is the smallest we are willing to make the
+ // randomization to stay above preferredTopDownBaseMin.
+ minMmapRand64 = (1 << 18) * usermem.PageSize
+)
+
+// context64 represents an ARM64 context.
+type context64 struct {
+ State
+}
+
+// Arch implements Context.Arch.
+func (c *context64) Arch() Arch {
+ return ARM64
+}
+
+// Fork returns an exact copy of this context.
+func (c *context64) Fork() Context {
+ return &context64{
+ State: c.State.Fork(),
+ }
+}
+
+// General purpose registers usage on Arm64:
+// R0...R7: parameter/result registers.
+// R8: indirect result location register.
+// R9...R15: temporary rgisters.
+// R16: the first intra-procedure-call scratch register.
+// R17: the second intra-procedure-call scratch register.
+// R18: the platform register.
+// R19...R28: callee-saved registers.
+// R29: the frame pointer.
+// R30: the link register.
+
+// Return returns the current syscall return value.
+func (c *context64) Return() uintptr {
+ return uintptr(c.Regs.Regs[0])
+}
+
+// SetReturn sets the syscall return value.
+func (c *context64) SetReturn(value uintptr) {
+ c.Regs.Regs[0] = uint64(value)
+}
+
+// IP returns the current instruction pointer.
+func (c *context64) IP() uintptr {
+ return uintptr(c.Regs.Pc)
+}
+
+// SetIP sets the current instruction pointer.
+func (c *context64) SetIP(value uintptr) {
+ c.Regs.Pc = uint64(value)
+}
+
+// Stack returns the current stack pointer.
+func (c *context64) Stack() uintptr {
+ return uintptr(c.Regs.Sp)
+}
+
+// SetStack sets the current stack pointer.
+func (c *context64) SetStack(value uintptr) {
+ c.Regs.Sp = uint64(value)
+}
+
+// TLS returns the current TLS pointer.
+func (c *context64) TLS() uintptr {
+ // TODO(gvisor.dev/issue/1238): TLS is not supported.
+ // MRS_TPIDR_EL0
+ return 0
+}
+
+// SetTLS sets the current TLS pointer. Returns false if value is invalid.
+func (c *context64) SetTLS(value uintptr) bool {
+ // TODO(gvisor.dev/issue/1238): TLS is not supported.
+ // MSR_TPIDR_EL0
+ return false
+}
+
+// SetRSEQInterruptedIP implements Context.SetRSEQInterruptedIP.
+func (c *context64) SetRSEQInterruptedIP(value uintptr) {
+ c.Regs.Regs[3] = uint64(value)
+}
+
+// Native returns the native type for the given val.
+func (c *context64) Native(val uintptr) interface{} {
+ v := uint64(val)
+ return &v
+}
+
+// Value returns the generic val for the given native type.
+func (c *context64) Value(val interface{}) uintptr {
+ return uintptr(*val.(*uint64))
+}
+
+// Width returns the byte width of this architecture.
+func (c *context64) Width() uint {
+ return 8
+}
+
+// FeatureSet returns the FeatureSet in use.
+func (c *context64) FeatureSet() *cpuid.FeatureSet {
+ return c.State.FeatureSet
+}
+
+// mmapRand returns a random adjustment for randomizing an mmap layout.
+func mmapRand(max uint64) usermem.Addr {
+ return usermem.Addr(rand.Int63n(int64(max))).RoundDown()
+}
+
+// NewMmapLayout implements Context.NewMmapLayout consistently with Linux.
+func (c *context64) NewMmapLayout(min, max usermem.Addr, r *limits.LimitSet) (MmapLayout, error) {
+ min, ok := min.RoundUp()
+ if !ok {
+ return MmapLayout{}, syscall.EINVAL
+ }
+ if max > maxAddr64 {
+ max = maxAddr64
+ }
+ max = max.RoundDown()
+
+ if min > max {
+ return MmapLayout{}, syscall.EINVAL
+ }
+
+ stackSize := r.Get(limits.Stack)
+
+ // MAX_GAP in Linux.
+ maxGap := (max / 6) * 5
+ gap := usermem.Addr(stackSize.Cur)
+ if gap < minGap64 {
+ gap = minGap64
+ }
+ if gap > maxGap {
+ gap = maxGap
+ }
+ defaultDir := MmapTopDown
+ if stackSize.Cur == limits.Infinity {
+ defaultDir = MmapBottomUp
+ }
+
+ topDownMin := max - gap - maxMmapRand64
+ maxRand := usermem.Addr(maxMmapRand64)
+ if topDownMin < preferredTopDownBaseMin {
+ // Try to keep TopDownBase above preferredTopDownBaseMin by
+ // shrinking maxRand.
+ maxAdjust := maxRand - minMmapRand64
+ needAdjust := preferredTopDownBaseMin - topDownMin
+ if needAdjust <= maxAdjust {
+ maxRand -= needAdjust
+ }
+ }
+
+ rnd := mmapRand(uint64(maxRand))
+ l := MmapLayout{
+ MinAddr: min,
+ MaxAddr: max,
+ // TASK_UNMAPPED_BASE in Linux.
+ BottomUpBase: (max/3 + rnd).RoundDown(),
+ TopDownBase: (max - gap - rnd).RoundDown(),
+ DefaultDirection: defaultDir,
+ // We may have reduced the maximum randomization to keep
+ // TopDownBase above preferredTopDownBaseMin while maintaining
+ // our stack gap. Stack allocations must use that max
+ // randomization to avoiding eating into the gap.
+ MaxStackRand: uint64(maxRand),
+ }
+
+ // Final sanity check on the layout.
+ if !l.Valid() {
+ panic(fmt.Sprintf("Invalid MmapLayout: %+v", l))
+ }
+
+ return l, nil
+}
+
+// PIELoadAddress implements Context.PIELoadAddress.
+func (c *context64) PIELoadAddress(l MmapLayout) usermem.Addr {
+ base := preferredPIELoadAddr
+ max, ok := base.AddLength(maxMmapRand64)
+ if !ok {
+ panic(fmt.Sprintf("preferredPIELoadAddr %#x too large", base))
+ }
+
+ if max > l.MaxAddr {
+ // preferredPIELoadAddr won't fit; fall back to the standard
+ // Linux behavior of 2/3 of TopDownBase. TSAN won't like this.
+ //
+ // Don't bother trying to shrink the randomization for now.
+ base = l.TopDownBase / 3 * 2
+ }
+
+ return base + mmapRand(maxMmapRand64)
+}
+
+// PtracePeekUser implements Context.PtracePeekUser.
+func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) {
+ // TODO(gvisor.dev/issue/1239): Full ptrace supporting for Arm64.
+ return c.Native(0), nil
+}
+
+// PtracePokeUser implements Context.PtracePokeUser.
+func (c *context64) PtracePokeUser(addr, data uintptr) error {
+ // TODO(gvisor.dev/issue/1239): Full ptrace supporting for Arm64.
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/mounts.go b/pkg/sentry/arch/arch_state_aarch64.go
index 8683cf677..0136a85ad 100644
--- a/pkg/sentry/fsimpl/proc/mounts.go
+++ b/pkg/sentry/arch/arch_state_aarch64.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,22 +12,27 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package proc
+// +build arm64
-import "gvisor.dev/gvisor/pkg/sentry/kernel"
+package arch
-// TODO(gvisor.dev/issue/1195): Implement mountInfoFile and mountsFile.
+import (
+ "syscall"
+)
-// mountInfoFile implements vfs.DynamicBytesSource for /proc/[pid]/mountinfo.
-//
-// +stateify savable
-type mountInfoFile struct {
- t *kernel.Task
+type syscallPtraceRegs struct {
+ Regs [31]uint64
+ Sp uint64
+ Pc uint64
+ Pstate uint64
}
-// mountsFile implements vfs.DynamicBytesSource for /proc/[pid]/mounts.
-//
-// +stateify savable
-type mountsFile struct {
- t *kernel.Task
+// saveRegs is invoked by stateify.
+func (s *State) saveRegs() syscallPtraceRegs {
+ return syscallPtraceRegs(s.Regs)
+}
+
+// loadRegs is invoked by stateify.
+func (s *State) loadRegs(r syscallPtraceRegs) {
+ s.Regs = syscall.PtraceRegs(r)
}
diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go
index 9061fcc86..84f11b0d1 100644
--- a/pkg/sentry/arch/arch_state_x86.go
+++ b/pkg/sentry/arch/arch_state_x86.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build amd64 i386
+
package arch
import (
diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go
index 9294ac773..9f41e566f 100644
--- a/pkg/sentry/arch/arch_x86.go
+++ b/pkg/sentry/arch/arch_x86.go
@@ -19,7 +19,6 @@ package arch
import (
"fmt"
"io"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/binary"
@@ -27,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/log"
rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/arch/registers.proto b/pkg/sentry/arch/registers.proto
index 9dc83e241..60c027aab 100644
--- a/pkg/sentry/arch/registers.proto
+++ b/pkg/sentry/arch/registers.proto
@@ -48,8 +48,45 @@ message AMD64Registers {
uint64 gs_base = 27;
}
+message ARM64Registers {
+ uint64 r0 = 1;
+ uint64 r1 = 2;
+ uint64 r2 = 3;
+ uint64 r3 = 4;
+ uint64 r4 = 5;
+ uint64 r5 = 6;
+ uint64 r6 = 7;
+ uint64 r7 = 8;
+ uint64 r8 = 9;
+ uint64 r9 = 10;
+ uint64 r10 = 11;
+ uint64 r11 = 12;
+ uint64 r12 = 13;
+ uint64 r13 = 14;
+ uint64 r14 = 15;
+ uint64 r15 = 16;
+ uint64 r16 = 17;
+ uint64 r17 = 18;
+ uint64 r18 = 19;
+ uint64 r19 = 20;
+ uint64 r20 = 21;
+ uint64 r21 = 22;
+ uint64 r22 = 23;
+ uint64 r23 = 24;
+ uint64 r24 = 25;
+ uint64 r25 = 26;
+ uint64 r26 = 27;
+ uint64 r27 = 28;
+ uint64 r28 = 29;
+ uint64 r29 = 30;
+ uint64 r30 = 31;
+ uint64 sp = 32;
+ uint64 pc = 33;
+ uint64 pstate = 34;
+}
message Registers {
oneof arch {
AMD64Registers amd64 = 1;
+ ARM64Registers arm64 = 2;
}
}
diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go
new file mode 100644
index 000000000..402e46025
--- /dev/null
+++ b/pkg/sentry/arch/signal.go
@@ -0,0 +1,250 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arch
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+)
+
+// SignalAct represents the action that should be taken when a signal is
+// delivered, and is equivalent to struct sigaction.
+//
+// +stateify savable
+type SignalAct struct {
+ Handler uint64
+ Flags uint64
+ Restorer uint64 // Only used on amd64.
+ Mask linux.SignalSet
+}
+
+// SerializeFrom implements NativeSignalAct.SerializeFrom.
+func (s *SignalAct) SerializeFrom(other *SignalAct) {
+ *s = *other
+}
+
+// DeserializeTo implements NativeSignalAct.DeserializeTo.
+func (s *SignalAct) DeserializeTo(other *SignalAct) {
+ *other = *s
+}
+
+// SignalStack represents information about a user stack, and is equivalent to
+// stack_t.
+//
+// +stateify savable
+type SignalStack struct {
+ Addr uint64
+ Flags uint32
+ _ uint32
+ Size uint64
+}
+
+// SerializeFrom implements NativeSignalStack.SerializeFrom.
+func (s *SignalStack) SerializeFrom(other *SignalStack) {
+ *s = *other
+}
+
+// DeserializeTo implements NativeSignalStack.DeserializeTo.
+func (s *SignalStack) DeserializeTo(other *SignalStack) {
+ *other = *s
+}
+
+// SignalInfo represents information about a signal being delivered, and is
+// equivalent to struct siginfo in linux kernel(linux/include/uapi/asm-generic/siginfo.h).
+//
+// +stateify savable
+type SignalInfo struct {
+ Signo int32 // Signal number
+ Errno int32 // Errno value
+ Code int32 // Signal code
+ _ uint32
+
+ // struct siginfo::_sifields is a union. In SignalInfo, fields in the union
+ // are accessed through methods.
+ //
+ // For reference, here is the definition of _sifields: (_sigfault._trapno,
+ // which does not exist on x86, omitted for clarity)
+ //
+ // union {
+ // int _pad[SI_PAD_SIZE];
+ //
+ // /* kill() */
+ // struct {
+ // __kernel_pid_t _pid; /* sender's pid */
+ // __ARCH_SI_UID_T _uid; /* sender's uid */
+ // } _kill;
+ //
+ // /* POSIX.1b timers */
+ // struct {
+ // __kernel_timer_t _tid; /* timer id */
+ // int _overrun; /* overrun count */
+ // char _pad[sizeof( __ARCH_SI_UID_T) - sizeof(int)];
+ // sigval_t _sigval; /* same as below */
+ // int _sys_private; /* not to be passed to user */
+ // } _timer;
+ //
+ // /* POSIX.1b signals */
+ // struct {
+ // __kernel_pid_t _pid; /* sender's pid */
+ // __ARCH_SI_UID_T _uid; /* sender's uid */
+ // sigval_t _sigval;
+ // } _rt;
+ //
+ // /* SIGCHLD */
+ // struct {
+ // __kernel_pid_t _pid; /* which child */
+ // __ARCH_SI_UID_T _uid; /* sender's uid */
+ // int _status; /* exit code */
+ // __ARCH_SI_CLOCK_T _utime;
+ // __ARCH_SI_CLOCK_T _stime;
+ // } _sigchld;
+ //
+ // /* SIGILL, SIGFPE, SIGSEGV, SIGBUS */
+ // struct {
+ // void *_addr; /* faulting insn/memory ref. */
+ // short _addr_lsb; /* LSB of the reported address */
+ // } _sigfault;
+ //
+ // /* SIGPOLL */
+ // struct {
+ // __ARCH_SI_BAND_T _band; /* POLL_IN, POLL_OUT, POLL_MSG */
+ // int _fd;
+ // } _sigpoll;
+ //
+ // /* SIGSYS */
+ // struct {
+ // void *_call_addr; /* calling user insn */
+ // int _syscall; /* triggering system call number */
+ // unsigned int _arch; /* AUDIT_ARCH_* of syscall */
+ // } _sigsys;
+ // } _sifields;
+ //
+ // _sifields is padded so that the size of siginfo is SI_MAX_SIZE = 128
+ // bytes.
+ Fields [128 - 16]byte
+}
+
+// FixSignalCodeForUser fixes up si_code.
+//
+// The si_code we get from Linux may contain the kernel-specific code in the
+// top 16 bits if it's positive (e.g., from ptrace). Linux's
+// copy_siginfo_to_user does
+// err |= __put_user((short)from->si_code, &to->si_code);
+// to mask out those bits and we need to do the same.
+func (s *SignalInfo) FixSignalCodeForUser() {
+ if s.Code > 0 {
+ s.Code &= 0x0000ffff
+ }
+}
+
+// Pid returns the si_pid field.
+func (s *SignalInfo) Pid() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[0:4]))
+}
+
+// SetPid mutates the si_pid field.
+func (s *SignalInfo) SetPid(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
+}
+
+// Uid returns the si_uid field.
+func (s *SignalInfo) Uid() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[4:8]))
+}
+
+// SetUid mutates the si_uid field.
+func (s *SignalInfo) SetUid(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
+}
+
+// Sigval returns the sigval field, which is aliased to both si_int and si_ptr.
+func (s *SignalInfo) Sigval() uint64 {
+ return usermem.ByteOrder.Uint64(s.Fields[8:16])
+}
+
+// SetSigval mutates the sigval field.
+func (s *SignalInfo) SetSigval(val uint64) {
+ usermem.ByteOrder.PutUint64(s.Fields[8:16], val)
+}
+
+// TimerID returns the si_timerid field.
+func (s *SignalInfo) TimerID() linux.TimerID {
+ return linux.TimerID(usermem.ByteOrder.Uint32(s.Fields[0:4]))
+}
+
+// SetTimerID sets the si_timerid field.
+func (s *SignalInfo) SetTimerID(val linux.TimerID) {
+ usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
+}
+
+// Overrun returns the si_overrun field.
+func (s *SignalInfo) Overrun() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[4:8]))
+}
+
+// SetOverrun sets the si_overrun field.
+func (s *SignalInfo) SetOverrun(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
+}
+
+// Addr returns the si_addr field.
+func (s *SignalInfo) Addr() uint64 {
+ return usermem.ByteOrder.Uint64(s.Fields[0:8])
+}
+
+// SetAddr sets the si_addr field.
+func (s *SignalInfo) SetAddr(val uint64) {
+ usermem.ByteOrder.PutUint64(s.Fields[0:8], val)
+}
+
+// Status returns the si_status field.
+func (s *SignalInfo) Status() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[8:12]))
+}
+
+// SetStatus mutates the si_status field.
+func (s *SignalInfo) SetStatus(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
+}
+
+// CallAddr returns the si_call_addr field.
+func (s *SignalInfo) CallAddr() uint64 {
+ return usermem.ByteOrder.Uint64(s.Fields[0:8])
+}
+
+// SetCallAddr mutates the si_call_addr field.
+func (s *SignalInfo) SetCallAddr(val uint64) {
+ usermem.ByteOrder.PutUint64(s.Fields[0:8], val)
+}
+
+// Syscall returns the si_syscall field.
+func (s *SignalInfo) Syscall() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[8:12]))
+}
+
+// SetSyscall mutates the si_syscall field.
+func (s *SignalInfo) SetSyscall(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
+}
+
+// Arch returns the si_arch field.
+func (s *SignalInfo) Arch() uint32 {
+ return usermem.ByteOrder.Uint32(s.Fields[12:16])
+}
+
+// SetArch mutates the si_arch field.
+func (s *SignalInfo) SetArch(val uint32) {
+ usermem.ByteOrder.PutUint32(s.Fields[12:16], val)
+}
diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go
index febd6f9b9..1e4f9c3c2 100644
--- a/pkg/sentry/arch/signal_amd64.go
+++ b/pkg/sentry/arch/signal_amd64.go
@@ -26,236 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/usermem"
)
-// SignalAct represents the action that should be taken when a signal is
-// delivered, and is equivalent to struct sigaction on 64-bit x86.
-//
-// +stateify savable
-type SignalAct struct {
- Handler uint64
- Flags uint64
- Restorer uint64
- Mask linux.SignalSet
-}
-
-// SerializeFrom implements NativeSignalAct.SerializeFrom.
-func (s *SignalAct) SerializeFrom(other *SignalAct) {
- *s = *other
-}
-
-// DeserializeTo implements NativeSignalAct.DeserializeTo.
-func (s *SignalAct) DeserializeTo(other *SignalAct) {
- *other = *s
-}
-
-// SignalStack represents information about a user stack, and is equivalent to
-// stack_t on 64-bit x86.
-//
-// +stateify savable
-type SignalStack struct {
- Addr uint64
- Flags uint32
- _ uint32
- Size uint64
-}
-
-// SerializeFrom implements NativeSignalStack.SerializeFrom.
-func (s *SignalStack) SerializeFrom(other *SignalStack) {
- *s = *other
-}
-
-// DeserializeTo implements NativeSignalStack.DeserializeTo.
-func (s *SignalStack) DeserializeTo(other *SignalStack) {
- *other = *s
-}
-
-// SignalInfo represents information about a signal being delivered, and is
-// equivalent to struct siginfo on 64-bit x86.
-//
-// +stateify savable
-type SignalInfo struct {
- Signo int32 // Signal number
- Errno int32 // Errno value
- Code int32 // Signal code
- _ uint32
-
- // struct siginfo::_sifields is a union. In SignalInfo, fields in the union
- // are accessed through methods.
- //
- // For reference, here is the definition of _sifields: (_sigfault._trapno,
- // which does not exist on x86, omitted for clarity)
- //
- // union {
- // int _pad[SI_PAD_SIZE];
- //
- // /* kill() */
- // struct {
- // __kernel_pid_t _pid; /* sender's pid */
- // __ARCH_SI_UID_T _uid; /* sender's uid */
- // } _kill;
- //
- // /* POSIX.1b timers */
- // struct {
- // __kernel_timer_t _tid; /* timer id */
- // int _overrun; /* overrun count */
- // char _pad[sizeof( __ARCH_SI_UID_T) - sizeof(int)];
- // sigval_t _sigval; /* same as below */
- // int _sys_private; /* not to be passed to user */
- // } _timer;
- //
- // /* POSIX.1b signals */
- // struct {
- // __kernel_pid_t _pid; /* sender's pid */
- // __ARCH_SI_UID_T _uid; /* sender's uid */
- // sigval_t _sigval;
- // } _rt;
- //
- // /* SIGCHLD */
- // struct {
- // __kernel_pid_t _pid; /* which child */
- // __ARCH_SI_UID_T _uid; /* sender's uid */
- // int _status; /* exit code */
- // __ARCH_SI_CLOCK_T _utime;
- // __ARCH_SI_CLOCK_T _stime;
- // } _sigchld;
- //
- // /* SIGILL, SIGFPE, SIGSEGV, SIGBUS */
- // struct {
- // void *_addr; /* faulting insn/memory ref. */
- // short _addr_lsb; /* LSB of the reported address */
- // } _sigfault;
- //
- // /* SIGPOLL */
- // struct {
- // __ARCH_SI_BAND_T _band; /* POLL_IN, POLL_OUT, POLL_MSG */
- // int _fd;
- // } _sigpoll;
- //
- // /* SIGSYS */
- // struct {
- // void *_call_addr; /* calling user insn */
- // int _syscall; /* triggering system call number */
- // unsigned int _arch; /* AUDIT_ARCH_* of syscall */
- // } _sigsys;
- // } _sifields;
- //
- // _sifields is padded so that the size of siginfo is SI_MAX_SIZE = 128
- // bytes.
- Fields [128 - 16]byte
-}
-
-// FixSignalCodeForUser fixes up si_code.
-//
-// The si_code we get from Linux may contain the kernel-specific code in the
-// top 16 bits if it's positive (e.g., from ptrace). Linux's
-// copy_siginfo_to_user does
-// err |= __put_user((short)from->si_code, &to->si_code);
-// to mask out those bits and we need to do the same.
-func (s *SignalInfo) FixSignalCodeForUser() {
- if s.Code > 0 {
- s.Code &= 0x0000ffff
- }
-}
-
-// Pid returns the si_pid field.
-func (s *SignalInfo) Pid() int32 {
- return int32(usermem.ByteOrder.Uint32(s.Fields[0:4]))
-}
-
-// SetPid mutates the si_pid field.
-func (s *SignalInfo) SetPid(val int32) {
- usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
-}
-
-// Uid returns the si_uid field.
-func (s *SignalInfo) Uid() int32 {
- return int32(usermem.ByteOrder.Uint32(s.Fields[4:8]))
-}
-
-// SetUid mutates the si_uid field.
-func (s *SignalInfo) SetUid(val int32) {
- usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
-}
-
-// Sigval returns the sigval field, which is aliased to both si_int and si_ptr.
-func (s *SignalInfo) Sigval() uint64 {
- return usermem.ByteOrder.Uint64(s.Fields[8:16])
-}
-
-// SetSigval mutates the sigval field.
-func (s *SignalInfo) SetSigval(val uint64) {
- usermem.ByteOrder.PutUint64(s.Fields[8:16], val)
-}
-
-// TimerID returns the si_timerid field.
-func (s *SignalInfo) TimerID() linux.TimerID {
- return linux.TimerID(usermem.ByteOrder.Uint32(s.Fields[0:4]))
-}
-
-// SetTimerID sets the si_timerid field.
-func (s *SignalInfo) SetTimerID(val linux.TimerID) {
- usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
-}
-
-// Overrun returns the si_overrun field.
-func (s *SignalInfo) Overrun() int32 {
- return int32(usermem.ByteOrder.Uint32(s.Fields[4:8]))
-}
-
-// SetOverrun sets the si_overrun field.
-func (s *SignalInfo) SetOverrun(val int32) {
- usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
-}
-
-// Addr returns the si_addr field.
-func (s *SignalInfo) Addr() uint64 {
- return usermem.ByteOrder.Uint64(s.Fields[0:8])
-}
-
-// SetAddr sets the si_addr field.
-func (s *SignalInfo) SetAddr(val uint64) {
- usermem.ByteOrder.PutUint64(s.Fields[0:8], val)
-}
-
-// Status returns the si_status field.
-func (s *SignalInfo) Status() int32 {
- return int32(usermem.ByteOrder.Uint32(s.Fields[8:12]))
-}
-
-// SetStatus mutates the si_status field.
-func (s *SignalInfo) SetStatus(val int32) {
- usermem.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
-}
-
-// CallAddr returns the si_call_addr field.
-func (s *SignalInfo) CallAddr() uint64 {
- return usermem.ByteOrder.Uint64(s.Fields[0:8])
-}
-
-// SetCallAddr mutates the si_call_addr field.
-func (s *SignalInfo) SetCallAddr(val uint64) {
- usermem.ByteOrder.PutUint64(s.Fields[0:8], val)
-}
-
-// Syscall returns the si_syscall field.
-func (s *SignalInfo) Syscall() int32 {
- return int32(usermem.ByteOrder.Uint32(s.Fields[8:12]))
-}
-
-// SetSyscall mutates the si_syscall field.
-func (s *SignalInfo) SetSyscall(val int32) {
- usermem.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
-}
-
-// Arch returns the si_arch field.
-func (s *SignalInfo) Arch() uint32 {
- return usermem.ByteOrder.Uint32(s.Fields[12:16])
-}
-
-// SetArch mutates the si_arch field.
-func (s *SignalInfo) SetArch(val uint32) {
- usermem.ByteOrder.PutUint32(s.Fields[12:16], val)
-}
-
// SignalContext64 is equivalent to struct sigcontext, the type passed as the
// second argument to signal handlers set by signal(2).
type SignalContext64 struct {
diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go
new file mode 100644
index 000000000..7d0e98935
--- /dev/null
+++ b/pkg/sentry/arch/signal_arm64.go
@@ -0,0 +1,126 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arch
+
+import (
+ "encoding/binary"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+)
+
+// SignalContext64 is equivalent to struct sigcontext, the type passed as the
+// second argument to signal handlers set by signal(2).
+type SignalContext64 struct {
+ FaultAddr uint64
+ Regs [31]uint64
+ Sp uint64
+ Pc uint64
+ Pstate uint64
+ _pad [8]byte // __attribute__((__aligned__(16)))
+ Reserved [4096]uint8
+}
+
+// UContext64 is equivalent to ucontext on arm64(arch/arm64/include/uapi/asm/ucontext.h).
+type UContext64 struct {
+ Flags uint64
+ Link *UContext64
+ Stack SignalStack
+ Sigset linux.SignalSet
+ // glibc uses a 1024-bit sigset_t
+ _pad [(1024 - 64) / 8]byte
+ // sigcontext must be aligned to 16-byte
+ _pad2 [8]byte
+ // last for future expansion
+ MContext SignalContext64
+}
+
+// NewSignalAct implements Context.NewSignalAct.
+func (c *context64) NewSignalAct() NativeSignalAct {
+ return &SignalAct{}
+}
+
+// NewSignalStack implements Context.NewSignalStack.
+func (c *context64) NewSignalStack() NativeSignalStack {
+ return &SignalStack{}
+}
+
+// SignalSetup implements Context.SignalSetup.
+func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt *SignalStack, sigset linux.SignalSet) error {
+ sp := st.Bottom
+
+ if !(alt.IsEnabled() && sp == alt.Top()) {
+ sp -= 128
+ }
+
+ // Construct the UContext64 now since we need its size.
+ uc := &UContext64{
+ Flags: 0,
+ Stack: *alt,
+ MContext: SignalContext64{
+ Regs: c.Regs.Regs,
+ Sp: c.Regs.Sp,
+ Pc: c.Regs.Pc,
+ Pstate: c.Regs.Pstate,
+ },
+ Sigset: sigset,
+ }
+
+ ucSize := binary.Size(uc)
+ if ucSize < 0 {
+ panic("can't get size of UContext64")
+ }
+ // st.Arch.Width() is for the restorer address. sizeof(siginfo) == 128.
+ frameSize := int(st.Arch.Width()) + ucSize + 128
+ frameBottom := (sp-usermem.Addr(frameSize)) & ^usermem.Addr(15) - 8
+ sp = frameBottom + usermem.Addr(frameSize)
+ st.Bottom = sp
+
+ // Prior to proceeding, figure out if the frame will exhaust the range
+ // for the signal stack. This is not allowed, and should immediately
+ // force signal delivery (reverting to the default handler).
+ if act.IsOnStack() && alt.IsEnabled() && !alt.Contains(frameBottom) {
+ return syscall.EFAULT
+ }
+
+ // Adjust the code.
+ info.FixSignalCodeForUser()
+
+ // Set up the stack frame.
+ infoAddr, err := st.Push(info)
+ if err != nil {
+ return err
+ }
+ ucAddr, err := st.Push(uc)
+ if err != nil {
+ return err
+ }
+
+ // Set up registers.
+ c.Regs.Sp = uint64(st.Bottom)
+ c.Regs.Pc = act.Handler
+ c.Regs.Regs[0] = uint64(info.Signo)
+ c.Regs.Regs[1] = uint64(infoAddr)
+ c.Regs.Regs[2] = uint64(ucAddr)
+
+ return nil
+}
+
+// SignalRestore implements Context.SignalRestore.
+// Only used on intel.
+func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) {
+ return 0, SignalStack{}, nil
+}
diff --git a/pkg/sentry/arch/signal_stack.go b/pkg/sentry/arch/signal_stack.go
index 5a3228113..d324da705 100644
--- a/pkg/sentry/arch/signal_stack.go
+++ b/pkg/sentry/arch/signal_stack.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build i386 amd64
+// +build i386 amd64 arm64
package arch
diff --git a/pkg/sentry/arch/syscalls_arm64.go b/pkg/sentry/arch/syscalls_arm64.go
new file mode 100644
index 000000000..00d5ef461
--- /dev/null
+++ b/pkg/sentry/arch/syscalls_arm64.go
@@ -0,0 +1,62 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package arch
+
+const restartSyscallNr = uintptr(128)
+
+// SyscallNo returns the syscall number according to the 64-bit convention.
+func (c *context64) SyscallNo() uintptr {
+ return uintptr(c.Regs.Regs[8])
+}
+
+// SyscallArgs provides syscall arguments according to the 64-bit convention.
+//
+// Due to the way addresses are mapped for the sentry this binary *must* be
+// built in 64-bit mode. So we can just assume the syscall numbers that come
+// back match the expected host system call numbers.
+// General purpose registers usage on Arm64:
+// R0...R7: parameter/result registers.
+// R8: indirect result location register.
+// R9...R15: temporary registers.
+// R16: the first intra-procedure-call scratch register.
+// R17: the second intra-procedure-call scratch register.
+// R18: the platform register.
+// R19...R28: callee-saved registers.
+// R29: the frame pointer.
+// R30: the link register.
+func (c *context64) SyscallArgs() SyscallArguments {
+ return SyscallArguments{
+ SyscallArgument{Value: uintptr(c.Regs.Regs[0])},
+ SyscallArgument{Value: uintptr(c.Regs.Regs[1])},
+ SyscallArgument{Value: uintptr(c.Regs.Regs[2])},
+ SyscallArgument{Value: uintptr(c.Regs.Regs[3])},
+ SyscallArgument{Value: uintptr(c.Regs.Regs[4])},
+ SyscallArgument{Value: uintptr(c.Regs.Regs[5])},
+ }
+}
+
+// RestartSyscall implements Context.RestartSyscall.
+func (c *context64) RestartSyscall() {
+ c.Regs.Pc -= SyscallWidth
+ c.Regs.Regs[8] = uint64(restartSyscallNr)
+}
+
+// RestartSyscallWithRestartBlock implements Context.RestartSyscallWithRestartBlock.
+func (c *context64) RestartSyscallWithRestartBlock() {
+ c.Regs.Pc -= SyscallWidth
+ c.Regs.Regs[8] = uint64(restartSyscallNr)
+}
diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD
index 5522cecd0..2561a6109 100644
--- a/pkg/sentry/control/BUILD
+++ b/pkg/sentry/control/BUILD
@@ -30,6 +30,7 @@ go_library(
"//pkg/sentry/strace",
"//pkg/sentry/usage",
"//pkg/sentry/watchdog",
+ "//pkg/sync",
"//pkg/tcpip/link/sniffer",
"//pkg/urpc",
],
diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go
index e1f2fea60..151808911 100644
--- a/pkg/sentry/control/pprof.go
+++ b/pkg/sentry/control/pprof.go
@@ -19,10 +19,10 @@ import (
"runtime"
"runtime/pprof"
"runtime/trace"
- "sync"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/urpc"
)
diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD
index 1098ed777..97fa1512c 100644
--- a/pkg/sentry/device/BUILD
+++ b/pkg/sentry/device/BUILD
@@ -8,7 +8,10 @@ go_library(
srcs = ["device.go"],
importpath = "gvisor.dev/gvisor/pkg/sentry/device",
visibility = ["//pkg/sentry:internal"],
- deps = ["//pkg/abi/linux"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sync",
+ ],
)
go_test(
diff --git a/pkg/sentry/device/device.go b/pkg/sentry/device/device.go
index 47945d1a7..69e71e322 100644
--- a/pkg/sentry/device/device.go
+++ b/pkg/sentry/device/device.go
@@ -19,10 +19,10 @@ package device
import (
"bytes"
"fmt"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Registry tracks all simple devices and related state on the system for
diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD
index c035ffff7..7d5d72d5a 100644
--- a/pkg/sentry/fs/BUILD
+++ b/pkg/sentry/fs/BUILD
@@ -68,7 +68,7 @@ go_library(
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
"//pkg/state",
- "//pkg/syncutil",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
@@ -115,6 +115,7 @@ go_test(
"//pkg/sentry/fs/tmpfs",
"//pkg/sentry/kernel/contexttest",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
],
)
diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go
index 9ac62c84d..e03e3e417 100644
--- a/pkg/sentry/fs/copy_up.go
+++ b/pkg/sentry/fs/copy_up.go
@@ -17,12 +17,13 @@ package fs
import (
"fmt"
"io"
- "sync"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -395,12 +396,12 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in
// Size and permissions are set on upper when the file content is copied
// and when the file is created respectively.
func copyAttributesLocked(ctx context.Context, upper *Inode, lower *Inode) error {
- // Extract attributes fro the lower filesystem.
+ // Extract attributes from the lower filesystem.
lowerAttr, err := lower.UnstableAttr(ctx)
if err != nil {
return err
}
- lowerXattr, err := lower.Listxattr()
+ lowerXattr, err := lower.ListXattr(ctx)
if err != nil && err != syserror.EOPNOTSUPP {
return err
}
@@ -421,11 +422,11 @@ func copyAttributesLocked(ctx context.Context, upper *Inode, lower *Inode) error
if isXattrOverlay(name) {
continue
}
- value, err := lower.Getxattr(name)
+ value, err := lower.GetXattr(ctx, name, linux.XATTR_SIZE_MAX)
if err != nil {
return err
}
- if err := upper.InodeOperations.Setxattr(upper, name, value); err != nil {
+ if err := upper.InodeOperations.SetXattr(ctx, upper, name, value, 0 /* flags */); err != nil {
return err
}
}
diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go
index 1d80bf15a..738580c5f 100644
--- a/pkg/sentry/fs/copy_up_test.go
+++ b/pkg/sentry/fs/copy_up_test.go
@@ -19,13 +19,13 @@ import (
"crypto/rand"
"fmt"
"io"
- "sync"
"testing"
"gvisor.dev/gvisor/pkg/sentry/fs"
_ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
)
const (
diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go
index 3cb73bd78..31fc4d87b 100644
--- a/pkg/sentry/fs/dirent.go
+++ b/pkg/sentry/fs/dirent.go
@@ -18,7 +18,6 @@ import (
"fmt"
"path"
"sort"
- "sync"
"sync/atomic"
"syscall"
@@ -28,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/fs/dirent_cache.go b/pkg/sentry/fs/dirent_cache.go
index 60a15a275..25514ace4 100644
--- a/pkg/sentry/fs/dirent_cache.go
+++ b/pkg/sentry/fs/dirent_cache.go
@@ -16,7 +16,8 @@ package fs
import (
"fmt"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// DirentCache is an LRU cache of Dirents. The Dirent's refCount is
diff --git a/pkg/sentry/fs/dirent_cache_limiter.go b/pkg/sentry/fs/dirent_cache_limiter.go
index ebb80bd50..525ee25f9 100644
--- a/pkg/sentry/fs/dirent_cache_limiter.go
+++ b/pkg/sentry/fs/dirent_cache_limiter.go
@@ -16,7 +16,8 @@ package fs
import (
"fmt"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// DirentCacheLimiter acts as a global limit for all dirent caches in the
diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD
index 277ee4c31..cc43de69d 100644
--- a/pkg/sentry/fs/fdpipe/BUILD
+++ b/pkg/sentry/fs/fdpipe/BUILD
@@ -23,6 +23,7 @@ go_library(
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/safemem",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/fdpipe/pipe.go b/pkg/sentry/fs/fdpipe/pipe.go
index 669ffcb75..5b6cfeb0a 100644
--- a/pkg/sentry/fs/fdpipe/pipe.go
+++ b/pkg/sentry/fs/fdpipe/pipe.go
@@ -17,7 +17,6 @@ package fdpipe
import (
"os"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/fd"
@@ -29,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/fs/fdpipe/pipe_state.go b/pkg/sentry/fs/fdpipe/pipe_state.go
index 29175fb3d..cee87f726 100644
--- a/pkg/sentry/fs/fdpipe/pipe_state.go
+++ b/pkg/sentry/fs/fdpipe/pipe_state.go
@@ -17,10 +17,10 @@ package fdpipe
import (
"fmt"
"io/ioutil"
- "sync"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sync"
)
// beforeSave is invoked by stateify.
diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go
index a2f966cb6..7c4586296 100644
--- a/pkg/sentry/fs/file.go
+++ b/pkg/sentry/fs/file.go
@@ -16,7 +16,6 @@ package fs
import (
"math"
- "sync"
"sync/atomic"
"time"
@@ -29,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go
index 225e40186..8991207b4 100644
--- a/pkg/sentry/fs/file_overlay.go
+++ b/pkg/sentry/fs/file_overlay.go
@@ -16,13 +16,13 @@ package fs
import (
"io"
- "sync"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -475,7 +475,7 @@ func readdirEntries(ctx context.Context, o *overlayEntry) (*SortedDentryMap, err
// Skip this name if it is a negative entry in the
// upper or there exists a whiteout for it.
if o.upper != nil {
- if overlayHasWhiteout(o.upper, name) {
+ if overlayHasWhiteout(ctx, o.upper, name) {
continue
}
}
diff --git a/pkg/sentry/fs/filesystems.go b/pkg/sentry/fs/filesystems.go
index b157fd228..c5b51620a 100644
--- a/pkg/sentry/fs/filesystems.go
+++ b/pkg/sentry/fs/filesystems.go
@@ -18,9 +18,9 @@ import (
"fmt"
"sort"
"strings"
- "sync"
"gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sync"
)
// FilesystemFlags matches include/linux/fs.h:file_system_type.fs_flags.
diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go
index 8b2a5e6b2..26abf49e2 100644
--- a/pkg/sentry/fs/fs.go
+++ b/pkg/sentry/fs/fs.go
@@ -54,10 +54,9 @@
package fs
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sync"
)
var (
diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD
index 9ca695a95..945b6270d 100644
--- a/pkg/sentry/fs/fsutil/BUILD
+++ b/pkg/sentry/fs/fsutil/BUILD
@@ -93,6 +93,7 @@ go_library(
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
"//pkg/state",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go
index b06a71cc2..837fc70b5 100644
--- a/pkg/sentry/fs/fsutil/host_file_mapper.go
+++ b/pkg/sentry/fs/fsutil/host_file_mapper.go
@@ -16,7 +16,6 @@ package fsutil
import (
"fmt"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/log"
@@ -24,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
)
// HostFileMapper caches mappings of an arbitrary host file descriptor. It is
diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go
index 30475f340..a625f0e26 100644
--- a/pkg/sentry/fs/fsutil/host_mappable.go
+++ b/pkg/sentry/fs/fsutil/host_mappable.go
@@ -16,7 +16,6 @@ package fsutil
import (
"math"
- "sync"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -24,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
)
// HostMappable implements memmap.Mappable and platform.File over a
diff --git a/pkg/sentry/fs/fsutil/inode.go b/pkg/sentry/fs/fsutil/inode.go
index 4e100a402..df7b74855 100644
--- a/pkg/sentry/fs/fsutil/inode.go
+++ b/pkg/sentry/fs/fsutil/inode.go
@@ -15,13 +15,13 @@
package fsutil
import (
- "sync"
-
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -203,7 +203,7 @@ func (i *InodeSimpleAttributes) NotifyModificationAndStatusChange(ctx context.Co
}
// InodeSimpleExtendedAttributes implements
-// fs.InodeOperations.{Get,Set,List}xattr.
+// fs.InodeOperations.{Get,Set,List}Xattr.
//
// +stateify savable
type InodeSimpleExtendedAttributes struct {
@@ -212,8 +212,8 @@ type InodeSimpleExtendedAttributes struct {
xattrs map[string]string
}
-// Getxattr implements fs.InodeOperations.Getxattr.
-func (i *InodeSimpleExtendedAttributes) Getxattr(_ *fs.Inode, name string) (string, error) {
+// GetXattr implements fs.InodeOperations.GetXattr.
+func (i *InodeSimpleExtendedAttributes) GetXattr(_ context.Context, _ *fs.Inode, name string, _ uint64) (string, error) {
i.mu.RLock()
value, ok := i.xattrs[name]
i.mu.RUnlock()
@@ -223,19 +223,31 @@ func (i *InodeSimpleExtendedAttributes) Getxattr(_ *fs.Inode, name string) (stri
return value, nil
}
-// Setxattr implements fs.InodeOperations.Setxattr.
-func (i *InodeSimpleExtendedAttributes) Setxattr(_ *fs.Inode, name, value string) error {
+// SetXattr implements fs.InodeOperations.SetXattr.
+func (i *InodeSimpleExtendedAttributes) SetXattr(_ context.Context, _ *fs.Inode, name, value string, flags uint32) error {
i.mu.Lock()
+ defer i.mu.Unlock()
if i.xattrs == nil {
+ if flags&linux.XATTR_REPLACE != 0 {
+ return syserror.ENODATA
+ }
i.xattrs = make(map[string]string)
}
+
+ _, ok := i.xattrs[name]
+ if ok && flags&linux.XATTR_CREATE != 0 {
+ return syserror.EEXIST
+ }
+ if !ok && flags&linux.XATTR_REPLACE != 0 {
+ return syserror.ENODATA
+ }
+
i.xattrs[name] = value
- i.mu.Unlock()
return nil
}
-// Listxattr implements fs.InodeOperations.Listxattr.
-func (i *InodeSimpleExtendedAttributes) Listxattr(_ *fs.Inode) (map[string]struct{}, error) {
+// ListXattr implements fs.InodeOperations.ListXattr.
+func (i *InodeSimpleExtendedAttributes) ListXattr(context.Context, *fs.Inode) (map[string]struct{}, error) {
i.mu.RLock()
names := make(map[string]struct{}, len(i.xattrs))
for name := range i.xattrs {
@@ -437,18 +449,18 @@ func (InodeNotSymlink) Getlink(context.Context, *fs.Inode) (*fs.Dirent, error) {
// extended attributes.
type InodeNoExtendedAttributes struct{}
-// Getxattr implements fs.InodeOperations.Getxattr.
-func (InodeNoExtendedAttributes) Getxattr(*fs.Inode, string) (string, error) {
+// GetXattr implements fs.InodeOperations.GetXattr.
+func (InodeNoExtendedAttributes) GetXattr(context.Context, *fs.Inode, string, uint64) (string, error) {
return "", syserror.EOPNOTSUPP
}
-// Setxattr implements fs.InodeOperations.Setxattr.
-func (InodeNoExtendedAttributes) Setxattr(*fs.Inode, string, string) error {
+// SetXattr implements fs.InodeOperations.SetXattr.
+func (InodeNoExtendedAttributes) SetXattr(context.Context, *fs.Inode, string, string, uint32) error {
return syserror.EOPNOTSUPP
}
-// Listxattr implements fs.InodeOperations.Listxattr.
-func (InodeNoExtendedAttributes) Listxattr(*fs.Inode) (map[string]struct{}, error) {
+// ListXattr implements fs.InodeOperations.ListXattr.
+func (InodeNoExtendedAttributes) ListXattr(context.Context, *fs.Inode) (map[string]struct{}, error) {
return nil, syserror.EOPNOTSUPP
}
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
index 798920d18..20a014402 100644
--- a/pkg/sentry/fs/fsutil/inode_cached.go
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -17,7 +17,6 @@ package fsutil
import (
"fmt"
"io"
- "sync"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -30,6 +29,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Lock order (compare the lock order model in mm/mm.go):
diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD
index 4a005c605..fd870e8e1 100644
--- a/pkg/sentry/fs/gofer/BUILD
+++ b/pkg/sentry/fs/gofer/BUILD
@@ -44,6 +44,7 @@ go_library(
"//pkg/sentry/safemem",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/unet",
diff --git a/pkg/sentry/fs/gofer/context_file.go b/pkg/sentry/fs/gofer/context_file.go
index 44b72582a..2125dafef 100644
--- a/pkg/sentry/fs/gofer/context_file.go
+++ b/pkg/sentry/fs/gofer/context_file.go
@@ -59,6 +59,20 @@ func (c *contextFile) setAttr(ctx context.Context, valid p9.SetAttrMask, attr p9
return err
}
+func (c *contextFile) getXattr(ctx context.Context, name string, size uint64) (string, error) {
+ ctx.UninterruptibleSleepStart(false)
+ val, err := c.file.GetXattr(name, size)
+ ctx.UninterruptibleSleepFinish(false)
+ return val, err
+}
+
+func (c *contextFile) setXattr(ctx context.Context, name, value string, flags uint32) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := c.file.SetXattr(name, value, flags)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
func (c *contextFile) allocate(ctx context.Context, mode p9.AllocateMode, offset, length uint64) error {
ctx.UninterruptibleSleepStart(false)
err := c.file.Allocate(mode, offset, length)
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index 91263ebdc..98d1a8a48 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -16,7 +16,6 @@ package gofer
import (
"errors"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -31,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/host"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/safemem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -38,8 +38,7 @@ import (
//
// +stateify savable
type inodeOperations struct {
- fsutil.InodeNotVirtual `state:"nosave"`
- fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNotVirtual `state:"nosave"`
// fileState implements fs.CachedFileObject. It exists
// to break a circular load dependency between inodeOperations
@@ -604,6 +603,21 @@ func (i *inodeOperations) Truncate(ctx context.Context, inode *fs.Inode, length
return i.fileState.file.setAttr(ctx, p9.SetAttrMask{Size: true}, p9.SetAttr{Size: uint64(length)})
}
+// GetXattr implements fs.InodeOperations.GetXattr.
+func (i *inodeOperations) GetXattr(ctx context.Context, inode *fs.Inode, name string, size uint64) (string, error) {
+ return i.fileState.file.getXattr(ctx, name, size)
+}
+
+// SetXattr implements fs.InodeOperations.SetXattr.
+func (i *inodeOperations) SetXattr(ctx context.Context, inode *fs.Inode, name string, value string, flags uint32) error {
+ return i.fileState.file.setXattr(ctx, name, value, flags)
+}
+
+// ListXattr implements fs.InodeOperations.ListXattr.
+func (i *inodeOperations) ListXattr(context.Context, *fs.Inode) (map[string]struct{}, error) {
+ return nil, syscall.EOPNOTSUPP
+}
+
// Allocate implements fs.InodeOperations.Allocate.
func (i *inodeOperations) Allocate(ctx context.Context, inode *fs.Inode, offset, length int64) error {
// This can only be called for files anyway.
diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go
index 4e358a46a..edc796ce0 100644
--- a/pkg/sentry/fs/gofer/session.go
+++ b/pkg/sentry/fs/gofer/session.go
@@ -16,7 +16,6 @@ package gofer
import (
"fmt"
- "sync"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/refs"
@@ -25,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
index 23daeb528..2b581aa69 100644
--- a/pkg/sentry/fs/host/BUILD
+++ b/pkg/sentry/fs/host/BUILD
@@ -50,6 +50,7 @@ go_library(
"//pkg/sentry/unimpl",
"//pkg/sentry/uniqueid",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go
index a6e4a09e3..873a1c52d 100644
--- a/pkg/sentry/fs/host/inode.go
+++ b/pkg/sentry/fs/host/inode.go
@@ -15,7 +15,6 @@
package host
import (
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -28,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index 107336a3e..c076d5bdd 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -16,7 +16,6 @@ package host
import (
"fmt"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -30,6 +29,7 @@ import (
unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go
index 90331e3b2..753ef8cd6 100644
--- a/pkg/sentry/fs/host/tty.go
+++ b/pkg/sentry/fs/host/tty.go
@@ -15,8 +15,6 @@
package host
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -24,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/unimpl"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go
index 91e2fde2f..e4cf5a570 100644
--- a/pkg/sentry/fs/inode.go
+++ b/pkg/sentry/fs/inode.go
@@ -15,8 +15,6 @@
package fs
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/metric"
@@ -26,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -262,28 +261,28 @@ func (i *Inode) UnstableAttr(ctx context.Context) (UnstableAttr, error) {
return i.InodeOperations.UnstableAttr(ctx, i)
}
-// Getxattr calls i.InodeOperations.Getxattr with i as the Inode.
-func (i *Inode) Getxattr(name string) (string, error) {
+// GetXattr calls i.InodeOperations.GetXattr with i as the Inode.
+func (i *Inode) GetXattr(ctx context.Context, name string, size uint64) (string, error) {
if i.overlay != nil {
- return overlayGetxattr(i.overlay, name)
+ return overlayGetXattr(ctx, i.overlay, name, size)
}
- return i.InodeOperations.Getxattr(i, name)
+ return i.InodeOperations.GetXattr(ctx, i, name, size)
}
-// Setxattr calls i.InodeOperations.Setxattr with i as the Inode.
-func (i *Inode) Setxattr(name, value string) error {
+// SetXattr calls i.InodeOperations.SetXattr with i as the Inode.
+func (i *Inode) SetXattr(ctx context.Context, d *Dirent, name, value string, flags uint32) error {
if i.overlay != nil {
- return overlaySetxattr(i.overlay, name, value)
+ return overlaySetxattr(ctx, i.overlay, d, name, value, flags)
}
- return i.InodeOperations.Setxattr(i, name, value)
+ return i.InodeOperations.SetXattr(ctx, i, name, value, flags)
}
-// Listxattr calls i.InodeOperations.Listxattr with i as the Inode.
-func (i *Inode) Listxattr() (map[string]struct{}, error) {
+// ListXattr calls i.InodeOperations.ListXattr with i as the Inode.
+func (i *Inode) ListXattr(ctx context.Context) (map[string]struct{}, error) {
if i.overlay != nil {
- return overlayListxattr(i.overlay)
+ return overlayListXattr(ctx, i.overlay)
}
- return i.InodeOperations.Listxattr(i)
+ return i.InodeOperations.ListXattr(ctx, i)
}
// CheckPermission will check if the caller may access this file in the
diff --git a/pkg/sentry/fs/inode_inotify.go b/pkg/sentry/fs/inode_inotify.go
index 0f2a66a79..efd3c962b 100644
--- a/pkg/sentry/fs/inode_inotify.go
+++ b/pkg/sentry/fs/inode_inotify.go
@@ -16,7 +16,8 @@ package fs
import (
"fmt"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// Watches is the collection of inotify watches on an inode.
diff --git a/pkg/sentry/fs/inode_operations.go b/pkg/sentry/fs/inode_operations.go
index 5cde9d215..13261cb81 100644
--- a/pkg/sentry/fs/inode_operations.go
+++ b/pkg/sentry/fs/inode_operations.go
@@ -170,20 +170,27 @@ type InodeOperations interface {
// file system events.
UnstableAttr(ctx context.Context, inode *Inode) (UnstableAttr, error)
- // Getxattr retrieves the value of extended attribute name. Inodes that
- // do not support extended attributes return EOPNOTSUPP. Inodes that
- // support extended attributes but don't have a value at name return
+ // GetXattr retrieves the value of extended attribute specified by name.
+ // Inodes that do not support extended attributes return EOPNOTSUPP. Inodes
+ // that support extended attributes but don't have a value at name return
// ENODATA.
- Getxattr(inode *Inode, name string) (string, error)
+ //
+ // If this is called through the getxattr(2) syscall, size indicates the
+ // size of the buffer that the application has allocated to hold the
+ // attribute value. If the value is larger than size, implementations may
+ // return ERANGE to indicate that the buffer is too small, but they are also
+ // free to ignore the hint entirely (i.e. the value returned may be larger
+ // than size). All size checking is done independently at the syscall layer.
+ GetXattr(ctx context.Context, inode *Inode, name string, size uint64) (string, error)
- // Setxattr sets the value of extended attribute name. Inodes that
- // do not support extended attributes return EOPNOTSUPP.
- Setxattr(inode *Inode, name, value string) error
+ // SetXattr sets the value of extended attribute specified by name. Inodes
+ // that do not support extended attributes return EOPNOTSUPP.
+ SetXattr(ctx context.Context, inode *Inode, name, value string, flags uint32) error
- // Listxattr returns the set of all extended attributes names that
+ // ListXattr returns the set of all extended attributes names that
// have values. Inodes that do not support extended attributes return
// EOPNOTSUPP.
- Listxattr(inode *Inode) (map[string]struct{}, error)
+ ListXattr(ctx context.Context, inode *Inode) (map[string]struct{}, error)
// Check determines whether an Inode can be accessed with the
// requested permission mask using the context (which gives access
diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go
index 13d11e001..c477de837 100644
--- a/pkg/sentry/fs/inode_overlay.go
+++ b/pkg/sentry/fs/inode_overlay.go
@@ -25,13 +25,13 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
-func overlayHasWhiteout(parent *Inode, name string) bool {
- s, err := parent.Getxattr(XattrOverlayWhiteout(name))
+func overlayHasWhiteout(ctx context.Context, parent *Inode, name string) bool {
+ s, err := parent.GetXattr(ctx, XattrOverlayWhiteout(name), 1)
return err == nil && s == "y"
}
-func overlayCreateWhiteout(parent *Inode, name string) error {
- return parent.InodeOperations.Setxattr(parent, XattrOverlayWhiteout(name), "y")
+func overlayCreateWhiteout(ctx context.Context, parent *Inode, name string) error {
+ return parent.InodeOperations.SetXattr(ctx, parent, XattrOverlayWhiteout(name), "y", 0 /* flags */)
}
func overlayWriteOut(ctx context.Context, o *overlayEntry) error {
@@ -89,7 +89,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name
}
// Are we done?
- if overlayHasWhiteout(parent.upper, name) {
+ if overlayHasWhiteout(ctx, parent.upper, name) {
if upperInode == nil {
parent.copyMu.RUnlock()
if negativeUpperChild {
@@ -345,7 +345,7 @@ func overlayRemove(ctx context.Context, o *overlayEntry, parent *Dirent, child *
}
}
if child.Inode.overlay.lowerExists {
- if err := overlayCreateWhiteout(o.upper, child.name); err != nil {
+ if err := overlayCreateWhiteout(ctx, o.upper, child.name); err != nil {
return err
}
}
@@ -426,7 +426,7 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena
return err
}
if renamed.Inode.overlay.lowerExists {
- if err := overlayCreateWhiteout(oldParent.Inode.overlay.upper, oldName); err != nil {
+ if err := overlayCreateWhiteout(ctx, oldParent.Inode.overlay.upper, oldName); err != nil {
return err
}
}
@@ -528,7 +528,7 @@ func overlayUnstableAttr(ctx context.Context, o *overlayEntry) (UnstableAttr, er
return attr, err
}
-func overlayGetxattr(o *overlayEntry, name string) (string, error) {
+func overlayGetXattr(ctx context.Context, o *overlayEntry, name string, size uint64) (string, error) {
// Hot path. This is how the overlay checks for whiteout files.
// Avoid defers.
var (
@@ -544,31 +544,38 @@ func overlayGetxattr(o *overlayEntry, name string) (string, error) {
o.copyMu.RLock()
if o.upper != nil {
- s, err = o.upper.Getxattr(name)
+ s, err = o.upper.GetXattr(ctx, name, size)
} else {
- s, err = o.lower.Getxattr(name)
+ s, err = o.lower.GetXattr(ctx, name, size)
}
o.copyMu.RUnlock()
return s, err
}
-// TODO(b/146028302): Support setxattr for overlayfs.
-func overlaySetxattr(o *overlayEntry, name, value string) error {
- return syserror.EOPNOTSUPP
+func overlaySetxattr(ctx context.Context, o *overlayEntry, d *Dirent, name, value string, flags uint32) error {
+ // Don't allow changes to overlay xattrs through a setxattr syscall.
+ if strings.HasPrefix(XattrOverlayPrefix, name) {
+ return syserror.EPERM
+ }
+
+ if err := copyUp(ctx, d); err != nil {
+ return err
+ }
+ return o.upper.SetXattr(ctx, d, name, value, flags)
}
-func overlayListxattr(o *overlayEntry) (map[string]struct{}, error) {
+func overlayListXattr(ctx context.Context, o *overlayEntry) (map[string]struct{}, error) {
o.copyMu.RLock()
defer o.copyMu.RUnlock()
var names map[string]struct{}
var err error
if o.upper != nil {
- names, err = o.upper.Listxattr()
+ names, err = o.upper.ListXattr(ctx)
} else {
- names, err = o.lower.Listxattr()
+ names, err = o.lower.ListXattr(ctx)
}
for name := range names {
- // Same as overlayGetxattr, we shouldn't forward along
+ // Same as overlayGetXattr, we shouldn't forward along
// overlay attributes.
if strings.HasPrefix(XattrOverlayPrefix, name) {
delete(names, name)
diff --git a/pkg/sentry/fs/inode_overlay_test.go b/pkg/sentry/fs/inode_overlay_test.go
index 8935aad65..493d98c36 100644
--- a/pkg/sentry/fs/inode_overlay_test.go
+++ b/pkg/sentry/fs/inode_overlay_test.go
@@ -382,8 +382,8 @@ type dir struct {
ReaddirCalled bool
}
-// Getxattr implements InodeOperations.Getxattr.
-func (d *dir) Getxattr(inode *fs.Inode, name string) (string, error) {
+// GetXattr implements InodeOperations.GetXattr.
+func (d *dir) GetXattr(_ context.Context, _ *fs.Inode, name string, _ uint64) (string, error) {
for _, n := range d.negative {
if name == fs.XattrOverlayWhiteout(n) {
return "y", nil
diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go
index ba3e0233d..cc7dd1c92 100644
--- a/pkg/sentry/fs/inotify.go
+++ b/pkg/sentry/fs/inotify.go
@@ -16,7 +16,6 @@ package fs
import (
"io"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -25,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/fs/inotify_watch.go b/pkg/sentry/fs/inotify_watch.go
index 0aa0a5e9b..900cba3ca 100644
--- a/pkg/sentry/fs/inotify_watch.go
+++ b/pkg/sentry/fs/inotify_watch.go
@@ -15,10 +15,10 @@
package fs
import (
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Watch represent a particular inotify watch created by inotify_add_watch.
diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD
index 8d62642e7..2c332a82a 100644
--- a/pkg/sentry/fs/lock/BUILD
+++ b/pkg/sentry/fs/lock/BUILD
@@ -44,6 +44,7 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/log",
+ "//pkg/sync",
"//pkg/waiter",
],
)
diff --git a/pkg/sentry/fs/lock/lock.go b/pkg/sentry/fs/lock/lock.go
index 636484424..926538d90 100644
--- a/pkg/sentry/fs/lock/lock.go
+++ b/pkg/sentry/fs/lock/lock.go
@@ -52,9 +52,9 @@ package lock
import (
"fmt"
"math"
- "sync"
"syscall"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -78,6 +78,9 @@ const (
)
// LockEOF is the maximal possible end of a regional file lock.
+//
+// A BSD-style full file lock can be represented as a regional file lock from
+// offset 0 to LockEOF.
const LockEOF = math.MaxUint64
// Lock is a regional file lock. It consists of either a single writer
diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go
index ac0398bd9..db3dfd096 100644
--- a/pkg/sentry/fs/mounts.go
+++ b/pkg/sentry/fs/mounts.go
@@ -19,7 +19,6 @@ import (
"math"
"path"
"strings"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -27,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/fs/overlay.go b/pkg/sentry/fs/overlay.go
index 25573e986..4cad55327 100644
--- a/pkg/sentry/fs/overlay.go
+++ b/pkg/sentry/fs/overlay.go
@@ -17,13 +17,12 @@ package fs
import (
"fmt"
"strings"
- "sync"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/syncutil"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -199,7 +198,7 @@ type overlayEntry struct {
upper *Inode
// dirCacheMu protects dirCache.
- dirCacheMu syncutil.DowngradableRWMutex `state:"nosave"`
+ dirCacheMu sync.DowngradableRWMutex `state:"nosave"`
// dirCache is cache of DentAttrs from upper and lower Inodes.
dirCache *SortedDentryMap
diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD
index 75cbb0622..cb37c6c6b 100644
--- a/pkg/sentry/fs/proc/BUILD
+++ b/pkg/sentry/fs/proc/BUILD
@@ -18,7 +18,6 @@ go_library(
"mounts.go",
"net.go",
"proc.go",
- "rpcinet_proc.go",
"stat.go",
"sys.go",
"sys_net.go",
@@ -46,11 +45,11 @@ go_library(
"//pkg/sentry/limits",
"//pkg/sentry/mm",
"//pkg/sentry/socket",
- "//pkg/sentry/socket/rpcinet",
"//pkg/sentry/socket/unix",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/tcpip/header",
"//pkg/waiter",
diff --git a/pkg/sentry/fs/proc/cgroup.go b/pkg/sentry/fs/proc/cgroup.go
index 05e31c55d..c4abe319d 100644
--- a/pkg/sentry/fs/proc/cgroup.go
+++ b/pkg/sentry/fs/proc/cgroup.go
@@ -21,6 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
)
+// LINT.IfChange
+
func newCGroupInode(ctx context.Context, msrc *fs.MountSource, cgroupControllers map[string]string) *fs.Inode {
// From man 7 cgroups: "For each cgroup hierarchy of which the process
// is a member, there is one entry containing three colon-separated
@@ -39,3 +41,5 @@ func newCGroupInode(ctx context.Context, msrc *fs.MountSource, cgroupControllers
return newStaticProcInode(ctx, msrc, []byte(data))
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/cpuinfo.go b/pkg/sentry/fs/proc/cpuinfo.go
index 3edf36780..df0c4e3a7 100644
--- a/pkg/sentry/fs/proc/cpuinfo.go
+++ b/pkg/sentry/fs/proc/cpuinfo.go
@@ -15,11 +15,15 @@
package proc
import (
+ "bytes"
+
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
)
+// LINT.IfChange
+
func newCPUInfo(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
k := kernel.KernelFromContext(ctx)
features := k.FeatureSet()
@@ -27,9 +31,11 @@ func newCPUInfo(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
// Kernel is always initialized with a FeatureSet.
panic("cpuinfo read with nil FeatureSet")
}
- contents := make([]byte, 0, 1024)
+ var buf bytes.Buffer
for i, max := uint(0), k.ApplicationCores(); i < max; i++ {
- contents = append(contents, []byte(features.CPUInfo(i))...)
+ features.WriteCPUInfoTo(i, &buf)
}
- return newStaticProcInode(ctx, msrc, contents)
+ return newStaticProcInode(ctx, msrc, buf.Bytes())
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/exec_args.go b/pkg/sentry/fs/proc/exec_args.go
index 1d3a2d426..9aaeb780b 100644
--- a/pkg/sentry/fs/proc/exec_args.go
+++ b/pkg/sentry/fs/proc/exec_args.go
@@ -29,6 +29,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// execArgType enumerates the types of exec arguments that are exposed through
// proc.
type execArgType int
@@ -201,3 +203,5 @@ func (f *execArgFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequen
}
return int64(n), err
}
+
+// LINT.ThenChange(../../fsimpl/proc/task.go)
diff --git a/pkg/sentry/fs/proc/fds.go b/pkg/sentry/fs/proc/fds.go
index bee421d76..2fa3cfa7d 100644
--- a/pkg/sentry/fs/proc/fds.go
+++ b/pkg/sentry/fs/proc/fds.go
@@ -28,6 +28,8 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// LINT.IfChange
+
// walkDescriptors finds the descriptor (file-flag pair) for the fd identified
// by p, and calls the toInodeOperations callback with that descriptor. This is a helper
// method for implementing fs.InodeOperations.Lookup.
@@ -277,3 +279,5 @@ func (fdid *fdInfoDir) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.
}
return fs.NewFile(ctx, dirent, flags, fops), nil
}
+
+// LINT.ThenChange(../../fsimpl/proc/task_files.go)
diff --git a/pkg/sentry/fs/proc/filesystems.go b/pkg/sentry/fs/proc/filesystems.go
index e9250c51c..7b3b974ab 100644
--- a/pkg/sentry/fs/proc/filesystems.go
+++ b/pkg/sentry/fs/proc/filesystems.go
@@ -23,6 +23,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
)
+// LINT.IfChange
+
// filesystemsData backs /proc/filesystems.
//
// +stateify savable
@@ -59,3 +61,5 @@ func (*filesystemsData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle
// Return the SeqData and advance the generation counter.
return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*filesystemsData)(nil)}}, 1
}
+
+// LINT.ThenChange(../../fsimpl/proc/filesystem.go)
diff --git a/pkg/sentry/fs/proc/fs.go b/pkg/sentry/fs/proc/fs.go
index f14833805..761d24462 100644
--- a/pkg/sentry/fs/proc/fs.go
+++ b/pkg/sentry/fs/proc/fs.go
@@ -21,6 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
)
+// LINT.IfChange
+
// filesystem is a procfs.
//
// +stateify savable
@@ -79,3 +81,5 @@ func (f *filesystem) Mount(ctx context.Context, device string, flags fs.MountSou
// never want them cached.
return New(ctx, fs.NewNonCachingMountSource(ctx, f, flags), cgroups)
}
+
+// LINT.ThenChange(../../fsimpl/proc/filesystem.go)
diff --git a/pkg/sentry/fs/proc/inode.go b/pkg/sentry/fs/proc/inode.go
index 0c04f81fa..723f6b661 100644
--- a/pkg/sentry/fs/proc/inode.go
+++ b/pkg/sentry/fs/proc/inode.go
@@ -26,6 +26,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/usermem"
)
+// LINT.IfChange
+
// taskOwnedInodeOps wraps an fs.InodeOperations and overrides the UnstableAttr
// method to return either the task or root as the owner, depending on the
// task's dumpability.
@@ -131,3 +133,5 @@ func newProcInode(ctx context.Context, iops fs.InodeOperations, msrc *fs.MountSo
}
return fs.NewInode(ctx, iops, msrc, sattr)
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks.go)
diff --git a/pkg/sentry/fs/proc/loadavg.go b/pkg/sentry/fs/proc/loadavg.go
index 8602b7426..d7d2afcb7 100644
--- a/pkg/sentry/fs/proc/loadavg.go
+++ b/pkg/sentry/fs/proc/loadavg.go
@@ -22,6 +22,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
)
+// LINT.IfChange
+
// loadavgData backs /proc/loadavg.
//
// +stateify savable
@@ -53,3 +55,5 @@ func (d *loadavgData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle)
},
}, 0
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/meminfo.go b/pkg/sentry/fs/proc/meminfo.go
index 495f3e3ba..313c6a32b 100644
--- a/pkg/sentry/fs/proc/meminfo.go
+++ b/pkg/sentry/fs/proc/meminfo.go
@@ -25,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/usermem"
)
+// LINT.IfChange
+
// meminfoData backs /proc/meminfo.
//
// +stateify savable
@@ -83,3 +85,5 @@ func (d *meminfoData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle)
fmt.Fprintf(&buf, "Shmem: %8d kB\n", snapshot.Tmpfs/1024)
return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*meminfoData)(nil)}}, 0
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/mounts.go b/pkg/sentry/fs/proc/mounts.go
index e33c4a460..5aedae799 100644
--- a/pkg/sentry/fs/proc/mounts.go
+++ b/pkg/sentry/fs/proc/mounts.go
@@ -25,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
)
+// LINT.IfChange
+
// forEachMountSource runs f for the process root mount and each mount that is a
// descendant of the root.
func forEachMount(t *kernel.Task, fn func(string, *fs.Mount)) {
@@ -195,3 +197,5 @@ func (mf *mountsFile) ReadSeqFileData(ctx context.Context, handle seqfile.SeqHan
return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*mountsFile)(nil)}}, 0
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go
index 402919924..3f17e98ea 100644
--- a/pkg/sentry/fs/proc/net.go
+++ b/pkg/sentry/fs/proc/net.go
@@ -38,6 +38,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
)
+// LINT.IfChange
+
// newNet creates a new proc net entry.
func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSource) *fs.Inode {
var contents map[string]*fs.Inode
@@ -831,3 +833,5 @@ func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
}
return data, 0
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_net.go)
diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go
index 56e92721e..29867dc3a 100644
--- a/pkg/sentry/fs/proc/proc.go
+++ b/pkg/sentry/fs/proc/proc.go
@@ -27,10 +27,11 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet"
"gvisor.dev/gvisor/pkg/syserror"
)
+// LINT.IfChange
+
// proc is a root proc node.
//
// +stateify savable
@@ -85,15 +86,9 @@ func New(ctx context.Context, msrc *fs.MountSource, cgroupControllers map[string
}
// Add more contents that need proc to be initialized.
+ p.AddChild(ctx, "net", p.newNetDir(ctx, k, msrc))
p.AddChild(ctx, "sys", p.newSysDir(ctx, msrc))
- // If we're using rpcinet we will let it manage /proc/net.
- if _, ok := p.k.NetworkStack().(*rpcinet.Stack); ok {
- p.AddChild(ctx, "net", newRPCInetProcNet(ctx, msrc))
- } else {
- p.AddChild(ctx, "net", p.newNetDir(ctx, k, msrc))
- }
-
return newProcInode(ctx, p, msrc, fs.SpecialDirectory, nil), nil
}
@@ -249,3 +244,5 @@ func (rpf *rootProcFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dent
}
return offset, nil
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks.go)
diff --git a/pkg/sentry/fs/proc/rpcinet_proc.go b/pkg/sentry/fs/proc/rpcinet_proc.go
deleted file mode 100644
index 01ac97530..000000000
--- a/pkg/sentry/fs/proc/rpcinet_proc.go
+++ /dev/null
@@ -1,217 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package proc
-
-import (
- "io"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
- "gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-// rpcInetInode implements fs.InodeOperations.
-type rpcInetInode struct {
- fsutil.SimpleFileInode
-
- // filepath is the full path of this rpcInetInode.
- filepath string
-
- k *kernel.Kernel
-}
-
-func newRPCInetInode(ctx context.Context, msrc *fs.MountSource, filepath string, mode linux.FileMode) *fs.Inode {
- f := &rpcInetInode{
- SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(mode), linux.PROC_SUPER_MAGIC),
- filepath: filepath,
- k: kernel.KernelFromContext(ctx),
- }
- return newProcInode(ctx, f, msrc, fs.SpecialFile, nil)
-}
-
-// GetFile implements fs.InodeOperations.GetFile.
-func (i *rpcInetInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
- flags.Pread = true
- flags.Pwrite = true
- fops := &rpcInetFile{
- inode: i,
- }
- return fs.NewFile(ctx, dirent, flags, fops), nil
-}
-
-// rpcInetFile implements fs.FileOperations as RPCs.
-type rpcInetFile struct {
- fsutil.FileGenericSeek `state:"nosave"`
- fsutil.FileNoIoctl `state:"nosave"`
- fsutil.FileNoMMap `state:"nosave"`
- fsutil.FileNoSplice `state:"nosave"`
- fsutil.FileNoopFlush `state:"nosave"`
- fsutil.FileNoopFsync `state:"nosave"`
- fsutil.FileNoopRelease `state:"nosave"`
- fsutil.FileNotDirReaddir `state:"nosave"`
- fsutil.FileUseInodeUnstableAttr `state:"nosave"`
- waiter.AlwaysReady `state:"nosave"`
-
- inode *rpcInetInode
-}
-
-// Read implements fs.FileOperations.Read.
-//
-// This method can panic if an rpcInetInode was created without an rpcinet
-// stack.
-func (f *rpcInetFile) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
- if offset < 0 {
- return 0, syserror.EINVAL
- }
- s, ok := f.inode.k.NetworkStack().(*rpcinet.Stack)
- if !ok {
- panic("Network stack is not a rpcinet.")
- }
-
- contents, se := s.RPCReadFile(f.inode.filepath)
- if se != nil || offset >= int64(len(contents)) {
- return 0, io.EOF
- }
-
- n, err := dst.CopyOut(ctx, contents[offset:])
- return int64(n), err
-}
-
-// Write implements fs.FileOperations.Write.
-//
-// This method can panic if an rpcInetInode was created without an rpcInet
-// stack.
-func (f *rpcInetFile) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
- s, ok := f.inode.k.NetworkStack().(*rpcinet.Stack)
- if !ok {
- panic("Network stack is not a rpcinet.")
- }
-
- if src.NumBytes() == 0 {
- return 0, nil
- }
-
- b := make([]byte, src.NumBytes(), src.NumBytes())
- n, err := src.CopyIn(ctx, b)
- if err != nil {
- return int64(n), err
- }
-
- written, se := s.RPCWriteFile(f.inode.filepath, b)
- return int64(written), se.ToError()
-}
-
-// newRPCInetProcNet will build an inode for /proc/net.
-func newRPCInetProcNet(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
- contents := map[string]*fs.Inode{
- "arp": newRPCInetInode(ctx, msrc, "/proc/net/arp", 0444),
- "dev": newRPCInetInode(ctx, msrc, "/proc/net/dev", 0444),
- "if_inet6": newRPCInetInode(ctx, msrc, "/proc/net/if_inet6", 0444),
- "ipv6_route": newRPCInetInode(ctx, msrc, "/proc/net/ipv6_route", 0444),
- "netlink": newRPCInetInode(ctx, msrc, "/proc/net/netlink", 0444),
- "netstat": newRPCInetInode(ctx, msrc, "/proc/net/netstat", 0444),
- "packet": newRPCInetInode(ctx, msrc, "/proc/net/packet", 0444),
- "protocols": newRPCInetInode(ctx, msrc, "/proc/net/protocols", 0444),
- "psched": newRPCInetInode(ctx, msrc, "/proc/net/psched", 0444),
- "ptype": newRPCInetInode(ctx, msrc, "/proc/net/ptype", 0444),
- "route": newRPCInetInode(ctx, msrc, "/proc/net/route", 0444),
- "tcp": newRPCInetInode(ctx, msrc, "/proc/net/tcp", 0444),
- "tcp6": newRPCInetInode(ctx, msrc, "/proc/net/tcp6", 0444),
- "udp": newRPCInetInode(ctx, msrc, "/proc/net/udp", 0444),
- "udp6": newRPCInetInode(ctx, msrc, "/proc/net/udp6", 0444),
- }
-
- d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
- return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
-}
-
-// newRPCInetProcSysNet will build an inode for /proc/sys/net.
-func newRPCInetProcSysNet(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
- contents := map[string]*fs.Inode{
- "ipv4": newRPCInetSysNetIPv4Dir(ctx, msrc),
- "core": newRPCInetSysNetCore(ctx, msrc),
- }
-
- d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
- return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
-}
-
-// newRPCInetSysNetCore builds the /proc/sys/net/core directory.
-func newRPCInetSysNetCore(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
- contents := map[string]*fs.Inode{
- "default_qdisc": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/default_qdisc", 0444),
- "message_burst": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/message_burst", 0444),
- "message_cost": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/message_cost", 0444),
- "optmem_max": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/optmem_max", 0444),
- "rmem_default": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/rmem_default", 0444),
- "rmem_max": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/rmem_max", 0444),
- "somaxconn": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/somaxconn", 0444),
- "wmem_default": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/wmem_default", 0444),
- "wmem_max": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/wmem_max", 0444),
- }
-
- d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
- return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
-}
-
-// newRPCInetSysNetIPv4Dir builds the /proc/sys/net/ipv4 directory.
-func newRPCInetSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
- contents := map[string]*fs.Inode{
- "ip_local_port_range": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/ip_local_port_range", 0444),
- "ip_local_reserved_ports": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/ip_local_reserved_ports", 0444),
- "ipfrag_time": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/ipfrag_time", 0444),
- "ip_nonlocal_bind": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/ip_nonlocal_bind", 0444),
- "ip_no_pmtu_disc": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/ip_no_pmtu_disc", 0444),
- "tcp_allowed_congestion_control": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_allowed_congestion_control", 0444),
- "tcp_available_congestion_control": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_available_congestion_control", 0444),
- "tcp_base_mss": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_base_mss", 0444),
- "tcp_congestion_control": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_congestion_control", 0644),
- "tcp_dsack": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_dsack", 0644),
- "tcp_early_retrans": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_early_retrans", 0644),
- "tcp_fack": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_fack", 0644),
- "tcp_fastopen": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_fastopen", 0644),
- "tcp_fastopen_key": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_fastopen_key", 0444),
- "tcp_fin_timeout": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_fin_timeout", 0644),
- "tcp_invalid_ratelimit": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_invalid_ratelimit", 0444),
- "tcp_keepalive_intvl": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_keepalive_intvl", 0644),
- "tcp_keepalive_probes": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_keepalive_probes", 0644),
- "tcp_keepalive_time": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_keepalive_time", 0644),
- "tcp_mem": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_mem", 0444),
- "tcp_mtu_probing": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_mtu_probing", 0644),
- "tcp_no_metrics_save": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_no_metrics_save", 0444),
- "tcp_probe_interval": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_probe_interval", 0444),
- "tcp_probe_threshold": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_probe_threshold", 0444),
- "tcp_retries1": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_retries1", 0644),
- "tcp_retries2": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_retries2", 0644),
- "tcp_rfc1337": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_rfc1337", 0444),
- "tcp_rmem": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_rmem", 0444),
- "tcp_sack": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_sack", 0644),
- "tcp_slow_start_after_idle": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_slow_start_after_idle", 0644),
- "tcp_synack_retries": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_synack_retries", 0644),
- "tcp_syn_retries": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_syn_retries", 0644),
- "tcp_timestamps": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_timestamps", 0644),
- "tcp_wmem": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_wmem", 0444),
- }
-
- d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
- return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
-}
diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD
index fe7067be1..38b246dff 100644
--- a/pkg/sentry/fs/proc/seqfile/BUILD
+++ b/pkg/sentry/fs/proc/seqfile/BUILD
@@ -16,6 +16,7 @@ go_library(
"//pkg/sentry/fs/proc/device",
"//pkg/sentry/kernel/time",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/proc/seqfile/seqfile.go b/pkg/sentry/fs/proc/seqfile/seqfile.go
index 5fe823000..f9af191d5 100644
--- a/pkg/sentry/fs/proc/seqfile/seqfile.go
+++ b/pkg/sentry/fs/proc/seqfile/seqfile.go
@@ -17,7 +17,6 @@ package seqfile
import (
"io"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -26,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/proc/device"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/fs/proc/stat.go b/pkg/sentry/fs/proc/stat.go
index b641effbb..bc5b2bc7b 100644
--- a/pkg/sentry/fs/proc/stat.go
+++ b/pkg/sentry/fs/proc/stat.go
@@ -24,6 +24,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
)
+// LINT.IfChange
+
// statData backs /proc/stat.
//
// +stateify savable
@@ -140,3 +142,5 @@ func (s *statData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]
},
}, 0
}
+
+// LINT.ThenChange(../../fsimpl/proc/task_files.go)
diff --git a/pkg/sentry/fs/proc/sys.go b/pkg/sentry/fs/proc/sys.go
index cd37776c8..2bdcf5f70 100644
--- a/pkg/sentry/fs/proc/sys.go
+++ b/pkg/sentry/fs/proc/sys.go
@@ -26,11 +26,12 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// mmapMinAddrData backs /proc/sys/vm/mmap_min_addr.
//
// +stateify savable
@@ -104,16 +105,10 @@ func (p *proc) newVMDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
func (p *proc) newSysDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
children := map[string]*fs.Inode{
"kernel": p.newKernelDir(ctx, msrc),
+ "net": p.newSysNetDir(ctx, msrc),
"vm": p.newVMDir(ctx, msrc),
}
- // If we're using rpcinet we will let it manage /proc/sys/net.
- if _, ok := p.k.NetworkStack().(*rpcinet.Stack); ok {
- children["net"] = newRPCInetProcSysNet(ctx, msrc)
- } else {
- children["net"] = p.newSysNetDir(ctx, msrc)
- }
-
d := ramfs.NewDir(ctx, children, fs.RootOwner, fs.FilePermsFromMode(0555))
return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
}
@@ -160,3 +155,5 @@ func (hf *hostnameFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequ
}
var _ fs.FileOperations = (*hostnameFile)(nil)
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_sys.go)
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go
index bd93f83fa..b9e8ef35f 100644
--- a/pkg/sentry/fs/proc/sys_net.go
+++ b/pkg/sentry/fs/proc/sys_net.go
@@ -17,7 +17,6 @@ package proc
import (
"fmt"
"io"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -27,9 +26,12 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
type tcpMemDir int
const (
@@ -364,3 +366,5 @@ func (p *proc) newSysNetDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode
d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_sys.go)
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index 9bf4b4527..7358d6ef9 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -37,6 +37,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// getTaskMM returns t's MemoryManager. If getTaskMM succeeds, the MemoryManager's
// users count is incremented, and must be decremented by the caller when it is
// no longer in use.
@@ -800,3 +802,5 @@ func (f *auxvecFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequenc
n, err := dst.CopyOut(ctx, buf[offset:])
return int64(n), err
}
+
+// LINT.ThenChange(../../fsimpl/proc/task.go|../../fsimpl/proc/task_files.go)
diff --git a/pkg/sentry/fs/proc/uid_gid_map.go b/pkg/sentry/fs/proc/uid_gid_map.go
index eea37d15c..3eacc9265 100644
--- a/pkg/sentry/fs/proc/uid_gid_map.go
+++ b/pkg/sentry/fs/proc/uid_gid_map.go
@@ -30,6 +30,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// idMapInodeOperations implements fs.InodeOperations for
// /proc/[pid]/{uid,gid}_map.
//
@@ -177,3 +179,5 @@ func (imfo *idMapFileOperations) Write(ctx context.Context, file *fs.File, src u
// count, even if fewer bytes were used.
return int64(srclen), nil
}
+
+// LINT.ThenChange(../../fsimpl/proc/task_files.go)
diff --git a/pkg/sentry/fs/proc/uptime.go b/pkg/sentry/fs/proc/uptime.go
index 4e903917a..adfe58adb 100644
--- a/pkg/sentry/fs/proc/uptime.go
+++ b/pkg/sentry/fs/proc/uptime.go
@@ -28,6 +28,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// uptime is a file containing the system uptime.
//
// +stateify savable
@@ -85,3 +87,5 @@ func (f *uptimeFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequenc
n, err := dst.CopyOut(ctx, s[offset:])
return int64(n), err
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/version.go b/pkg/sentry/fs/proc/version.go
index a6d2c3cd3..27fd5b1cb 100644
--- a/pkg/sentry/fs/proc/version.go
+++ b/pkg/sentry/fs/proc/version.go
@@ -22,6 +22,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
)
+// LINT.IfChange
+
// versionData backs /proc/version.
//
// +stateify savable
@@ -76,3 +78,5 @@ func (v *versionData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle)
},
}, 0
}
+
+// LINT.ThenChange(../../fsimpl/proc/task_files.go)
diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD
index 012cb3e44..3fb7b0633 100644
--- a/pkg/sentry/fs/ramfs/BUILD
+++ b/pkg/sentry/fs/ramfs/BUILD
@@ -21,6 +21,7 @@ go_library(
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/ramfs/dir.go b/pkg/sentry/fs/ramfs/dir.go
index 78e082b8e..dcbb8eb2e 100644
--- a/pkg/sentry/fs/ramfs/dir.go
+++ b/pkg/sentry/fs/ramfs/dir.go
@@ -17,7 +17,6 @@ package ramfs
import (
"fmt"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -25,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/fs/restore.go b/pkg/sentry/fs/restore.go
index f10168125..64c6a6ae9 100644
--- a/pkg/sentry/fs/restore.go
+++ b/pkg/sentry/fs/restore.go
@@ -15,7 +15,7 @@
package fs
import (
- "sync"
+ "gvisor.dev/gvisor/pkg/sync"
)
// RestoreEnvironment is the restore environment for file systems. It consists
diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD
index 59ce400c2..3400b940c 100644
--- a/pkg/sentry/fs/tmpfs/BUILD
+++ b/pkg/sentry/fs/tmpfs/BUILD
@@ -31,6 +31,7 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go
index f86dfaa36..f1c87fe41 100644
--- a/pkg/sentry/fs/tmpfs/inode_file.go
+++ b/pkg/sentry/fs/tmpfs/inode_file.go
@@ -17,7 +17,6 @@ package tmpfs
import (
"fmt"
"io"
- "sync"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -31,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go
index 69089c8a8..0f718e236 100644
--- a/pkg/sentry/fs/tmpfs/tmpfs.go
+++ b/pkg/sentry/fs/tmpfs/tmpfs.go
@@ -148,19 +148,19 @@ func (d *Dir) CreateFifo(ctx context.Context, dir *fs.Inode, name string, perms
return d.ramfsDir.CreateFifo(ctx, dir, name, perms)
}
-// Getxattr implements fs.InodeOperations.Getxattr.
-func (d *Dir) Getxattr(i *fs.Inode, name string) (string, error) {
- return d.ramfsDir.Getxattr(i, name)
+// GetXattr implements fs.InodeOperations.GetXattr.
+func (d *Dir) GetXattr(ctx context.Context, i *fs.Inode, name string, size uint64) (string, error) {
+ return d.ramfsDir.GetXattr(ctx, i, name, size)
}
-// Setxattr implements fs.InodeOperations.Setxattr.
-func (d *Dir) Setxattr(i *fs.Inode, name, value string) error {
- return d.ramfsDir.Setxattr(i, name, value)
+// SetXattr implements fs.InodeOperations.SetXattr.
+func (d *Dir) SetXattr(ctx context.Context, i *fs.Inode, name, value string, flags uint32) error {
+ return d.ramfsDir.SetXattr(ctx, i, name, value, flags)
}
-// Listxattr implements fs.InodeOperations.Listxattr.
-func (d *Dir) Listxattr(i *fs.Inode) (map[string]struct{}, error) {
- return d.ramfsDir.Listxattr(i)
+// ListXattr implements fs.InodeOperations.ListXattr.
+func (d *Dir) ListXattr(ctx context.Context, i *fs.Inode) (map[string]struct{}, error) {
+ return d.ramfsDir.ListXattr(ctx, i)
}
// Lookup implements fs.InodeOperations.Lookup.
diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD
index 95ad98cb0..f6f60d0cf 100644
--- a/pkg/sentry/fs/tty/BUILD
+++ b/pkg/sentry/fs/tty/BUILD
@@ -30,6 +30,7 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/unimpl",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go
index 2f639c823..88aa66b24 100644
--- a/pkg/sentry/fs/tty/dir.go
+++ b/pkg/sentry/fs/tty/dir.go
@@ -19,7 +19,6 @@ import (
"fmt"
"math"
"strconv"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -28,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/fs/tty/line_discipline.go b/pkg/sentry/fs/tty/line_discipline.go
index 7cc0eb409..894964260 100644
--- a/pkg/sentry/fs/tty/line_discipline.go
+++ b/pkg/sentry/fs/tty/line_discipline.go
@@ -16,13 +16,13 @@ package tty
import (
"bytes"
- "sync"
"unicode/utf8"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/fs/tty/queue.go b/pkg/sentry/fs/tty/queue.go
index 231e4e6eb..8b5d4699a 100644
--- a/pkg/sentry/fs/tty/queue.go
+++ b/pkg/sentry/fs/tty/queue.go
@@ -15,13 +15,12 @@
package tty
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
index bc90330bc..903874141 100644
--- a/pkg/sentry/fsimpl/ext/BUILD
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -50,6 +50,7 @@ go_library(
"//pkg/sentry/syscalls/linux",
"//pkg/sentry/usermem",
"//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
index 91802dc1e..8944171c8 100644
--- a/pkg/sentry/fsimpl/ext/directory.go
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -15,8 +15,6 @@
package ext
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/log"
@@ -25,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go
index 616fc002a..9afb1a84c 100644
--- a/pkg/sentry/fsimpl/ext/filesystem.go
+++ b/pkg/sentry/fsimpl/ext/filesystem.go
@@ -17,13 +17,13 @@ package ext
import (
"errors"
"io"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go
index aec33e00a..d11153c90 100644
--- a/pkg/sentry/fsimpl/ext/regular_file.go
+++ b/pkg/sentry/fsimpl/ext/regular_file.go
@@ -16,7 +16,6 @@ package ext
import (
"io"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -24,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD
index 39c03ee9d..809178250 100644
--- a/pkg/sentry/fsimpl/kernfs/BUILD
+++ b/pkg/sentry/fsimpl/kernfs/BUILD
@@ -39,6 +39,7 @@ go_library(
"//pkg/sentry/memmap",
"//pkg/sentry/usermem",
"//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/syserror",
],
)
@@ -56,6 +57,7 @@ go_test(
"//pkg/sentry/kernel/auth",
"//pkg/sentry/usermem",
"//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/syserror",
"@com_github_google_go-cmp//cmp:go_default_library",
],
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
index 79759e0fc..a4600ad47 100644
--- a/pkg/sentry/fsimpl/kernfs/filesystem.go
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -40,7 +39,7 @@ func (fs *Filesystem) stepExistingLocked(ctx context.Context, rp *vfs.ResolvingP
return nil, syserror.ENOTDIR
}
// Directory searchable?
- if err := d.inode.CheckPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ if err := d.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
afterSymlink:
@@ -182,8 +181,8 @@ func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving
//
// Preconditions: Filesystem.mu must be locked for at least reading. parentInode
// == parentVFSD.Impl().(*Dentry).Inode. isDir(parentInode) == true.
-func checkCreateLocked(rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInode Inode) (string, error) {
- if err := parentInode.CheckPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInode Inode) (string, error) {
+ if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return "", err
}
pc := rp.Component()
@@ -206,7 +205,7 @@ func checkCreateLocked(rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInod
// checkDeleteLocked checks that the file represented by vfsd may be deleted.
//
// Preconditions: Filesystem.mu must be locked for at least reading.
-func checkDeleteLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry) error {
+func checkDeleteLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry) error {
parentVFSD := vfsd.Parent()
if parentVFSD == nil {
return syserror.EBUSY
@@ -214,36 +213,12 @@ func checkDeleteLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry) error {
if parentVFSD.IsDisowned() {
return syserror.ENOENT
}
- if err := parentVFSD.Impl().(*Dentry).inode.CheckPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ if err := parentVFSD.Impl().(*Dentry).inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
return err
}
return nil
}
-// checkRenameLocked checks that a rename operation may be performed on the
-// target dentry across the given set of parent directories. The target dentry
-// may be nil.
-//
-// Precondition: isDir(dstInode) == true.
-func checkRenameLocked(creds *auth.Credentials, src, dstDir *vfs.Dentry, dstInode Inode) error {
- srcDir := src.Parent()
- if srcDir == nil {
- return syserror.EBUSY
- }
- if srcDir.IsDisowned() {
- return syserror.ENOENT
- }
- if dstDir.IsDisowned() {
- return syserror.ENOENT
- }
- // Check for creation permissions on dst dir.
- if err := dstInode.CheckPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil {
- return err
- }
-
- return nil
-}
-
// Release implements vfs.FilesystemImpl.Release.
func (fs *Filesystem) Release() {
}
@@ -269,7 +244,7 @@ func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op
if !d.isDir() {
return nil, syserror.ENOTDIR
}
- if err := inode.CheckPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ if err := inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
}
@@ -302,7 +277,7 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
if err != nil {
return err
}
- pc, err := checkCreateLocked(rp, parentVFSD, parentInode)
+ pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode)
if err != nil {
return err
}
@@ -339,7 +314,7 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
if err != nil {
return err
}
- pc, err := checkCreateLocked(rp, parentVFSD, parentInode)
+ pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode)
if err != nil {
return err
}
@@ -367,7 +342,7 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
if err != nil {
return err
}
- pc, err := checkCreateLocked(rp, parentVFSD, parentInode)
+ pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode)
if err != nil {
return err
}
@@ -401,7 +376,7 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if err != nil {
return nil, err
}
- if err := inode.CheckPermissions(rp.Credentials(), ats); err != nil {
+ if err := inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
return nil, err
}
return inode.Open(rp, vfsd, opts.Flags)
@@ -420,7 +395,7 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if mustCreate {
return nil, syserror.EEXIST
}
- if err := inode.CheckPermissions(rp.Credentials(), ats); err != nil {
+ if err := inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
return nil, err
}
return inode.Open(rp, vfsd, opts.Flags)
@@ -432,7 +407,7 @@ afterTrailingSymlink:
return nil, err
}
// Check for search permission in the parent directory.
- if err := parentInode.CheckPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil {
return nil, err
}
// Reject attempts to open directories with O_CREAT.
@@ -450,7 +425,7 @@ afterTrailingSymlink:
}
if childVFSD == nil {
// Already checked for searchability above; now check for writability.
- if err := parentInode.CheckPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil {
return nil, err
}
if err := rp.Mount().CheckBeginWrite(); err != nil {
@@ -485,7 +460,7 @@ afterTrailingSymlink:
goto afterTrailingSymlink
}
}
- if err := childInode.CheckPermissions(rp.Credentials(), ats); err != nil {
+ if err := childInode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
return nil, err
}
return childInode.Open(rp, childVFSD, opts.Flags)
@@ -545,13 +520,13 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
srcVFSD := &src.vfsd
// Can we remove the src dentry?
- if err := checkDeleteLocked(rp, srcVFSD); err != nil {
+ if err := checkDeleteLocked(ctx, rp, srcVFSD); err != nil {
return err
}
// Can we create the dst dentry?
var dstVFSD *vfs.Dentry
- pc, err := checkCreateLocked(rp, dstDirVFSD, dstDirInode)
+ pc, err := checkCreateLocked(ctx, rp, dstDirVFSD, dstDirInode)
switch err {
case nil:
// Ok, continue with rename as replacement.
@@ -607,7 +582,7 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
return err
}
defer rp.Mount().EndWrite()
- if err := checkDeleteLocked(rp, vfsd); err != nil {
+ if err := checkDeleteLocked(ctx, rp, vfsd); err != nil {
return err
}
if !vfsd.Impl().(*Dentry).isDir() {
@@ -683,7 +658,7 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ
if err != nil {
return err
}
- pc, err := checkCreateLocked(rp, parentVFSD, parentInode)
+ pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode)
if err != nil {
return err
}
@@ -712,7 +687,7 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
return err
}
defer rp.Mount().EndWrite()
- if err := checkDeleteLocked(rp, vfsd); err != nil {
+ if err := checkDeleteLocked(ctx, rp, vfsd); err != nil {
return err
}
if vfsd.Impl().(*Dentry).isDir() {
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
index 752e0f659..1700fffd9 100644
--- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -16,7 +16,6 @@ package kernfs
import (
"fmt"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -24,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -262,7 +262,7 @@ func (a *InodeAttrs) SetStat(_ *vfs.Filesystem, opts vfs.SetStatOptions) error {
}
// CheckPermissions implements Inode.CheckPermissions.
-func (a *InodeAttrs) CheckPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
+func (a *InodeAttrs) CheckPermissions(_ context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
mode := a.Mode()
return vfs.GenericCheckPermissions(
creds,
@@ -510,3 +510,47 @@ type InodeSymlink struct {
func (InodeSymlink) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
return nil, syserror.ELOOP
}
+
+// StaticDirectory is a standard implementation of a directory with static
+// contents.
+//
+// +stateify savable
+type StaticDirectory struct {
+ InodeNotSymlink
+ InodeDirectoryNoNewChildren
+ InodeAttrs
+ InodeNoDynamicLookup
+ OrderedChildren
+}
+
+var _ Inode = (*StaticDirectory)(nil)
+
+// NewStaticDir creates a new static directory and returns its dentry.
+func NewStaticDir(creds *auth.Credentials, ino uint64, perm linux.FileMode, children map[string]*Dentry) *Dentry {
+ inode := &StaticDirectory{}
+ inode.Init(creds, ino, perm)
+
+ dentry := &Dentry{}
+ dentry.Init(inode)
+
+ inode.OrderedChildren.Init(OrderedChildrenOptions{})
+ links := inode.OrderedChildren.Populate(dentry, children)
+ inode.IncLinks(links)
+
+ return dentry
+}
+
+// Init initializes StaticDirectory.
+func (s *StaticDirectory) Init(creds *auth.Credentials, ino uint64, perm linux.FileMode) {
+ if perm&^linux.PermissionsMask != 0 {
+ panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask))
+ }
+ s.InodeAttrs.Init(creds, ino, linux.ModeDirectory|perm)
+}
+
+// Open implements kernfs.Inode.
+func (s *StaticDirectory) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+ fd := &GenericDirectoryFD{}
+ fd.Init(rp.Mount(), vfsd, &s.OrderedChildren, flags)
+ return fd.VFSFileDescription(), nil
+}
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index d69b299ae..85bcdcc57 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -53,7 +53,6 @@ package kernfs
import (
"fmt"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -61,6 +60,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
)
// FilesystemType implements vfs.FilesystemType.
@@ -320,7 +320,7 @@ type inodeMetadata interface {
// CheckPermissions checks that creds may access this inode for the
// requested access type, per the the rules of
// fs/namei.c:generic_permission().
- CheckPermissions(creds *auth.Credentials, atx vfs.AccessTypes) error
+ CheckPermissions(ctx context.Context, creds *auth.Credentials, atx vfs.AccessTypes) error
// Mode returns the (struct stat)::st_mode value for this inode. This is
// separated from Stat for performance.
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
index 4b6b95f5f..5c9d580e1 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -19,7 +19,6 @@ import (
"fmt"
"io"
"runtime"
- "sync"
"testing"
"github.com/google/go-cmp/cmp"
@@ -31,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go
index 068063f4e..f19f12854 100644
--- a/pkg/sentry/fsimpl/kernfs/symlink.go
+++ b/pkg/sentry/fsimpl/kernfs/symlink.go
@@ -20,7 +20,9 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
)
-type staticSymlink struct {
+// StaticSymlink provides an Inode implementation for symlinks that point to
+// a immutable target.
+type StaticSymlink struct {
InodeAttrs
InodeNoopRefCount
InodeSymlink
@@ -28,18 +30,25 @@ type staticSymlink struct {
target string
}
-var _ Inode = (*staticSymlink)(nil)
+var _ Inode = (*StaticSymlink)(nil)
// NewStaticSymlink creates a new symlink file pointing to 'target'.
-func NewStaticSymlink(creds *auth.Credentials, ino uint64, perm linux.FileMode, target string) *Dentry {
- inode := &staticSymlink{target: target}
- inode.Init(creds, ino, linux.ModeSymlink|perm)
+func NewStaticSymlink(creds *auth.Credentials, ino uint64, target string) *Dentry {
+ inode := &StaticSymlink{}
+ inode.Init(creds, ino, target)
d := &Dentry{}
d.Init(inode)
return d
}
-func (s *staticSymlink) Readlink(_ context.Context) (string, error) {
+// Init initializes the instance.
+func (s *StaticSymlink) Init(creds *auth.Credentials, ino uint64, target string) {
+ s.target = target
+ s.InodeAttrs.Init(creds, ino, linux.ModeSymlink|0777)
+}
+
+// Readlink implements Inode.
+func (s *StaticSymlink) Readlink(_ context.Context) (string, error) {
return s.target, nil
}
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
index 1f44b3217..e92564b5d 100644
--- a/pkg/sentry/fsimpl/proc/BUILD
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -7,17 +7,13 @@ go_library(
name = "proc",
srcs = [
"filesystem.go",
- "loadavg.go",
- "meminfo.go",
- "mounts.go",
- "net.go",
- "stat.go",
- "sys.go",
+ "subtasks.go",
"task.go",
"task_files.go",
"tasks.go",
"tasks_files.go",
- "version.go",
+ "tasks_net.go",
+ "tasks_sys.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/proc",
deps = [
@@ -30,8 +26,10 @@ go_library(
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
"//pkg/sentry/limits",
"//pkg/sentry/mm",
+ "//pkg/sentry/safemem",
"//pkg/sentry/socket",
"//pkg/sentry/socket/unix",
"//pkg/sentry/socket/unix/transport",
@@ -47,7 +45,7 @@ go_test(
size = "small",
srcs = [
"boot_test.go",
- "net_test.go",
+ "tasks_sys_test.go",
"tasks_test.go",
],
embed = [":proc"],
diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go
index d09182c77..e9cb7895f 100644
--- a/pkg/sentry/fsimpl/proc/filesystem.go
+++ b/pkg/sentry/fsimpl/proc/filesystem.go
@@ -67,3 +67,14 @@ func newDentry(creds *auth.Credentials, ino uint64, perm linux.FileMode, inode d
d.Init(inode)
return d
}
+
+type staticFile struct {
+ kernfs.DynamicBytesFile
+ vfs.StaticData
+}
+
+var _ dynamicInode = (*staticFile)(nil)
+
+func newStaticFile(data string) *staticFile {
+ return &staticFile{StaticData: vfs.StaticData{Data: data}}
+}
diff --git a/pkg/sentry/fsimpl/proc/loadavg.go b/pkg/sentry/fsimpl/proc/loadavg.go
deleted file mode 100644
index 5351d86e8..000000000
--- a/pkg/sentry/fsimpl/proc/loadavg.go
+++ /dev/null
@@ -1,42 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package proc
-
-import (
- "bytes"
- "fmt"
-
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
-)
-
-// loadavgData backs /proc/loadavg.
-//
-// +stateify savable
-type loadavgData struct {
- kernfs.DynamicBytesFile
-}
-
-var _ dynamicInode = (*loadavgData)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (d *loadavgData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- // TODO(b/62345059): Include real data in fields.
- // Column 1-3: CPU and IO utilization of the last 1, 5, and 10 minute periods.
- // Column 4-5: currently running processes and the total number of processes.
- // Column 6: the last process ID used.
- fmt.Fprintf(buf, "%.2f %.2f %.2f %d/%d %d\n", 0.00, 0.00, 0.00, 0, 0, 0)
- return nil
-}
diff --git a/pkg/sentry/fsimpl/proc/meminfo.go b/pkg/sentry/fsimpl/proc/meminfo.go
deleted file mode 100644
index cbdd4f3fc..000000000
--- a/pkg/sentry/fsimpl/proc/meminfo.go
+++ /dev/null
@@ -1,79 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package proc
-
-import (
- "bytes"
- "fmt"
-
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
-)
-
-// meminfoData implements vfs.DynamicBytesSource for /proc/meminfo.
-//
-// +stateify savable
-type meminfoData struct {
- kernfs.DynamicBytesFile
-
- // k is the owning Kernel.
- k *kernel.Kernel
-}
-
-var _ dynamicInode = (*meminfoData)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (d *meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- mf := d.k.MemoryFile()
- mf.UpdateUsage()
- snapshot, totalUsage := usage.MemoryAccounting.Copy()
- totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage)
- anon := snapshot.Anonymous + snapshot.Tmpfs
- file := snapshot.PageCache + snapshot.Mapped
- // We don't actually have active/inactive LRUs, so just make up numbers.
- activeFile := (file / 2) &^ (usermem.PageSize - 1)
- inactiveFile := file - activeFile
-
- fmt.Fprintf(buf, "MemTotal: %8d kB\n", totalSize/1024)
- memFree := (totalSize - totalUsage) / 1024
- // We use MemFree as MemAvailable because we don't swap.
- // TODO(rahat): When reclaim is implemented the value of MemAvailable
- // should change.
- fmt.Fprintf(buf, "MemFree: %8d kB\n", memFree)
- fmt.Fprintf(buf, "MemAvailable: %8d kB\n", memFree)
- fmt.Fprintf(buf, "Buffers: 0 kB\n") // memory usage by block devices
- fmt.Fprintf(buf, "Cached: %8d kB\n", (file+snapshot.Tmpfs)/1024)
- // Emulate a system with no swap, which disables inactivation of anon pages.
- fmt.Fprintf(buf, "SwapCache: 0 kB\n")
- fmt.Fprintf(buf, "Active: %8d kB\n", (anon+activeFile)/1024)
- fmt.Fprintf(buf, "Inactive: %8d kB\n", inactiveFile/1024)
- fmt.Fprintf(buf, "Active(anon): %8d kB\n", anon/1024)
- fmt.Fprintf(buf, "Inactive(anon): 0 kB\n")
- fmt.Fprintf(buf, "Active(file): %8d kB\n", activeFile/1024)
- fmt.Fprintf(buf, "Inactive(file): %8d kB\n", inactiveFile/1024)
- fmt.Fprintf(buf, "Unevictable: 0 kB\n") // TODO(b/31823263)
- fmt.Fprintf(buf, "Mlocked: 0 kB\n") // TODO(b/31823263)
- fmt.Fprintf(buf, "SwapTotal: 0 kB\n")
- fmt.Fprintf(buf, "SwapFree: 0 kB\n")
- fmt.Fprintf(buf, "Dirty: 0 kB\n")
- fmt.Fprintf(buf, "Writeback: 0 kB\n")
- fmt.Fprintf(buf, "AnonPages: %8d kB\n", anon/1024)
- fmt.Fprintf(buf, "Mapped: %8d kB\n", file/1024) // doesn't count mapped tmpfs, which we don't know
- fmt.Fprintf(buf, "Shmem: %8d kB\n", snapshot.Tmpfs/1024)
- return nil
-}
diff --git a/pkg/sentry/fsimpl/proc/stat.go b/pkg/sentry/fsimpl/proc/stat.go
deleted file mode 100644
index 50894a534..000000000
--- a/pkg/sentry/fsimpl/proc/stat.go
+++ /dev/null
@@ -1,129 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package proc
-
-import (
- "bytes"
- "fmt"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
-)
-
-// cpuStats contains the breakdown of CPU time for /proc/stat.
-type cpuStats struct {
- // user is time spent in userspace tasks with non-positive niceness.
- user uint64
-
- // nice is time spent in userspace tasks with positive niceness.
- nice uint64
-
- // system is time spent in non-interrupt kernel context.
- system uint64
-
- // idle is time spent idle.
- idle uint64
-
- // ioWait is time spent waiting for IO.
- ioWait uint64
-
- // irq is time spent in interrupt context.
- irq uint64
-
- // softirq is time spent in software interrupt context.
- softirq uint64
-
- // steal is involuntary wait time.
- steal uint64
-
- // guest is time spent in guests with non-positive niceness.
- guest uint64
-
- // guestNice is time spent in guests with positive niceness.
- guestNice uint64
-}
-
-// String implements fmt.Stringer.
-func (c cpuStats) String() string {
- return fmt.Sprintf("%d %d %d %d %d %d %d %d %d %d", c.user, c.nice, c.system, c.idle, c.ioWait, c.irq, c.softirq, c.steal, c.guest, c.guestNice)
-}
-
-// statData implements vfs.DynamicBytesSource for /proc/stat.
-//
-// +stateify savable
-type statData struct {
- kernfs.DynamicBytesFile
-
- // k is the owning Kernel.
- k *kernel.Kernel
-}
-
-var _ dynamicInode = (*statData)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (s *statData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- // TODO(b/37226836): We currently export only zero CPU stats. We could
- // at least provide some aggregate stats.
- var cpu cpuStats
- fmt.Fprintf(buf, "cpu %s\n", cpu)
-
- for c, max := uint(0), s.k.ApplicationCores(); c < max; c++ {
- fmt.Fprintf(buf, "cpu%d %s\n", c, cpu)
- }
-
- // The total number of interrupts is dependent on the CPUs and PCI
- // devices on the system. See arch_probe_nr_irqs.
- //
- // Since we don't report real interrupt stats, just choose an arbitrary
- // value from a representative VM.
- const numInterrupts = 256
-
- // The Kernel doesn't handle real interrupts, so report all zeroes.
- // TODO(b/37226836): We could count page faults as #PF.
- fmt.Fprintf(buf, "intr 0") // total
- for i := 0; i < numInterrupts; i++ {
- fmt.Fprintf(buf, " 0")
- }
- fmt.Fprintf(buf, "\n")
-
- // Total number of context switches.
- // TODO(b/37226836): Count this.
- fmt.Fprintf(buf, "ctxt 0\n")
-
- // CLOCK_REALTIME timestamp from boot, in seconds.
- fmt.Fprintf(buf, "btime %d\n", s.k.Timekeeper().BootTime().Seconds())
-
- // Total number of clones.
- // TODO(b/37226836): Count this.
- fmt.Fprintf(buf, "processes 0\n")
-
- // Number of runnable tasks.
- // TODO(b/37226836): Count this.
- fmt.Fprintf(buf, "procs_running 0\n")
-
- // Number of tasks waiting on IO.
- // TODO(b/37226836): Count this.
- fmt.Fprintf(buf, "procs_blocked 0\n")
-
- // Number of each softirq handled.
- fmt.Fprintf(buf, "softirq 0") // total
- for i := 0; i < linux.NumSoftIRQ; i++ {
- fmt.Fprintf(buf, " 0")
- }
- fmt.Fprintf(buf, "\n")
- return nil
-}
diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go
new file mode 100644
index 000000000..8892c5a11
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/subtasks.go
@@ -0,0 +1,126 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "sort"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// subtasksInode represents the inode for /proc/[pid]/task/ directory.
+//
+// +stateify savable
+type subtasksInode struct {
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeAttrs
+ kernfs.OrderedChildren
+
+ task *kernel.Task
+ pidns *kernel.PIDNamespace
+ inoGen InoGenerator
+}
+
+var _ kernfs.Inode = (*subtasksInode)(nil)
+
+func newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, inoGen InoGenerator) *kernfs.Dentry {
+ subInode := &subtasksInode{
+ task: task,
+ pidns: pidns,
+ inoGen: inoGen,
+ }
+ // Note: credentials are overridden by taskOwnedInode.
+ subInode.InodeAttrs.Init(task.Credentials(), inoGen.NextIno(), linux.ModeDirectory|0555)
+ subInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+
+ inode := &taskOwnedInode{Inode: subInode, owner: task}
+ dentry := &kernfs.Dentry{}
+ dentry.Init(inode)
+
+ return dentry
+}
+
+// Valid implements kernfs.inodeDynamicLookup.
+func (i *subtasksInode) Valid(ctx context.Context) bool {
+ return true
+}
+
+// Lookup implements kernfs.inodeDynamicLookup.
+func (i *subtasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ tid, err := strconv.ParseUint(name, 10, 32)
+ if err != nil {
+ return nil, syserror.ENOENT
+ }
+
+ subTask := i.pidns.TaskWithID(kernel.ThreadID(tid))
+ if subTask == nil {
+ return nil, syserror.ENOENT
+ }
+ if subTask.ThreadGroup() != i.task.ThreadGroup() {
+ return nil, syserror.ENOENT
+ }
+
+ subTaskDentry := newTaskInode(i.inoGen, subTask, i.pidns, false)
+ return subTaskDentry.VFSDentry(), nil
+}
+
+// IterDirents implements kernfs.inodeDynamicLookup.
+func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+ tasks := i.task.ThreadGroup().MemberIDs(i.pidns)
+ if len(tasks) == 0 {
+ return offset, syserror.ENOENT
+ }
+
+ tids := make([]int, 0, len(tasks))
+ for _, tid := range tasks {
+ tids = append(tids, int(tid))
+ }
+
+ sort.Ints(tids)
+ for _, tid := range tids[relOffset:] {
+ dirent := vfs.Dirent{
+ Name: strconv.FormatUint(uint64(tid), 10),
+ Type: linux.DT_DIR,
+ Ino: i.inoGen.NextIno(),
+ NextOff: offset + 1,
+ }
+ if !cb.Handle(dirent) {
+ return offset, nil
+ }
+ offset++
+ }
+ return offset, nil
+}
+
+// Open implements kernfs.Inode.
+func (i *subtasksInode) Open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
+ fd := &kernfs.GenericDirectoryFD{}
+ fd.Init(rp.Mount(), vfsd, &i.OrderedChildren, flags)
+ return fd.VFSFileDescription(), nil
+}
+
+// Stat implements kernfs.Inode.
+func (i *subtasksInode) Stat(vsfs *vfs.Filesystem) linux.Statx {
+ stat := i.InodeAttrs.Stat(vsfs)
+ stat.Nlink += uint32(i.task.ThreadGroup().Count())
+ return stat
+}
diff --git a/pkg/sentry/fsimpl/proc/sys.go b/pkg/sentry/fsimpl/proc/sys.go
deleted file mode 100644
index b88256e12..000000000
--- a/pkg/sentry/fsimpl/proc/sys.go
+++ /dev/null
@@ -1,51 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package proc
-
-import (
- "bytes"
- "fmt"
-
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
-)
-
-// mmapMinAddrData implements vfs.DynamicBytesSource for
-// /proc/sys/vm/mmap_min_addr.
-//
-// +stateify savable
-type mmapMinAddrData struct {
- k *kernel.Kernel
-}
-
-var _ vfs.DynamicBytesSource = (*mmapMinAddrData)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (d *mmapMinAddrData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- fmt.Fprintf(buf, "%d\n", d.k.Platform.MinUserAddress())
- return nil
-}
-
-// +stateify savable
-type overcommitMemory struct{}
-
-var _ vfs.DynamicBytesSource = (*overcommitMemory)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (d *overcommitMemory) Generate(ctx context.Context, buf *bytes.Buffer) error {
- fmt.Fprintf(buf, "0\n")
- return nil
-}
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index 11a64c777..621c17cfe 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -15,6 +15,8 @@
package proc
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
@@ -42,27 +44,31 @@ var _ kernfs.Inode = (*taskInode)(nil)
func newTaskInode(inoGen InoGenerator, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool) *kernfs.Dentry {
contents := map[string]*kernfs.Dentry{
- //"auxv": newAuxvec(t, msrc),
- //"cmdline": newExecArgInode(t, msrc, cmdlineExecArg),
- //"comm": newComm(t, msrc),
- //"environ": newExecArgInode(t, msrc, environExecArg),
+ "auxv": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &auxvData{task: task}),
+ "cmdline": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}),
+ "comm": newComm(task, inoGen.NextIno(), 0444),
+ "environ": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}),
//"exe": newExe(t, msrc),
//"fd": newFdDir(t, msrc),
//"fdinfo": newFdInfoDir(t, msrc),
- //"gid_map": newGIDMap(t, msrc),
- "io": newTaskOwnedFile(task, inoGen.NextIno(), defaultPermission, newIO(task, isThreadGroup)),
- "maps": newTaskOwnedFile(task, inoGen.NextIno(), defaultPermission, &mapsData{task: task}),
+ "gid_map": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &idMapData{task: task, gids: true}),
+ "io": newTaskOwnedFile(task, inoGen.NextIno(), 0400, newIO(task, isThreadGroup)),
+ "maps": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &mapsData{task: task}),
//"mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc),
//"mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc),
- //"ns": newNamespaceDir(t, msrc),
- "smaps": newTaskOwnedFile(task, inoGen.NextIno(), defaultPermission, &smapsData{task: task}),
- "stat": newTaskOwnedFile(task, inoGen.NextIno(), defaultPermission, &taskStatData{t: task, pidns: pidns, tgstats: isThreadGroup}),
- "statm": newTaskOwnedFile(task, inoGen.NextIno(), defaultPermission, &statmData{t: task}),
- "status": newTaskOwnedFile(task, inoGen.NextIno(), defaultPermission, &statusData{t: task, pidns: pidns}),
- //"uid_map": newUIDMap(t, msrc),
+ "ns": newTaskOwnedDir(task, inoGen.NextIno(), 0511, map[string]*kernfs.Dentry{
+ "net": newNamespaceSymlink(task, inoGen.NextIno(), "net"),
+ "pid": newNamespaceSymlink(task, inoGen.NextIno(), "pid"),
+ "user": newNamespaceSymlink(task, inoGen.NextIno(), "user"),
+ }),
+ "smaps": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &smapsData{task: task}),
+ "stat": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}),
+ "statm": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statmData{task: task}),
+ "status": newTaskOwnedFile(task, inoGen.NextIno(), 0444, &statusData{task: task, pidns: pidns}),
+ "uid_map": newTaskOwnedFile(task, inoGen.NextIno(), 0644, &idMapData{task: task, gids: false}),
}
if isThreadGroup {
- //contents["task"] = p.newSubtasks(t, msrc)
+ contents["task"] = newSubtasks(task, pidns, inoGen)
}
//if len(p.cgroupControllers) > 0 {
// contents["cgroup"] = newCGroupInode(t, msrc, p.cgroupControllers)
@@ -127,6 +133,23 @@ func newTaskOwnedFile(task *kernel.Task, ino uint64, perm linux.FileMode, inode
return d
}
+func newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.FileMode, children map[string]*kernfs.Dentry) *kernfs.Dentry {
+ dir := &kernfs.StaticDirectory{}
+
+ // Note: credentials are overridden by taskOwnedInode.
+ dir.Init(task.Credentials(), ino, perm)
+
+ inode := &taskOwnedInode{Inode: dir, owner: task}
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+
+ dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ links := dir.OrderedChildren.Populate(d, children)
+ dir.IncLinks(links)
+
+ return d
+}
+
// Stat implements kernfs.Inode.
func (i *taskOwnedInode) Stat(fs *vfs.Filesystem) linux.Statx {
stat := i.Inode.Stat(fs)
@@ -137,7 +160,7 @@ func (i *taskOwnedInode) Stat(fs *vfs.Filesystem) linux.Statx {
}
// CheckPermissions implements kernfs.Inode.
-func (i *taskOwnedInode) CheckPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
+func (i *taskOwnedInode) CheckPermissions(_ context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
mode := i.Mode()
uid, gid := i.getOwner(mode)
return vfs.GenericCheckPermissions(
@@ -188,3 +211,19 @@ func newIO(t *kernel.Task, isThreadGroup bool) *ioData {
}
return &ioData{ioUsage: t}
}
+
+func newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) *kernfs.Dentry {
+ // Namespace symlinks should contain the namespace name and the inode number
+ // for the namespace instance, so for example user:[123456]. We currently fake
+ // the inode number by sticking the symlink inode in its place.
+ target := fmt.Sprintf("%s:[%d]", ns, ino)
+
+ inode := &kernfs.StaticSymlink{}
+ // Note: credentials are overridden by taskOwnedInode.
+ inode.Init(task.Credentials(), ino, target)
+
+ taskInode := &taskOwnedInode{Inode: inode, owner: task}
+ d := &kernfs.Dentry{}
+ d.Init(taskInode)
+ return d
+}
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index 93f0e1aa8..7bc352ae9 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -17,15 +17,20 @@ package proc
import (
"bytes"
"fmt"
+ "io"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
)
// mm gets the kernel task's MemoryManager. No additional reference is taken on
@@ -41,6 +46,256 @@ func getMM(task *kernel.Task) *mm.MemoryManager {
return tmm
}
+// getMMIncRef returns t's MemoryManager. If getMMIncRef succeeds, the
+// MemoryManager's users count is incremented, and must be decremented by the
+// caller when it is no longer in use.
+func getMMIncRef(task *kernel.Task) (*mm.MemoryManager, error) {
+ if task.ExitState() == kernel.TaskExitDead {
+ return nil, syserror.ESRCH
+ }
+ var m *mm.MemoryManager
+ task.WithMuLocked(func(t *kernel.Task) {
+ m = t.MemoryManager()
+ })
+ if m == nil || !m.IncUsers() {
+ return nil, io.EOF
+ }
+ return m, nil
+}
+
+type bufferWriter struct {
+ buf *bytes.Buffer
+}
+
+// WriteFromBlocks writes up to srcs.NumBytes() bytes from srcs and returns
+// the number of bytes written. It may return a partial write without an
+// error (i.e. (n, nil) where 0 < n < srcs.NumBytes()). It should not
+// return a full write with an error (i.e. srcs.NumBytes(), err) where err
+// != nil).
+func (w *bufferWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ written := srcs.NumBytes()
+ for !srcs.IsEmpty() {
+ w.buf.Write(srcs.Head().ToSlice())
+ srcs = srcs.Tail()
+ }
+ return written, nil
+}
+
+// auxvData implements vfs.DynamicBytesSource for /proc/[pid]/auxv.
+//
+// +stateify savable
+type auxvData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*auxvData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *auxvData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ m, err := getMMIncRef(d.task)
+ if err != nil {
+ return err
+ }
+ defer m.DecUsers(ctx)
+
+ // Space for buffer with AT_NULL (0) terminator at the end.
+ auxv := m.Auxv()
+ buf.Grow((len(auxv) + 1) * 16)
+ for _, e := range auxv {
+ var tmp [8]byte
+ usermem.ByteOrder.PutUint64(tmp[:], e.Key)
+ buf.Write(tmp[:])
+
+ usermem.ByteOrder.PutUint64(tmp[:], uint64(e.Value))
+ buf.Write(tmp[:])
+ }
+ return nil
+}
+
+// execArgType enumerates the types of exec arguments that are exposed through
+// proc.
+type execArgType int
+
+const (
+ cmdlineDataArg execArgType = iota
+ environDataArg
+)
+
+// cmdlineData implements vfs.DynamicBytesSource for /proc/[pid]/cmdline.
+//
+// +stateify savable
+type cmdlineData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+
+ // arg is the type of exec argument this file contains.
+ arg execArgType
+}
+
+var _ dynamicInode = (*cmdlineData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ m, err := getMMIncRef(d.task)
+ if err != nil {
+ return err
+ }
+ defer m.DecUsers(ctx)
+
+ // Figure out the bounds of the exec arg we are trying to read.
+ var ar usermem.AddrRange
+ switch d.arg {
+ case cmdlineDataArg:
+ ar = usermem.AddrRange{
+ Start: m.ArgvStart(),
+ End: m.ArgvEnd(),
+ }
+ case environDataArg:
+ ar = usermem.AddrRange{
+ Start: m.EnvvStart(),
+ End: m.EnvvEnd(),
+ }
+ default:
+ panic(fmt.Sprintf("unknown exec arg type %v", d.arg))
+ }
+ if ar.Start == 0 || ar.End == 0 {
+ // Don't attempt to read before the start/end are set up.
+ return io.EOF
+ }
+
+ // N.B. Technically this should be usermem.IOOpts.IgnorePermissions = true
+ // until Linux 4.9 (272ddc8b3735 "proc: don't use FOLL_FORCE for reading
+ // cmdline and environment").
+ writer := &bufferWriter{buf: buf}
+ if n, err := m.CopyInTo(ctx, usermem.AddrRangeSeqOf(ar), writer, usermem.IOOpts{}); n == 0 || err != nil {
+ // Nothing to copy or something went wrong.
+ return err
+ }
+
+ // On Linux, if the NULL byte at the end of the argument vector has been
+ // overwritten, it continues reading the environment vector as part of
+ // the argument vector.
+ if d.arg == cmdlineDataArg && buf.Bytes()[buf.Len()-1] != 0 {
+ if end := bytes.IndexByte(buf.Bytes(), 0); end != -1 {
+ // If we found a NULL character somewhere else in argv, truncate the
+ // return up to the NULL terminator (including it).
+ buf.Truncate(end)
+ return nil
+ }
+
+ // There is no NULL terminator in the string, return into envp.
+ arEnvv := usermem.AddrRange{
+ Start: m.EnvvStart(),
+ End: m.EnvvEnd(),
+ }
+
+ // Upstream limits the returned amount to one page of slop.
+ // https://elixir.bootlin.com/linux/v4.20/source/fs/proc/base.c#L208
+ // we'll return one page total between argv and envp because of the
+ // above page restrictions.
+ if buf.Len() >= usermem.PageSize {
+ // Returned at least one page already, nothing else to add.
+ return nil
+ }
+ remaining := usermem.PageSize - buf.Len()
+ if int(arEnvv.Length()) > remaining {
+ end, ok := arEnvv.Start.AddLength(uint64(remaining))
+ if !ok {
+ return syserror.EFAULT
+ }
+ arEnvv.End = end
+ }
+ if _, err := m.CopyInTo(ctx, usermem.AddrRangeSeqOf(arEnvv), writer, usermem.IOOpts{}); err != nil {
+ return err
+ }
+
+ // Linux will return envp up to and including the first NULL character,
+ // so find it.
+ if end := bytes.IndexByte(buf.Bytes()[ar.Length():], 0); end != -1 {
+ buf.Truncate(end)
+ }
+ }
+
+ return nil
+}
+
+// +stateify savable
+type commInode struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+func newComm(task *kernel.Task, ino uint64, perm linux.FileMode) *kernfs.Dentry {
+ inode := &commInode{task: task}
+ inode.DynamicBytesFile.Init(task.Credentials(), ino, &commData{task: task}, perm)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+func (i *commInode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ // This file can always be read or written by members of the same thread
+ // group. See fs/proc/base.c:proc_tid_comm_permission.
+ //
+ // N.B. This check is currently a no-op as we don't yet support writing and
+ // this file is world-readable anyways.
+ t := kernel.TaskFromContext(ctx)
+ if t != nil && t.ThreadGroup() == i.task.ThreadGroup() && !ats.MayExec() {
+ return nil
+ }
+
+ return i.DynamicBytesFile.CheckPermissions(ctx, creds, ats)
+}
+
+// commData implements vfs.DynamicBytesSource for /proc/[pid]/comm.
+//
+// +stateify savable
+type commData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*commData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *commData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ buf.WriteString(d.task.Name())
+ buf.WriteString("\n")
+ return nil
+}
+
+// idMapData implements vfs.DynamicBytesSource for /proc/[pid]/{gid_map|uid_map}.
+//
+// +stateify savable
+type idMapData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+ gids bool
+}
+
+var _ dynamicInode = (*idMapData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *idMapData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ var entries []auth.IDMapEntry
+ if d.gids {
+ entries = d.task.UserNamespace().GIDMap()
+ } else {
+ entries = d.task.UserNamespace().UIDMap()
+ }
+ for _, e := range entries {
+ fmt.Fprintf(buf, "%10d %10d %10d\n", e.FirstID, e.FirstParentID, e.Length)
+ }
+ return nil
+}
+
// mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps.
//
// +stateify savable
@@ -83,7 +338,7 @@ func (d *smapsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
type taskStatData struct {
kernfs.DynamicBytesFile
- t *kernel.Task
+ task *kernel.Task
// If tgstats is true, accumulate fault stats (not implemented) and CPU
// time across all tasks in t's thread group.
@@ -98,40 +353,40 @@ var _ dynamicInode = (*taskStatData)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
func (s *taskStatData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- fmt.Fprintf(buf, "%d ", s.pidns.IDOfTask(s.t))
- fmt.Fprintf(buf, "(%s) ", s.t.Name())
- fmt.Fprintf(buf, "%c ", s.t.StateStatus()[0])
+ fmt.Fprintf(buf, "%d ", s.pidns.IDOfTask(s.task))
+ fmt.Fprintf(buf, "(%s) ", s.task.Name())
+ fmt.Fprintf(buf, "%c ", s.task.StateStatus()[0])
ppid := kernel.ThreadID(0)
- if parent := s.t.Parent(); parent != nil {
+ if parent := s.task.Parent(); parent != nil {
ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
}
fmt.Fprintf(buf, "%d ", ppid)
- fmt.Fprintf(buf, "%d ", s.pidns.IDOfProcessGroup(s.t.ThreadGroup().ProcessGroup()))
- fmt.Fprintf(buf, "%d ", s.pidns.IDOfSession(s.t.ThreadGroup().Session()))
+ fmt.Fprintf(buf, "%d ", s.pidns.IDOfProcessGroup(s.task.ThreadGroup().ProcessGroup()))
+ fmt.Fprintf(buf, "%d ", s.pidns.IDOfSession(s.task.ThreadGroup().Session()))
fmt.Fprintf(buf, "0 0 " /* tty_nr tpgid */)
fmt.Fprintf(buf, "0 " /* flags */)
fmt.Fprintf(buf, "0 0 0 0 " /* minflt cminflt majflt cmajflt */)
var cputime usage.CPUStats
if s.tgstats {
- cputime = s.t.ThreadGroup().CPUStats()
+ cputime = s.task.ThreadGroup().CPUStats()
} else {
- cputime = s.t.CPUStats()
+ cputime = s.task.CPUStats()
}
fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime))
- cputime = s.t.ThreadGroup().JoinedChildCPUStats()
+ cputime = s.task.ThreadGroup().JoinedChildCPUStats()
fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime))
- fmt.Fprintf(buf, "%d %d ", s.t.Priority(), s.t.Niceness())
- fmt.Fprintf(buf, "%d ", s.t.ThreadGroup().Count())
+ fmt.Fprintf(buf, "%d %d ", s.task.Priority(), s.task.Niceness())
+ fmt.Fprintf(buf, "%d ", s.task.ThreadGroup().Count())
// itrealvalue. Since kernel 2.6.17, this field is no longer
// maintained, and is hard coded as 0.
fmt.Fprintf(buf, "0 ")
// Start time is relative to boot time, expressed in clock ticks.
- fmt.Fprintf(buf, "%d ", linux.ClockTFromDuration(s.t.StartTime().Sub(s.t.Kernel().Timekeeper().BootTime())))
+ fmt.Fprintf(buf, "%d ", linux.ClockTFromDuration(s.task.StartTime().Sub(s.task.Kernel().Timekeeper().BootTime())))
var vss, rss uint64
- s.t.WithMuLocked(func(t *kernel.Task) {
+ s.task.WithMuLocked(func(t *kernel.Task) {
if mm := t.MemoryManager(); mm != nil {
vss = mm.VirtualMemorySize()
rss = mm.ResidentSetSize()
@@ -140,14 +395,14 @@ func (s *taskStatData) Generate(ctx context.Context, buf *bytes.Buffer) error {
fmt.Fprintf(buf, "%d %d ", vss, rss/usermem.PageSize)
// rsslim.
- fmt.Fprintf(buf, "%d ", s.t.ThreadGroup().Limits().Get(limits.Rss).Cur)
+ fmt.Fprintf(buf, "%d ", s.task.ThreadGroup().Limits().Get(limits.Rss).Cur)
fmt.Fprintf(buf, "0 0 0 0 0 " /* startcode endcode startstack kstkesp kstkeip */)
fmt.Fprintf(buf, "0 0 0 0 0 " /* signal blocked sigignore sigcatch wchan */)
fmt.Fprintf(buf, "0 0 " /* nswap cnswap */)
terminationSignal := linux.Signal(0)
- if s.t == s.t.ThreadGroup().Leader() {
- terminationSignal = s.t.ThreadGroup().TerminationSignal()
+ if s.task == s.task.ThreadGroup().Leader() {
+ terminationSignal = s.task.ThreadGroup().TerminationSignal()
}
fmt.Fprintf(buf, "%d ", terminationSignal)
fmt.Fprintf(buf, "0 0 0 " /* processor rt_priority policy */)
@@ -164,7 +419,7 @@ func (s *taskStatData) Generate(ctx context.Context, buf *bytes.Buffer) error {
type statmData struct {
kernfs.DynamicBytesFile
- t *kernel.Task
+ task *kernel.Task
}
var _ dynamicInode = (*statmData)(nil)
@@ -172,7 +427,7 @@ var _ dynamicInode = (*statmData)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error {
var vss, rss uint64
- s.t.WithMuLocked(func(t *kernel.Task) {
+ s.task.WithMuLocked(func(t *kernel.Task) {
if mm := t.MemoryManager(); mm != nil {
vss = mm.VirtualMemorySize()
rss = mm.ResidentSetSize()
@@ -189,7 +444,7 @@ func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error {
type statusData struct {
kernfs.DynamicBytesFile
- t *kernel.Task
+ task *kernel.Task
pidns *kernel.PIDNamespace
}
@@ -197,23 +452,23 @@ var _ dynamicInode = (*statusData)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- fmt.Fprintf(buf, "Name:\t%s\n", s.t.Name())
- fmt.Fprintf(buf, "State:\t%s\n", s.t.StateStatus())
- fmt.Fprintf(buf, "Tgid:\t%d\n", s.pidns.IDOfThreadGroup(s.t.ThreadGroup()))
- fmt.Fprintf(buf, "Pid:\t%d\n", s.pidns.IDOfTask(s.t))
+ fmt.Fprintf(buf, "Name:\t%s\n", s.task.Name())
+ fmt.Fprintf(buf, "State:\t%s\n", s.task.StateStatus())
+ fmt.Fprintf(buf, "Tgid:\t%d\n", s.pidns.IDOfThreadGroup(s.task.ThreadGroup()))
+ fmt.Fprintf(buf, "Pid:\t%d\n", s.pidns.IDOfTask(s.task))
ppid := kernel.ThreadID(0)
- if parent := s.t.Parent(); parent != nil {
+ if parent := s.task.Parent(); parent != nil {
ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
}
fmt.Fprintf(buf, "PPid:\t%d\n", ppid)
tpid := kernel.ThreadID(0)
- if tracer := s.t.Tracer(); tracer != nil {
+ if tracer := s.task.Tracer(); tracer != nil {
tpid = s.pidns.IDOfTask(tracer)
}
fmt.Fprintf(buf, "TracerPid:\t%d\n", tpid)
var fds int
var vss, rss, data uint64
- s.t.WithMuLocked(func(t *kernel.Task) {
+ s.task.WithMuLocked(func(t *kernel.Task) {
if fdTable := t.FDTable(); fdTable != nil {
fds = fdTable.Size()
}
@@ -227,13 +482,13 @@ func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error {
fmt.Fprintf(buf, "VmSize:\t%d kB\n", vss>>10)
fmt.Fprintf(buf, "VmRSS:\t%d kB\n", rss>>10)
fmt.Fprintf(buf, "VmData:\t%d kB\n", data>>10)
- fmt.Fprintf(buf, "Threads:\t%d\n", s.t.ThreadGroup().Count())
- creds := s.t.Credentials()
+ fmt.Fprintf(buf, "Threads:\t%d\n", s.task.ThreadGroup().Count())
+ creds := s.task.Credentials()
fmt.Fprintf(buf, "CapInh:\t%016x\n", creds.InheritableCaps)
fmt.Fprintf(buf, "CapPrm:\t%016x\n", creds.PermittedCaps)
fmt.Fprintf(buf, "CapEff:\t%016x\n", creds.EffectiveCaps)
fmt.Fprintf(buf, "CapBnd:\t%016x\n", creds.BoundingCaps)
- fmt.Fprintf(buf, "Seccomp:\t%d\n", s.t.SeccompMode())
+ fmt.Fprintf(buf, "Seccomp:\t%d\n", s.task.SeccompMode())
// We unconditionally report a single NUMA node. See
// pkg/sentry/syscalls/linux/sys_mempolicy.go.
fmt.Fprintf(buf, "Mems_allowed:\t1\n")
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index d8f92d52f..a97b1753a 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -15,6 +15,7 @@
package proc
import (
+ "bytes"
"sort"
"strconv"
@@ -28,9 +29,8 @@ import (
)
const (
- defaultPermission = 0444
- selfName = "self"
- threadSelfName = "thread-self"
+ selfName = "self"
+ threadSelfName = "thread-self"
)
// InoGenerator generates unique inode numbers for a given filesystem.
@@ -61,15 +61,15 @@ var _ kernfs.Inode = (*tasksInode)(nil)
func newTasksInode(inoGen InoGenerator, k *kernel.Kernel, pidns *kernel.PIDNamespace) (*tasksInode, *kernfs.Dentry) {
root := auth.NewRootCredentials(pidns.UserNamespace())
contents := map[string]*kernfs.Dentry{
- //"cpuinfo": newCPUInfo(ctx, msrc),
- //"filesystems": seqfile.NewSeqFileInode(ctx, &filesystemsData{}, msrc),
- "loadavg": newDentry(root, inoGen.NextIno(), defaultPermission, &loadavgData{}),
- "meminfo": newDentry(root, inoGen.NextIno(), defaultPermission, &meminfoData{k: k}),
- "mounts": kernfs.NewStaticSymlink(root, inoGen.NextIno(), defaultPermission, "self/mounts"),
- "stat": newDentry(root, inoGen.NextIno(), defaultPermission, &statData{k: k}),
- //"uptime": newUptime(ctx, msrc),
- //"version": newVersionData(root, inoGen.NextIno(), k),
- "version": newDentry(root, inoGen.NextIno(), defaultPermission, &versionData{k: k}),
+ "cpuinfo": newDentry(root, inoGen.NextIno(), 0444, newStaticFile(cpuInfoData(k))),
+ //"filesystems": newDentry(root, inoGen.NextIno(), 0444, &filesystemsData{}),
+ "loadavg": newDentry(root, inoGen.NextIno(), 0444, &loadavgData{}),
+ "sys": newSysDir(root, inoGen),
+ "meminfo": newDentry(root, inoGen.NextIno(), 0444, &meminfoData{}),
+ "mounts": kernfs.NewStaticSymlink(root, inoGen.NextIno(), "self/mounts"),
+ "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{}),
+ "uptime": newDentry(root, inoGen.NextIno(), 0444, &uptimeData{}),
+ "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{}),
}
inode := &tasksInode{
@@ -216,3 +216,20 @@ func (i *tasksInode) Stat(vsfs *vfs.Filesystem) linux.Statx {
return stat
}
+
+func cpuInfoData(k *kernel.Kernel) string {
+ features := k.FeatureSet()
+ if features == nil {
+ // Kernel is always initialized with a FeatureSet.
+ panic("cpuinfo read with nil FeatureSet")
+ }
+ var buf bytes.Buffer
+ for i, max := uint(0), k.ApplicationCores(); i < max; i++ {
+ features.WriteCPUInfoTo(i, &buf)
+ }
+ return buf.String()
+}
+
+func shmData(v uint64) dynamicInode {
+ return newStaticFile(strconv.FormatUint(v, 10))
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go
index 91f30a798..ad3760e39 100644
--- a/pkg/sentry/fsimpl/proc/tasks_files.go
+++ b/pkg/sentry/fsimpl/proc/tasks_files.go
@@ -15,6 +15,7 @@
package proc
import (
+ "bytes"
"fmt"
"strconv"
@@ -23,6 +24,9 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -90,3 +94,244 @@ func (s *threadSelfSymlink) Readlink(ctx context.Context) (string, error) {
}
return fmt.Sprintf("%d/task/%d", tgid, tid), nil
}
+
+// cpuStats contains the breakdown of CPU time for /proc/stat.
+type cpuStats struct {
+ // user is time spent in userspace tasks with non-positive niceness.
+ user uint64
+
+ // nice is time spent in userspace tasks with positive niceness.
+ nice uint64
+
+ // system is time spent in non-interrupt kernel context.
+ system uint64
+
+ // idle is time spent idle.
+ idle uint64
+
+ // ioWait is time spent waiting for IO.
+ ioWait uint64
+
+ // irq is time spent in interrupt context.
+ irq uint64
+
+ // softirq is time spent in software interrupt context.
+ softirq uint64
+
+ // steal is involuntary wait time.
+ steal uint64
+
+ // guest is time spent in guests with non-positive niceness.
+ guest uint64
+
+ // guestNice is time spent in guests with positive niceness.
+ guestNice uint64
+}
+
+// String implements fmt.Stringer.
+func (c cpuStats) String() string {
+ return fmt.Sprintf("%d %d %d %d %d %d %d %d %d %d", c.user, c.nice, c.system, c.idle, c.ioWait, c.irq, c.softirq, c.steal, c.guest, c.guestNice)
+}
+
+// statData implements vfs.DynamicBytesSource for /proc/stat.
+//
+// +stateify savable
+type statData struct {
+ kernfs.DynamicBytesFile
+
+ // k is the owning Kernel.
+ k *kernel.Kernel
+}
+
+var _ dynamicInode = (*statData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (s *statData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ // TODO(b/37226836): We currently export only zero CPU stats. We could
+ // at least provide some aggregate stats.
+ var cpu cpuStats
+ fmt.Fprintf(buf, "cpu %s\n", cpu)
+
+ for c, max := uint(0), s.k.ApplicationCores(); c < max; c++ {
+ fmt.Fprintf(buf, "cpu%d %s\n", c, cpu)
+ }
+
+ // The total number of interrupts is dependent on the CPUs and PCI
+ // devices on the system. See arch_probe_nr_irqs.
+ //
+ // Since we don't report real interrupt stats, just choose an arbitrary
+ // value from a representative VM.
+ const numInterrupts = 256
+
+ // The Kernel doesn't handle real interrupts, so report all zeroes.
+ // TODO(b/37226836): We could count page faults as #PF.
+ fmt.Fprintf(buf, "intr 0") // total
+ for i := 0; i < numInterrupts; i++ {
+ fmt.Fprintf(buf, " 0")
+ }
+ fmt.Fprintf(buf, "\n")
+
+ // Total number of context switches.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(buf, "ctxt 0\n")
+
+ // CLOCK_REALTIME timestamp from boot, in seconds.
+ fmt.Fprintf(buf, "btime %d\n", s.k.Timekeeper().BootTime().Seconds())
+
+ // Total number of clones.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(buf, "processes 0\n")
+
+ // Number of runnable tasks.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(buf, "procs_running 0\n")
+
+ // Number of tasks waiting on IO.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(buf, "procs_blocked 0\n")
+
+ // Number of each softirq handled.
+ fmt.Fprintf(buf, "softirq 0") // total
+ for i := 0; i < linux.NumSoftIRQ; i++ {
+ fmt.Fprintf(buf, " 0")
+ }
+ fmt.Fprintf(buf, "\n")
+ return nil
+}
+
+// loadavgData backs /proc/loadavg.
+//
+// +stateify savable
+type loadavgData struct {
+ kernfs.DynamicBytesFile
+}
+
+var _ dynamicInode = (*loadavgData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *loadavgData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ // TODO(b/62345059): Include real data in fields.
+ // Column 1-3: CPU and IO utilization of the last 1, 5, and 10 minute periods.
+ // Column 4-5: currently running processes and the total number of processes.
+ // Column 6: the last process ID used.
+ fmt.Fprintf(buf, "%.2f %.2f %.2f %d/%d %d\n", 0.00, 0.00, 0.00, 0, 0, 0)
+ return nil
+}
+
+// meminfoData implements vfs.DynamicBytesSource for /proc/meminfo.
+//
+// +stateify savable
+type meminfoData struct {
+ kernfs.DynamicBytesFile
+
+ // k is the owning Kernel.
+ k *kernel.Kernel
+}
+
+var _ dynamicInode = (*meminfoData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ mf := d.k.MemoryFile()
+ mf.UpdateUsage()
+ snapshot, totalUsage := usage.MemoryAccounting.Copy()
+ totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage)
+ anon := snapshot.Anonymous + snapshot.Tmpfs
+ file := snapshot.PageCache + snapshot.Mapped
+ // We don't actually have active/inactive LRUs, so just make up numbers.
+ activeFile := (file / 2) &^ (usermem.PageSize - 1)
+ inactiveFile := file - activeFile
+
+ fmt.Fprintf(buf, "MemTotal: %8d kB\n", totalSize/1024)
+ memFree := (totalSize - totalUsage) / 1024
+ // We use MemFree as MemAvailable because we don't swap.
+ // TODO(rahat): When reclaim is implemented the value of MemAvailable
+ // should change.
+ fmt.Fprintf(buf, "MemFree: %8d kB\n", memFree)
+ fmt.Fprintf(buf, "MemAvailable: %8d kB\n", memFree)
+ fmt.Fprintf(buf, "Buffers: 0 kB\n") // memory usage by block devices
+ fmt.Fprintf(buf, "Cached: %8d kB\n", (file+snapshot.Tmpfs)/1024)
+ // Emulate a system with no swap, which disables inactivation of anon pages.
+ fmt.Fprintf(buf, "SwapCache: 0 kB\n")
+ fmt.Fprintf(buf, "Active: %8d kB\n", (anon+activeFile)/1024)
+ fmt.Fprintf(buf, "Inactive: %8d kB\n", inactiveFile/1024)
+ fmt.Fprintf(buf, "Active(anon): %8d kB\n", anon/1024)
+ fmt.Fprintf(buf, "Inactive(anon): 0 kB\n")
+ fmt.Fprintf(buf, "Active(file): %8d kB\n", activeFile/1024)
+ fmt.Fprintf(buf, "Inactive(file): %8d kB\n", inactiveFile/1024)
+ fmt.Fprintf(buf, "Unevictable: 0 kB\n") // TODO(b/31823263)
+ fmt.Fprintf(buf, "Mlocked: 0 kB\n") // TODO(b/31823263)
+ fmt.Fprintf(buf, "SwapTotal: 0 kB\n")
+ fmt.Fprintf(buf, "SwapFree: 0 kB\n")
+ fmt.Fprintf(buf, "Dirty: 0 kB\n")
+ fmt.Fprintf(buf, "Writeback: 0 kB\n")
+ fmt.Fprintf(buf, "AnonPages: %8d kB\n", anon/1024)
+ fmt.Fprintf(buf, "Mapped: %8d kB\n", file/1024) // doesn't count mapped tmpfs, which we don't know
+ fmt.Fprintf(buf, "Shmem: %8d kB\n", snapshot.Tmpfs/1024)
+ return nil
+}
+
+// uptimeData implements vfs.DynamicBytesSource for /proc/uptime.
+//
+// +stateify savable
+type uptimeData struct {
+ kernfs.DynamicBytesFile
+}
+
+var _ dynamicInode = (*uptimeData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (*uptimeData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ k := kernel.KernelFromContext(ctx)
+ now := time.NowFromContext(ctx)
+
+ // Pretend that we've spent zero time sleeping (second number).
+ fmt.Fprintf(buf, "%.2f 0.00\n", now.Sub(k.Timekeeper().BootTime()).Seconds())
+ return nil
+}
+
+// versionData implements vfs.DynamicBytesSource for /proc/version.
+//
+// +stateify savable
+type versionData struct {
+ kernfs.DynamicBytesFile
+
+ // k is the owning Kernel.
+ k *kernel.Kernel
+}
+
+var _ dynamicInode = (*versionData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (v *versionData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ init := v.k.GlobalInit()
+ if init == nil {
+ // Attempted to read before the init Task is created. This can
+ // only occur during startup, which should never need to read
+ // this file.
+ panic("Attempted to read version before initial Task is available")
+ }
+
+ // /proc/version takes the form:
+ //
+ // "SYSNAME version RELEASE (COMPILE_USER@COMPILE_HOST)
+ // (COMPILER_VERSION) VERSION"
+ //
+ // where:
+ // - SYSNAME, RELEASE, and VERSION are the same as returned by
+ // sys_utsname
+ // - COMPILE_USER is the user that build the kernel
+ // - COMPILE_HOST is the hostname of the machine on which the kernel
+ // was built
+ // - COMPILER_VERSION is the version reported by the building compiler
+ //
+ // Since we don't really want to expose build information to
+ // applications, those fields are omitted.
+ //
+ // FIXME(mpratt): Using Version from the init task SyscallTable
+ // disregards the different version a task may have (e.g., in a uts
+ // namespace).
+ ver := init.Leader().SyscallTable().Version
+ fmt.Fprintf(buf, "%s version %s %s\n", ver.Sysname, ver.Release, ver.Version)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/net.go b/pkg/sentry/fsimpl/proc/tasks_net.go
index fd46eebf8..06dc43c26 100644
--- a/pkg/sentry/fsimpl/proc/net.go
+++ b/pkg/sentry/fsimpl/proc/tasks_net.go
@@ -46,8 +46,7 @@ func (n *ifinet6) contents() []string {
for id, naddrs := range n.s.InterfaceAddrs() {
nic, ok := nics[id]
if !ok {
- // NIC was added after NICNames was called. We'll just
- // ignore it.
+ // NIC was added after NICNames was called. We'll just ignore it.
continue
}
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
new file mode 100644
index 000000000..aabf2bf0c
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -0,0 +1,143 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// newSysDir returns the dentry corresponding to /proc/sys directory.
+func newSysDir(root *auth.Credentials, inoGen InoGenerator) *kernfs.Dentry {
+ return kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{
+ "kernel": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{
+ "hostname": newDentry(root, inoGen.NextIno(), 0444, &hostnameData{}),
+ "shmall": newDentry(root, inoGen.NextIno(), 0444, shmData(linux.SHMALL)),
+ "shmmax": newDentry(root, inoGen.NextIno(), 0444, shmData(linux.SHMMAX)),
+ "shmmni": newDentry(root, inoGen.NextIno(), 0444, shmData(linux.SHMMNI)),
+ }),
+ "vm": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{
+ "mmap_min_addr": newDentry(root, inoGen.NextIno(), 0444, &mmapMinAddrData{}),
+ "overcommit_memory": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0\n")),
+ }),
+ "net": newSysNetDir(root, inoGen),
+ })
+}
+
+// newSysNetDir returns the dentry corresponding to /proc/sys/net directory.
+func newSysNetDir(root *auth.Credentials, inoGen InoGenerator) *kernfs.Dentry {
+ return kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{
+ "net": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{
+ "ipv4": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{
+ // Add tcp_sack.
+ // TODO(gvisor.dev/issue/1195): tcp_sack allows write(2)
+ // "tcp_sack": newTCPSackInode(ctx, msrc, s),
+
+ // The following files are simple stubs until they are implemented in
+ // netstack, most of these files are configuration related. We use the
+ // value closest to the actual netstack behavior or any empty file, all
+ // of these files will have mode 0444 (read-only for all users).
+ "ip_local_port_range": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("16000 65535")),
+ "ip_local_reserved_ports": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("")),
+ "ipfrag_time": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("30")),
+ "ip_nonlocal_bind": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "ip_no_pmtu_disc": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1")),
+
+ // tcp_allowed_congestion_control tell the user what they are able to
+ // do as an unprivledged process so we leave it empty.
+ "tcp_allowed_congestion_control": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("")),
+ "tcp_available_congestion_control": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("reno")),
+ "tcp_congestion_control": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("reno")),
+
+ // Many of the following stub files are features netstack doesn't
+ // support. The unsupported features return "0" to indicate they are
+ // disabled.
+ "tcp_base_mss": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1280")),
+ "tcp_dsack": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "tcp_early_retrans": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "tcp_fack": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "tcp_fastopen": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "tcp_fastopen_key": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("")),
+ "tcp_invalid_ratelimit": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "tcp_keepalive_intvl": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "tcp_keepalive_probes": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "tcp_keepalive_time": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("7200")),
+ "tcp_mtu_probing": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "tcp_no_metrics_save": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1")),
+ "tcp_probe_interval": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "tcp_probe_threshold": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "tcp_retries1": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("3")),
+ "tcp_retries2": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("15")),
+ "tcp_rfc1337": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1")),
+ "tcp_slow_start_after_idle": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1")),
+ "tcp_synack_retries": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("5")),
+ "tcp_syn_retries": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("3")),
+ "tcp_timestamps": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("1")),
+ }),
+ "core": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{
+ "default_qdisc": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("pfifo_fast")),
+ "message_burst": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("10")),
+ "message_cost": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("5")),
+ "optmem_max": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("0")),
+ "rmem_default": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("212992")),
+ "rmem_max": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("212992")),
+ "somaxconn": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("128")),
+ "wmem_default": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("212992")),
+ "wmem_max": newDentry(root, inoGen.NextIno(), 0444, newStaticFile("212992")),
+ }),
+ }),
+ })
+}
+
+// mmapMinAddrData implements vfs.DynamicBytesSource for
+// /proc/sys/vm/mmap_min_addr.
+//
+// +stateify savable
+type mmapMinAddrData struct {
+ kernfs.DynamicBytesFile
+
+ k *kernel.Kernel
+}
+
+var _ dynamicInode = (*mmapMinAddrData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *mmapMinAddrData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "%d\n", d.k.Platform.MinUserAddress())
+ return nil
+}
+
+// hostnameData implements vfs.DynamicBytesSource for /proc/sys/kernel/hostname.
+//
+// +stateify savable
+type hostnameData struct {
+ kernfs.DynamicBytesFile
+}
+
+var _ dynamicInode = (*hostnameData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (*hostnameData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ utsns := kernel.UTSNamespaceFromContext(ctx)
+ buf.WriteString(utsns.HostName())
+ buf.WriteString("\n")
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/net_test.go b/pkg/sentry/fsimpl/proc/tasks_sys_test.go
index 20a77a8ca..20a77a8ca 100644
--- a/pkg/sentry/fsimpl/proc/net_test.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys_test.go
diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go
index ca8c87ec2..6b58c16b9 100644
--- a/pkg/sentry/fsimpl/proc/tasks_test.go
+++ b/pkg/sentry/fsimpl/proc/tasks_test.go
@@ -69,12 +69,15 @@ func checkDots(dirs []vfs.Dirent) ([]vfs.Dirent, error) {
func checkTasksStaticFiles(gots []vfs.Dirent) ([]vfs.Dirent, error) {
wants := map[string]vfs.Dirent{
+ "cpuinfo": {Type: linux.DT_REG},
"loadavg": {Type: linux.DT_REG},
"meminfo": {Type: linux.DT_REG},
"mounts": {Type: linux.DT_LNK},
"self": selfLink,
"stat": {Type: linux.DT_REG},
+ "sys": {Type: linux.DT_DIR},
"thread-self": threadSelfLink,
+ "uptime": {Type: linux.DT_REG},
"version": {Type: linux.DT_REG},
}
return checkFiles(gots, wants)
@@ -82,12 +85,20 @@ func checkTasksStaticFiles(gots []vfs.Dirent) ([]vfs.Dirent, error) {
func checkTaskStaticFiles(gots []vfs.Dirent) ([]vfs.Dirent, error) {
wants := map[string]vfs.Dirent{
- "io": {Type: linux.DT_REG},
- "maps": {Type: linux.DT_REG},
- "smaps": {Type: linux.DT_REG},
- "stat": {Type: linux.DT_REG},
- "statm": {Type: linux.DT_REG},
- "status": {Type: linux.DT_REG},
+ "auxv": {Type: linux.DT_REG},
+ "cmdline": {Type: linux.DT_REG},
+ "comm": {Type: linux.DT_REG},
+ "environ": {Type: linux.DT_REG},
+ "gid_map": {Type: linux.DT_REG},
+ "io": {Type: linux.DT_REG},
+ "maps": {Type: linux.DT_REG},
+ "ns": {Type: linux.DT_DIR},
+ "smaps": {Type: linux.DT_REG},
+ "stat": {Type: linux.DT_REG},
+ "statm": {Type: linux.DT_REG},
+ "status": {Type: linux.DT_REG},
+ "task": {Type: linux.DT_DIR},
+ "uid_map": {Type: linux.DT_REG},
}
return checkFiles(gots, wants)
}
diff --git a/pkg/sentry/fsimpl/proc/version.go b/pkg/sentry/fsimpl/proc/version.go
deleted file mode 100644
index 367f2396b..000000000
--- a/pkg/sentry/fsimpl/proc/version.go
+++ /dev/null
@@ -1,70 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package proc
-
-import (
- "bytes"
- "fmt"
-
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
-)
-
-// versionData implements vfs.DynamicBytesSource for /proc/version.
-//
-// +stateify savable
-type versionData struct {
- kernfs.DynamicBytesFile
-
- // k is the owning Kernel.
- k *kernel.Kernel
-}
-
-var _ dynamicInode = (*versionData)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (v *versionData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- init := v.k.GlobalInit()
- if init == nil {
- // Attempted to read before the init Task is created. This can
- // only occur during startup, which should never need to read
- // this file.
- panic("Attempted to read version before initial Task is available")
- }
-
- // /proc/version takes the form:
- //
- // "SYSNAME version RELEASE (COMPILE_USER@COMPILE_HOST)
- // (COMPILER_VERSION) VERSION"
- //
- // where:
- // - SYSNAME, RELEASE, and VERSION are the same as returned by
- // sys_utsname
- // - COMPILE_USER is the user that build the kernel
- // - COMPILE_HOST is the hostname of the machine on which the kernel
- // was built
- // - COMPILER_VERSION is the version reported by the building compiler
- //
- // Since we don't really want to expose build information to
- // applications, those fields are omitted.
- //
- // FIXME(mpratt): Using Version from the init task SyscallTable
- // disregards the different version a task may have (e.g., in a uts
- // namespace).
- ver := init.Leader().SyscallTable().Version
- fmt.Fprintf(buf, "%s version %s %s\n", ver.Sysname, ver.Release, ver.Version)
- return nil
-}
diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD
index a5b285987..7601c7c04 100644
--- a/pkg/sentry/fsimpl/tmpfs/BUILD
+++ b/pkg/sentry/fsimpl/tmpfs/BUILD
@@ -40,6 +40,7 @@ go_library(
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/pipe",
+ "//pkg/sentry/kernel/time",
"//pkg/sentry/memmap",
"//pkg/sentry/pgalloc",
"//pkg/sentry/platform",
@@ -47,6 +48,7 @@ go_library(
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
"//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/syserror",
],
)
@@ -76,6 +78,7 @@ go_test(
srcs = [
"pipe_test.go",
"regular_file_test.go",
+ "stat_test.go",
],
embed = [":tmpfs"],
deps = [
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index 26979729e..4cd7e9aea 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -56,7 +56,8 @@ afterSymlink:
}
next := nextVFSD.Impl().(*dentry)
if symlink, ok := next.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() {
- // TODO: symlink traversals update access time
+ // TODO(gvisor.dev/issues/1197): Symlink traversals updates
+ // access time.
if err := rp.HandleSymlink(symlink.target); err != nil {
return nil, err
}
@@ -501,7 +502,8 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa
oldParent.inode.decLinksLocked()
newParent.inode.incLinksLocked()
}
- // TODO: update timestamps and parent directory sizes
+ // TODO(gvisor.dev/issues/1197): Update timestamps and parent directory
+ // sizes.
vfsObj.CommitRenameReplaceDentry(renamedVFSD, &newParent.vfsd, newName, replacedVFSD)
return nil
}
@@ -555,15 +557,11 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
fs.mu.RLock()
defer fs.mu.RUnlock()
- _, err := resolveLocked(rp)
+ d, err := resolveLocked(rp)
if err != nil {
return err
}
- if opts.Stat.Mask == 0 {
- return nil
- }
- // TODO: implement inode.setStat
- return syserror.EPERM
+ return d.inode.setStat(opts.Stat)
}
// StatAt implements vfs.FilesystemImpl.StatAt.
@@ -587,7 +585,7 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu
if err != nil {
return linux.Statfs{}, err
}
- // TODO: actually implement statfs
+ // TODO(gvisor.dev/issues/1197): Actually implement statfs.
return linux.Statfs{}, syserror.ENOSYS
}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index f51e247a7..5fa70cc6d 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -17,7 +17,6 @@ package tmpfs
import (
"io"
"math"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -30,6 +29,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -63,6 +63,41 @@ func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMod
return &file.inode
}
+// truncate grows or shrinks the file to the given size. It returns true if the
+// file size was updated.
+func (rf *regularFile) truncate(size uint64) (bool, error) {
+ rf.mu.Lock()
+ defer rf.mu.Unlock()
+
+ if size == rf.size {
+ // Nothing to do.
+ return false, nil
+ }
+
+ if size > rf.size {
+ // Growing the file.
+ if rf.seals&linux.F_SEAL_GROW != 0 {
+ // Seal does not allow growth.
+ return false, syserror.EPERM
+ }
+ rf.size = size
+ return true, nil
+ }
+
+ // Shrinking the file
+ if rf.seals&linux.F_SEAL_SHRINK != 0 {
+ // Seal does not allow shrink.
+ return false, syserror.EPERM
+ }
+
+ // TODO(gvisor.dev/issues/1197): Invalidate mappings once we have
+ // mappings.
+
+ rf.data.Truncate(size, rf.memFile)
+ rf.size = size
+ return true, nil
+}
+
type regularFileFD struct {
fileDescription
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
index 3731c5b6f..034a29fdb 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
@@ -18,6 +18,7 @@ import (
"bytes"
"fmt"
"io"
+ "sync/atomic"
"testing"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -29,10 +30,12 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
)
-// newFileFD creates a new file in a new tmpfs mount, and returns the FD. If
-// the returned err is not nil, then cleanup should be called when the FD is no
-// longer needed.
-func newFileFD(ctx context.Context, filename string) (*vfs.FileDescription, func(), error) {
+// nextFileID is used to generate unique file names.
+var nextFileID int64
+
+// newTmpfsRoot creates a new tmpfs mount, and returns the root. If the error
+// is not nil, then cleanup should be called when the root is no longer needed.
+func newTmpfsRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentry, func(), error) {
creds := auth.CredentialsFromContext(ctx)
vfsObj := vfs.New()
@@ -41,36 +44,124 @@ func newFileFD(ctx context.Context, filename string) (*vfs.FileDescription, func
})
mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
if err != nil {
- return nil, nil, fmt.Errorf("failed to create tmpfs root mount: %v", err)
+ return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("failed to create tmpfs root mount: %v", err)
}
root := mntns.Root()
+ return vfsObj, root, func() {
+ root.DecRef()
+ mntns.DecRef(vfsObj)
+ }, nil
+}
+
+// newFileFD creates a new file in a new tmpfs mount, and returns the FD. If
+// the returned err is not nil, then cleanup should be called when the FD is no
+// longer needed.
+func newFileFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) {
+ creds := auth.CredentialsFromContext(ctx)
+ vfsObj, root, cleanup, err := newTmpfsRoot(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ filename := fmt.Sprintf("tmpfs-test-file-%d", atomic.AddInt64(&nextFileID, 1))
// Create the file that will be write/read.
fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
- Root: root,
- Start: root,
- Path: fspath.Parse(filename),
- FollowFinalSymlink: true,
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filename),
}, &vfs.OpenOptions{
Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL,
- Mode: 0644,
+ Mode: linux.ModeRegular | mode,
})
if err != nil {
- root.DecRef()
- mntns.DecRef(vfsObj)
+ cleanup()
return nil, nil, fmt.Errorf("failed to create file %q: %v", filename, err)
}
- return fd, func() {
- root.DecRef()
- mntns.DecRef(vfsObj)
- }, nil
+ return fd, cleanup, nil
+}
+
+// newDirFD is like newFileFD, but for directories.
+func newDirFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) {
+ creds := auth.CredentialsFromContext(ctx)
+ vfsObj, root, cleanup, err := newTmpfsRoot(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ dirname := fmt.Sprintf("tmpfs-test-dir-%d", atomic.AddInt64(&nextFileID, 1))
+
+ // Create the dir.
+ if err := vfsObj.MkdirAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(dirname),
+ }, &vfs.MkdirOptions{
+ Mode: linux.ModeDirectory | mode,
+ }); err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("failed to create directory %q: %v", dirname, err)
+ }
+
+ // Open the dir and return it.
+ fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(dirname),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY | linux.O_DIRECTORY,
+ })
+ if err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("failed to open directory %q: %v", dirname, err)
+ }
+
+ return fd, cleanup, nil
+}
+
+// newPipeFD is like newFileFD, but for pipes.
+func newPipeFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) {
+ creds := auth.CredentialsFromContext(ctx)
+ vfsObj, root, cleanup, err := newTmpfsRoot(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ pipename := fmt.Sprintf("tmpfs-test-pipe-%d", atomic.AddInt64(&nextFileID, 1))
+
+ // Create the pipe.
+ if err := vfsObj.MknodAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(pipename),
+ }, &vfs.MknodOptions{
+ Mode: linux.ModeNamedPipe | mode,
+ }); err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("failed to create pipe %q: %v", pipename, err)
+ }
+
+ // Open the pipe and return it.
+ fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(pipename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("failed to open pipe %q: %v", pipename, err)
+ }
+
+ return fd, cleanup, nil
}
// Test that we can write some data to a file and read it back.`
func TestSimpleWriteRead(t *testing.T) {
ctx := contexttest.Context(t)
- fd, cleanup, err := newFileFD(ctx, "simpleReadWrite")
+ fd, cleanup, err := newFileFD(ctx, 0644)
if err != nil {
t.Fatal(err)
}
@@ -116,7 +207,7 @@ func TestSimpleWriteRead(t *testing.T) {
func TestPWrite(t *testing.T) {
ctx := contexttest.Context(t)
- fd, cleanup, err := newFileFD(ctx, "PRead")
+ fd, cleanup, err := newFileFD(ctx, 0644)
if err != nil {
t.Fatal(err)
}
@@ -171,7 +262,7 @@ func TestPWrite(t *testing.T) {
func TestPRead(t *testing.T) {
ctx := contexttest.Context(t)
- fd, cleanup, err := newFileFD(ctx, "PRead")
+ fd, cleanup, err := newFileFD(ctx, 0644)
if err != nil {
t.Fatal(err)
}
@@ -222,3 +313,124 @@ func TestPRead(t *testing.T) {
}
}
}
+
+func TestTruncate(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ // Fill the file with some data.
+ data := bytes.Repeat([]byte("gVisor is awsome"), 100)
+ written, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("fd.Write failed: %v", err)
+ }
+
+ // Size should be same as written.
+ sizeStatOpts := vfs.StatOptions{Mask: linux.STATX_SIZE}
+ stat, err := fd.Stat(ctx, sizeStatOpts)
+ if err != nil {
+ t.Fatalf("fd.Stat failed: %v", err)
+ }
+ if got, want := int64(stat.Size), written; got != want {
+ t.Errorf("fd.Stat got size %d, want %d", got, want)
+ }
+
+ // Truncate down.
+ newSize := uint64(10)
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: newSize,
+ },
+ }); err != nil {
+ t.Errorf("fd.Truncate failed: %v", err)
+ }
+ // Size should be updated.
+ statAfterTruncateDown, err := fd.Stat(ctx, sizeStatOpts)
+ if err != nil {
+ t.Fatalf("fd.Stat failed: %v", err)
+ }
+ if got, want := statAfterTruncateDown.Size, newSize; got != want {
+ t.Errorf("fd.Stat got size %d, want %d", got, want)
+ }
+ // We should only read newSize worth of data.
+ buf := make([]byte, 1000)
+ if n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0, vfs.ReadOptions{}); err != nil && err != io.EOF {
+ t.Fatalf("fd.PRead failed: %v", err)
+ } else if uint64(n) != newSize {
+ t.Errorf("fd.PRead got size %d, want %d", n, newSize)
+ }
+ // Mtime and Ctime should be bumped.
+ if got := statAfterTruncateDown.Mtime.ToNsec(); got <= stat.Mtime.ToNsec() {
+ t.Errorf("fd.Stat got Mtime %v, want > %v", got, stat.Mtime)
+ }
+ if got := statAfterTruncateDown.Ctime.ToNsec(); got <= stat.Ctime.ToNsec() {
+ t.Errorf("fd.Stat got Ctime %v, want > %v", got, stat.Ctime)
+ }
+
+ // Truncate up.
+ newSize = 100
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: newSize,
+ },
+ }); err != nil {
+ t.Errorf("fd.Truncate failed: %v", err)
+ }
+ // Size should be updated.
+ statAfterTruncateUp, err := fd.Stat(ctx, sizeStatOpts)
+ if err != nil {
+ t.Fatalf("fd.Stat failed: %v", err)
+ }
+ if got, want := statAfterTruncateUp.Size, newSize; got != want {
+ t.Errorf("fd.Stat got size %d, want %d", got, want)
+ }
+ // We should read newSize worth of data.
+ buf = make([]byte, 1000)
+ if n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0, vfs.ReadOptions{}); err != nil && err != io.EOF {
+ t.Fatalf("fd.PRead failed: %v", err)
+ } else if uint64(n) != newSize {
+ t.Errorf("fd.PRead got size %d, want %d", n, newSize)
+ }
+ // Bytes should be null after 10, since we previously truncated to 10.
+ for i := uint64(10); i < newSize; i++ {
+ if buf[i] != 0 {
+ t.Errorf("fd.PRead got byte %d=%x, want 0", i, buf[i])
+ break
+ }
+ }
+ // Mtime and Ctime should be bumped.
+ if got := statAfterTruncateUp.Mtime.ToNsec(); got <= statAfterTruncateDown.Mtime.ToNsec() {
+ t.Errorf("fd.Stat got Mtime %v, want > %v", got, statAfterTruncateDown.Mtime)
+ }
+ if got := statAfterTruncateUp.Ctime.ToNsec(); got <= statAfterTruncateDown.Ctime.ToNsec() {
+ t.Errorf("fd.Stat got Ctime %v, want > %v", got, stat.Ctime)
+ }
+
+ // Truncate to the current size.
+ newSize = statAfterTruncateUp.Size
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: newSize,
+ },
+ }); err != nil {
+ t.Errorf("fd.Truncate failed: %v", err)
+ }
+ statAfterTruncateNoop, err := fd.Stat(ctx, sizeStatOpts)
+ if err != nil {
+ t.Fatalf("fd.Stat failed: %v", err)
+ }
+ // Mtime and Ctime should not be bumped, since operation is a noop.
+ if got := statAfterTruncateNoop.Mtime.ToNsec(); got != statAfterTruncateUp.Mtime.ToNsec() {
+ t.Errorf("fd.Stat got Mtime %v, want %v", got, statAfterTruncateUp.Mtime)
+ }
+ if got := statAfterTruncateNoop.Ctime.ToNsec(); got != statAfterTruncateUp.Ctime.ToNsec() {
+ t.Errorf("fd.Stat got Ctime %v, want %v", got, statAfterTruncateUp.Ctime)
+ }
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/stat_test.go b/pkg/sentry/fsimpl/tmpfs/stat_test.go
new file mode 100644
index 000000000..ebe035dee
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/stat_test.go
@@ -0,0 +1,232 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "fmt"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+func TestStatAfterCreate(t *testing.T) {
+ ctx := contexttest.Context(t)
+ mode := linux.FileMode(0644)
+
+ // Run with different file types.
+ // TODO(gvisor.dev/issues/1197): Also test symlinks and sockets.
+ for _, typ := range []string{"file", "dir", "pipe"} {
+ t.Run(fmt.Sprintf("type=%q", typ), func(t *testing.T) {
+ var (
+ fd *vfs.FileDescription
+ cleanup func()
+ err error
+ )
+ switch typ {
+ case "file":
+ fd, cleanup, err = newFileFD(ctx, mode)
+ case "dir":
+ fd, cleanup, err = newDirFD(ctx, mode)
+ case "pipe":
+ fd, cleanup, err = newPipeFD(ctx, mode)
+ default:
+ panic(fmt.Sprintf("unknown typ %q", typ))
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ got, err := fd.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ t.Fatalf("Stat failed: %v", err)
+ }
+
+ // Atime, Ctime, Mtime should all be current time (non-zero).
+ atime, ctime, mtime := got.Atime.ToNsec(), got.Ctime.ToNsec(), got.Mtime.ToNsec()
+ if atime != ctime || ctime != mtime {
+ t.Errorf("got atime=%d ctime=%d mtime=%d, wanted equal values", atime, ctime, mtime)
+ }
+ if atime == 0 {
+ t.Errorf("got atime=%d, want non-zero", atime)
+ }
+
+ // Btime should be 0, as it is not set by tmpfs.
+ if btime := got.Btime.ToNsec(); btime != 0 {
+ t.Errorf("got btime %d, want 0", got.Btime.ToNsec())
+ }
+
+ // Size should be 0.
+ if got.Size != 0 {
+ t.Errorf("got size %d, want 0", got.Size)
+ }
+
+ // Nlink should be 1 for files, 2 for dirs.
+ wantNlink := uint32(1)
+ if typ == "dir" {
+ wantNlink = 2
+ }
+ if got.Nlink != wantNlink {
+ t.Errorf("got nlink %d, want %d", got.Nlink, wantNlink)
+ }
+
+ // UID and GID are set from context creds.
+ creds := auth.CredentialsFromContext(ctx)
+ if got.UID != uint32(creds.EffectiveKUID) {
+ t.Errorf("got uid %d, want %d", got.UID, uint32(creds.EffectiveKUID))
+ }
+ if got.GID != uint32(creds.EffectiveKGID) {
+ t.Errorf("got gid %d, want %d", got.GID, uint32(creds.EffectiveKGID))
+ }
+
+ // Mode.
+ wantMode := uint16(mode)
+ switch typ {
+ case "file":
+ wantMode |= linux.S_IFREG
+ case "dir":
+ wantMode |= linux.S_IFDIR
+ case "pipe":
+ wantMode |= linux.S_IFIFO
+ default:
+ panic(fmt.Sprintf("unknown typ %q", typ))
+ }
+
+ if got.Mode != wantMode {
+ t.Errorf("got mode %x, want %x", got.Mode, wantMode)
+ }
+
+ // Ino.
+ if got.Ino == 0 {
+ t.Errorf("got ino %d, want not 0", got.Ino)
+ }
+ })
+ }
+}
+
+func TestSetStatAtime(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ allStatOptions := vfs.StatOptions{Mask: linux.STATX_ALL}
+
+ // Get initial stat.
+ initialStat, err := fd.Stat(ctx, allStatOptions)
+ if err != nil {
+ t.Fatalf("Stat failed: %v", err)
+ }
+
+ // Set atime, but without the mask.
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: linux.Statx{
+ Mask: 0,
+ Atime: linux.NsecToStatxTimestamp(100),
+ }}); err != nil {
+ t.Errorf("SetStat atime without mask failed: %v")
+ }
+ // Atime should be unchanged.
+ if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
+ t.Errorf("Stat got error: %v", err)
+ } else if gotStat.Atime != initialStat.Atime {
+ t.Errorf("Stat got atime %d, want %d", gotStat.Atime, initialStat.Atime)
+ }
+
+ // Set atime, this time included in the mask.
+ setStat := linux.Statx{
+ Mask: linux.STATX_ATIME,
+ Atime: linux.NsecToStatxTimestamp(100),
+ }
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: setStat}); err != nil {
+ t.Errorf("SetStat atime with mask failed: %v")
+ }
+ if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
+ t.Errorf("Stat got error: %v", err)
+ } else if gotStat.Atime != setStat.Atime {
+ t.Errorf("Stat got atime %d, want %d", gotStat.Atime, setStat.Atime)
+ }
+}
+
+func TestSetStat(t *testing.T) {
+ ctx := contexttest.Context(t)
+ mode := linux.FileMode(0644)
+
+ // Run with different file types.
+ // TODO(gvisor.dev/issues/1197): Also test symlinks and sockets.
+ for _, typ := range []string{"file", "dir", "pipe"} {
+ t.Run(fmt.Sprintf("type=%q", typ), func(t *testing.T) {
+ var (
+ fd *vfs.FileDescription
+ cleanup func()
+ err error
+ )
+ switch typ {
+ case "file":
+ fd, cleanup, err = newFileFD(ctx, mode)
+ case "dir":
+ fd, cleanup, err = newDirFD(ctx, mode)
+ case "pipe":
+ fd, cleanup, err = newPipeFD(ctx, mode)
+ default:
+ panic(fmt.Sprintf("unknown typ %q", typ))
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ allStatOptions := vfs.StatOptions{Mask: linux.STATX_ALL}
+
+ // Get initial stat.
+ initialStat, err := fd.Stat(ctx, allStatOptions)
+ if err != nil {
+ t.Fatalf("Stat failed: %v", err)
+ }
+
+ // Set atime, but without the mask.
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: linux.Statx{
+ Mask: 0,
+ Atime: linux.NsecToStatxTimestamp(100),
+ }}); err != nil {
+ t.Errorf("SetStat atime without mask failed: %v")
+ }
+ // Atime should be unchanged.
+ if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
+ t.Errorf("Stat got error: %v", err)
+ } else if gotStat.Atime != initialStat.Atime {
+ t.Errorf("Stat got atime %d, want %d", gotStat.Atime, initialStat.Atime)
+ }
+
+ // Set atime, this time included in the mask.
+ setStat := linux.Statx{
+ Mask: linux.STATX_ATIME,
+ Atime: linux.NsecToStatxTimestamp(100),
+ }
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: setStat}); err != nil {
+ t.Errorf("SetStat atime with mask failed: %v")
+ }
+ if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
+ t.Errorf("Stat got error: %v", err)
+ } else if gotStat.Atime != setStat.Atime {
+ t.Errorf("Stat got atime %d, want %d", gotStat.Atime, setStat.Atime)
+ }
+ })
+ }
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index 7be6faa5b..1d4889c89 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -26,14 +26,15 @@ package tmpfs
import (
"fmt"
"math"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -47,6 +48,9 @@ type filesystem struct {
// memFile is used to allocate pages to for regular files.
memFile *pgalloc.MemoryFile
+ // clock is a realtime clock used to set timestamps in file operations.
+ clock time.Clock
+
// mu serializes changes to the Dentry tree.
mu sync.RWMutex
@@ -59,8 +63,10 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
if memFileProvider == nil {
panic("MemoryFileProviderFromContext returned nil")
}
+ clock := time.RealtimeClockFromContext(ctx)
fs := filesystem{
memFile: memFileProvider.MemoryFile(),
+ clock: clock,
}
fs.vfsfs.Init(vfsObj, &fs)
root := fs.newDentry(fs.newDirectory(creds, 01777))
@@ -116,6 +122,9 @@ func (d *dentry) DecRef() {
// inode represents a filesystem object.
type inode struct {
+ // clock is a realtime clock used to set timestamps in file operations.
+ clock time.Clock
+
// refs is a reference count. refs is accessed using atomic memory
// operations.
//
@@ -126,26 +135,37 @@ type inode struct {
// filesystem.RmdirAt() drops the reference.
refs int64
- // Inode metadata; protected by mu and accessed using atomic memory
- // operations unless otherwise specified.
- mu sync.RWMutex
+ // Inode metadata. Writing multiple fields atomically requires holding
+ // mu, othewise atomic operations can be used.
+ mu sync.Mutex
mode uint32 // excluding file type bits, which are based on impl
nlink uint32 // protected by filesystem.mu instead of inode.mu
uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
gid uint32 // auth.KGID, but ...
ino uint64 // immutable
+ // Linux's tmpfs has no concept of btime.
+ atime int64 // nanoseconds
+ ctime int64 // nanoseconds
+ mtime int64 // nanoseconds
+
impl interface{} // immutable
}
const maxLinks = math.MaxUint32
func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode linux.FileMode) {
+ i.clock = fs.clock
i.refs = 1
i.mode = uint32(mode)
i.uid = uint32(creds.EffectiveKUID)
i.gid = uint32(creds.EffectiveKGID)
i.ino = atomic.AddUint64(&fs.nextInoMinusOne, 1)
+ // Tmpfs creation sets atime, ctime, and mtime to current time.
+ now := i.clock.Now().Nanoseconds()
+ i.atime = now
+ i.ctime = now
+ i.mtime = now
// i.nlink initialized by caller
i.impl = impl
}
@@ -213,15 +233,24 @@ func (i *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes, i
// Go won't inline this function, and returning linux.Statx (which is quite
// big) means spending a lot of time in runtime.duffcopy(), so instead it's an
// output parameter.
+//
+// Note that Linux does not guarantee to return consistent data (in the case of
+// a concurrent modification), so we do not require holding inode.mu.
func (i *inode) statTo(stat *linux.Statx) {
- stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO
+ stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK |
+ linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_ATIME |
+ linux.STATX_BTIME | linux.STATX_CTIME | linux.STATX_MTIME
stat.Blksize = 1 // usermem.PageSize in tmpfs
stat.Nlink = atomic.LoadUint32(&i.nlink)
stat.UID = atomic.LoadUint32(&i.uid)
stat.GID = atomic.LoadUint32(&i.gid)
stat.Mode = uint16(atomic.LoadUint32(&i.mode))
stat.Ino = i.ino
- // TODO: device number
+ // Linux's tmpfs has no concept of btime, so zero-value is returned.
+ stat.Atime = linux.NsecToStatxTimestamp(i.atime)
+ stat.Ctime = linux.NsecToStatxTimestamp(i.ctime)
+ stat.Mtime = linux.NsecToStatxTimestamp(i.mtime)
+ // TODO(gvisor.dev/issues/1197): Device number.
switch impl := i.impl.(type) {
case *regularFile:
stat.Mode |= linux.S_IFREG
@@ -245,6 +274,75 @@ func (i *inode) statTo(stat *linux.Statx) {
}
}
+func (i *inode) setStat(stat linux.Statx) error {
+ if stat.Mask == 0 {
+ return nil
+ }
+ i.mu.Lock()
+ var (
+ needsMtimeBump bool
+ needsCtimeBump bool
+ )
+ mask := stat.Mask
+ if mask&linux.STATX_MODE != 0 {
+ atomic.StoreUint32(&i.mode, uint32(stat.Mode))
+ needsCtimeBump = true
+ }
+ if mask&linux.STATX_UID != 0 {
+ atomic.StoreUint32(&i.uid, stat.UID)
+ needsCtimeBump = true
+ }
+ if mask&linux.STATX_GID != 0 {
+ atomic.StoreUint32(&i.gid, stat.GID)
+ needsCtimeBump = true
+ }
+ if mask&linux.STATX_SIZE != 0 {
+ switch impl := i.impl.(type) {
+ case *regularFile:
+ updated, err := impl.truncate(stat.Size)
+ if err != nil {
+ return err
+ }
+ if updated {
+ needsMtimeBump = true
+ needsCtimeBump = true
+ }
+ case *directory:
+ return syserror.EISDIR
+ case *symlink:
+ return syserror.EINVAL
+ case *namedPipe:
+ // Nothing.
+ default:
+ panic(fmt.Sprintf("unknown inode type: %T", i.impl))
+ }
+ }
+ if mask&linux.STATX_ATIME != 0 {
+ atomic.StoreInt64(&i.atime, stat.Atime.ToNsecCapped())
+ needsCtimeBump = true
+ }
+ if mask&linux.STATX_MTIME != 0 {
+ atomic.StoreInt64(&i.mtime, stat.Mtime.ToNsecCapped())
+ needsCtimeBump = true
+ // Ignore the mtime bump, since we just set it ourselves.
+ needsMtimeBump = false
+ }
+ if mask&linux.STATX_CTIME != 0 {
+ atomic.StoreInt64(&i.ctime, stat.Ctime.ToNsecCapped())
+ // Ignore the ctime bump, since we just set it ourselves.
+ needsCtimeBump = false
+ }
+ now := i.clock.Now().Nanoseconds()
+ if needsMtimeBump {
+ atomic.StoreInt64(&i.mtime, now)
+ }
+ if needsCtimeBump {
+ atomic.StoreInt64(&i.ctime, now)
+ }
+ i.mu.Unlock()
+ return nil
+}
+
// allocatedBlocksForSize returns the number of 512B blocks needed to
// accommodate the given size in bytes, as appropriate for struct
// stat::st_blocks and struct statx::stx_blocks. (Note that this 512B block
@@ -291,9 +389,5 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu
// SetStat implements vfs.FileDescriptionImpl.SetStat.
func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
- if opts.Stat.Mask == 0 {
- return nil
- }
- // TODO: implement inode.setStat
- return syserror.EPERM
+ return fd.inode().setStat(opts.Stat)
}
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index 2706927ff..ac85ba0c8 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -35,7 +35,7 @@ go_template_instance(
out = "seqatomic_taskgoroutineschedinfo_unsafe.go",
package = "kernel",
suffix = "TaskGoroutineSchedInfo",
- template = "//pkg/syncutil:generic_seqatomic",
+ template = "//pkg/sync:generic_seqatomic",
types = {
"Value": "TaskGoroutineSchedInfo",
},
@@ -209,7 +209,7 @@ go_library(
"//pkg/sentry/usermem",
"//pkg/state",
"//pkg/state/statefile",
- "//pkg/syncutil",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
@@ -241,6 +241,7 @@ go_test(
"//pkg/sentry/time",
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
],
)
diff --git a/pkg/sentry/kernel/abstract_socket_namespace.go b/pkg/sentry/kernel/abstract_socket_namespace.go
index 244655b5c..920fe4329 100644
--- a/pkg/sentry/kernel/abstract_socket_namespace.go
+++ b/pkg/sentry/kernel/abstract_socket_namespace.go
@@ -15,11 +15,11 @@
package kernel
import (
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
)
// +stateify savable
diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD
index 04c244447..1aa72fa47 100644
--- a/pkg/sentry/kernel/auth/BUILD
+++ b/pkg/sentry/kernel/auth/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "atomicptr_credentials_unsafe.go",
package = "auth",
suffix = "Credentials",
- template = "//pkg/syncutil:generic_atomicptr",
+ template = "//pkg/sync:generic_atomicptr",
types = {
"Value": "Credentials",
},
@@ -64,6 +64,7 @@ go_library(
"//pkg/bits",
"//pkg/log",
"//pkg/sentry/context",
+ "//pkg/sync",
"//pkg/syserror",
],
)
diff --git a/pkg/sentry/kernel/auth/user_namespace.go b/pkg/sentry/kernel/auth/user_namespace.go
index af28ccc65..9dd52c860 100644
--- a/pkg/sentry/kernel/auth/user_namespace.go
+++ b/pkg/sentry/kernel/auth/user_namespace.go
@@ -16,8 +16,8 @@ package auth
import (
"math"
- "sync"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD
index 3361e8b7d..c47f6b6fc 100644
--- a/pkg/sentry/kernel/epoll/BUILD
+++ b/pkg/sentry/kernel/epoll/BUILD
@@ -32,6 +32,7 @@ go_library(
"//pkg/sentry/fs/anon",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/waiter",
],
)
diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go
index 9c0a4e1b4..430311cc0 100644
--- a/pkg/sentry/kernel/epoll/epoll.go
+++ b/pkg/sentry/kernel/epoll/epoll.go
@@ -18,7 +18,6 @@ package epoll
import (
"fmt"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/refs"
@@ -27,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/anon"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD
index e65b961e8..c831fbab2 100644
--- a/pkg/sentry/kernel/eventfd/BUILD
+++ b/pkg/sentry/kernel/eventfd/BUILD
@@ -16,6 +16,7 @@ go_library(
"//pkg/sentry/fs/anon",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/kernel/eventfd/eventfd.go b/pkg/sentry/kernel/eventfd/eventfd.go
index 12f0d429b..687690679 100644
--- a/pkg/sentry/kernel/eventfd/eventfd.go
+++ b/pkg/sentry/kernel/eventfd/eventfd.go
@@ -18,7 +18,6 @@ package eventfd
import (
"math"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -28,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/anon"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD
index 49d81b712..6b36bc63e 100644
--- a/pkg/sentry/kernel/fasync/BUILD
+++ b/pkg/sentry/kernel/fasync/BUILD
@@ -12,6 +12,7 @@ go_library(
"//pkg/sentry/fs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
+ "//pkg/sync",
"//pkg/waiter",
],
)
diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go
index 6b0bb0324..d32c3e90a 100644
--- a/pkg/sentry/kernel/fasync/fasync.go
+++ b/pkg/sentry/kernel/fasync/fasync.go
@@ -16,12 +16,11 @@
package fasync
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index 11f613a11..cd1501f85 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -18,7 +18,6 @@ import (
"bytes"
"fmt"
"math"
- "sync"
"sync/atomic"
"syscall"
@@ -28,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sync"
)
// FDFlags define flags for an individual descriptor.
diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go
index 2bcb6216a..eccb7d1e7 100644
--- a/pkg/sentry/kernel/fd_table_test.go
+++ b/pkg/sentry/kernel/fd_table_test.go
@@ -16,7 +16,6 @@ package kernel
import (
"runtime"
- "sync"
"testing"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -24,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/filetest"
"gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sync"
)
const (
diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go
index ded27d668..2448c1d99 100644
--- a/pkg/sentry/kernel/fs_context.go
+++ b/pkg/sentry/kernel/fs_context.go
@@ -16,10 +16,10 @@ package kernel
import (
"fmt"
- "sync"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sync"
)
// FSContext contains filesystem context.
diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD
index 75ec31761..50db443ce 100644
--- a/pkg/sentry/kernel/futex/BUILD
+++ b/pkg/sentry/kernel/futex/BUILD
@@ -9,7 +9,7 @@ go_template_instance(
out = "atomicptr_bucket_unsafe.go",
package = "futex",
suffix = "Bucket",
- template = "//pkg/syncutil:generic_atomicptr",
+ template = "//pkg/sync:generic_atomicptr",
types = {
"Value": "bucket",
},
@@ -42,6 +42,7 @@ go_library(
"//pkg/sentry/context",
"//pkg/sentry/memmap",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
],
)
@@ -51,5 +52,8 @@ go_test(
size = "small",
srcs = ["futex_test.go"],
embed = [":futex"],
- deps = ["//pkg/sentry/usermem"],
+ deps = [
+ "//pkg/sentry/usermem",
+ "//pkg/sync",
+ ],
)
diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go
index 278cc8143..d1931c8f4 100644
--- a/pkg/sentry/kernel/futex/futex.go
+++ b/pkg/sentry/kernel/futex/futex.go
@@ -18,11 +18,10 @@
package futex
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/kernel/futex/futex_test.go b/pkg/sentry/kernel/futex/futex_test.go
index 65e5d1428..c23126ca5 100644
--- a/pkg/sentry/kernel/futex/futex_test.go
+++ b/pkg/sentry/kernel/futex/futex_test.go
@@ -17,13 +17,13 @@ package futex
import (
"math"
"runtime"
- "sync"
"sync/atomic"
"syscall"
"testing"
"unsafe"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
)
// testData implements the Target interface, and allows us to
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 8653d2f63..c85e97fef 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -36,7 +36,6 @@ import (
"fmt"
"io"
"path/filepath"
- "sync"
"sync/atomic"
"time"
@@ -67,6 +66,7 @@ import (
uspb "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
"gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD
index d7a7d1169..7f36252a9 100644
--- a/pkg/sentry/kernel/memevent/BUILD
+++ b/pkg/sentry/kernel/memevent/BUILD
@@ -16,6 +16,7 @@ go_library(
"//pkg/metric",
"//pkg/sentry/kernel",
"//pkg/sentry/usage",
+ "//pkg/sync",
],
)
diff --git a/pkg/sentry/kernel/memevent/memory_events.go b/pkg/sentry/kernel/memevent/memory_events.go
index b0d98e7f0..200565bb8 100644
--- a/pkg/sentry/kernel/memevent/memory_events.go
+++ b/pkg/sentry/kernel/memevent/memory_events.go
@@ -17,7 +17,6 @@
package memevent
import (
- "sync"
"time"
"gvisor.dev/gvisor/pkg/eventchannel"
@@ -26,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
pb "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto"
"gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sync"
)
var totalTicks = metric.MustCreateNewUint64Metric("/memory_events/ticks", false /*sync*/, "Total number of memory event periods that have elapsed since startup.")
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
index 9d34f6d4d..5eeaeff66 100644
--- a/pkg/sentry/kernel/pipe/BUILD
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -43,6 +43,7 @@ go_library(
"//pkg/sentry/safemem",
"//pkg/sentry/usermem",
"//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go
index 95bee2d37..1c0f34269 100644
--- a/pkg/sentry/kernel/pipe/buffer.go
+++ b/pkg/sentry/kernel/pipe/buffer.go
@@ -16,9 +16,9 @@ package pipe
import (
"io"
- "sync"
"gvisor.dev/gvisor/pkg/sentry/safemem"
+ "gvisor.dev/gvisor/pkg/sync"
)
// buffer encapsulates a queueable byte buffer.
diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go
index 4a19ab7ce..716f589af 100644
--- a/pkg/sentry/kernel/pipe/node.go
+++ b/pkg/sentry/kernel/pipe/node.go
@@ -15,12 +15,11 @@
package pipe
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 1a1b38f83..e4fd7d420 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -17,12 +17,12 @@ package pipe
import (
"fmt"
- "sync"
"sync/atomic"
"syscall"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go
index ef9641e6a..8394eb78b 100644
--- a/pkg/sentry/kernel/pipe/pipe_util.go
+++ b/pkg/sentry/kernel/pipe/pipe_util.go
@@ -17,7 +17,6 @@ package pipe
import (
"io"
"math"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -25,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
index 6416e0dd8..bf7461cbb 100644
--- a/pkg/sentry/kernel/pipe/vfs.go
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -15,13 +15,12 @@
package pipe
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD
index f4c00cd86..13a961594 100644
--- a/pkg/sentry/kernel/semaphore/BUILD
+++ b/pkg/sentry/kernel/semaphore/BUILD
@@ -31,6 +31,7 @@ go_library(
"//pkg/sentry/fs",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
+ "//pkg/sync",
"//pkg/syserror",
],
)
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go
index de9617e9d..18299814e 100644
--- a/pkg/sentry/kernel/semaphore/semaphore.go
+++ b/pkg/sentry/kernel/semaphore/semaphore.go
@@ -17,7 +17,6 @@ package semaphore
import (
"fmt"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
@@ -25,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD
index cd48945e6..7321b22ed 100644
--- a/pkg/sentry/kernel/shm/BUILD
+++ b/pkg/sentry/kernel/shm/BUILD
@@ -24,6 +24,7 @@ go_library(
"//pkg/sentry/platform",
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
],
)
diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go
index 19034a21e..8ddef7eb8 100644
--- a/pkg/sentry/kernel/shm/shm.go
+++ b/pkg/sentry/kernel/shm/shm.go
@@ -35,7 +35,6 @@ package shm
import (
"fmt"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
@@ -49,6 +48,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/kernel/signal_handlers.go b/pkg/sentry/kernel/signal_handlers.go
index a16f3d57f..768fda220 100644
--- a/pkg/sentry/kernel/signal_handlers.go
+++ b/pkg/sentry/kernel/signal_handlers.go
@@ -15,10 +15,9 @@
package kernel
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sync"
)
// SignalHandlers holds information about signal actions.
diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD
index 9f7e19b4d..89e4d84b1 100644
--- a/pkg/sentry/kernel/signalfd/BUILD
+++ b/pkg/sentry/kernel/signalfd/BUILD
@@ -16,6 +16,7 @@ go_library(
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/kernel",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go
index 4b08d7d72..28be4a939 100644
--- a/pkg/sentry/kernel/signalfd/signalfd.go
+++ b/pkg/sentry/kernel/signalfd/signalfd.go
@@ -16,8 +16,6 @@
package signalfd
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -26,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go
index 2fdee0282..d2d01add4 100644
--- a/pkg/sentry/kernel/syscalls.go
+++ b/pkg/sentry/kernel/syscalls.go
@@ -16,13 +16,13 @@ package kernel
import (
"fmt"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
)
// maxSyscallNum is the highest supported syscall number.
diff --git a/pkg/sentry/kernel/syslog.go b/pkg/sentry/kernel/syslog.go
index 8227ecf1d..4607cde2f 100644
--- a/pkg/sentry/kernel/syslog.go
+++ b/pkg/sentry/kernel/syslog.go
@@ -17,7 +17,8 @@ package kernel
import (
"fmt"
"math/rand"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// syslog represents a sentry-global kernel log.
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index d25a7903b..978d66da8 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -17,7 +17,6 @@ package kernel
import (
gocontext "context"
"runtime/trace"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -37,7 +36,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/syncutil"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -85,7 +84,7 @@ type Task struct {
//
// gosched is protected by goschedSeq. gosched is owned by the task
// goroutine.
- goschedSeq syncutil.SeqCount `state:"nosave"`
+ goschedSeq sync.SeqCount `state:"nosave"`
gosched TaskGoroutineSchedInfo
// yieldCount is the number of times the task goroutine has called
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index c0197a563..768e958d2 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -15,7 +15,6 @@
package kernel
import (
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -25,6 +24,7 @@ import (
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go
index 8267929a6..bf2dabb6e 100644
--- a/pkg/sentry/kernel/threads.go
+++ b/pkg/sentry/kernel/threads.go
@@ -16,9 +16,9 @@ package kernel
import (
"fmt"
- "sync"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD
index 31847e1df..4e4de0512 100644
--- a/pkg/sentry/kernel/time/BUILD
+++ b/pkg/sentry/kernel/time/BUILD
@@ -13,6 +13,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/sentry/context",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go
index 107394183..706de83ef 100644
--- a/pkg/sentry/kernel/time/time.go
+++ b/pkg/sentry/kernel/time/time.go
@@ -19,10 +19,10 @@ package time
import (
"fmt"
"math"
- "sync"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/kernel/timekeeper.go b/pkg/sentry/kernel/timekeeper.go
index 76417342a..dc99301de 100644
--- a/pkg/sentry/kernel/timekeeper.go
+++ b/pkg/sentry/kernel/timekeeper.go
@@ -16,7 +16,6 @@ package kernel
import (
"fmt"
- "sync"
"time"
"gvisor.dev/gvisor/pkg/log"
@@ -24,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/platform"
sentrytime "gvisor.dev/gvisor/pkg/sentry/time"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Timekeeper manages all of the kernel clocks.
diff --git a/pkg/sentry/kernel/tty.go b/pkg/sentry/kernel/tty.go
index 048de26dc..464d2306a 100644
--- a/pkg/sentry/kernel/tty.go
+++ b/pkg/sentry/kernel/tty.go
@@ -14,7 +14,7 @@
package kernel
-import "sync"
+import "gvisor.dev/gvisor/pkg/sync"
// TTY defines the relationship between a thread group and its controlling
// terminal.
diff --git a/pkg/sentry/kernel/uts_namespace.go b/pkg/sentry/kernel/uts_namespace.go
index 0a563e715..8ccf04bd1 100644
--- a/pkg/sentry/kernel/uts_namespace.go
+++ b/pkg/sentry/kernel/uts_namespace.go
@@ -15,9 +15,8 @@
package kernel
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
)
// UTSNamespace represents a UTS namespace, a holder of two system identifiers:
diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD
index 156e67bf8..9fa841e8b 100644
--- a/pkg/sentry/limits/BUILD
+++ b/pkg/sentry/limits/BUILD
@@ -15,6 +15,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/sentry/context",
+ "//pkg/sync",
],
)
diff --git a/pkg/sentry/limits/limits.go b/pkg/sentry/limits/limits.go
index b6c22656b..31b9e9ff6 100644
--- a/pkg/sentry/limits/limits.go
+++ b/pkg/sentry/limits/limits.go
@@ -16,8 +16,9 @@
package limits
import (
- "sync"
"syscall"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// LimitType defines a type of resource limit.
diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD
index 839931f67..83e248431 100644
--- a/pkg/sentry/mm/BUILD
+++ b/pkg/sentry/mm/BUILD
@@ -118,7 +118,7 @@ go_library(
"//pkg/sentry/safemem",
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
- "//pkg/syncutil",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/tcpip/buffer",
],
diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go
index 1b746d030..4b48866ad 100644
--- a/pkg/sentry/mm/aio_context.go
+++ b/pkg/sentry/mm/aio_context.go
@@ -15,8 +15,6 @@
package mm
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -25,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
index 58a5c186d..fa86ebced 100644
--- a/pkg/sentry/mm/mm.go
+++ b/pkg/sentry/mm/mm.go
@@ -35,8 +35,6 @@
package mm
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/memmap"
@@ -44,7 +42,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/syncutil"
+ "gvisor.dev/gvisor/pkg/sync"
)
// MemoryManager implements a virtual address space.
@@ -82,7 +80,7 @@ type MemoryManager struct {
users int32
// mappingMu is analogous to Linux's struct mm_struct::mmap_sem.
- mappingMu syncutil.DowngradableRWMutex `state:"nosave"`
+ mappingMu sync.DowngradableRWMutex `state:"nosave"`
// vmas stores virtual memory areas. Since vmas are stored by value,
// clients should usually use vmaIterator.ValuePtr() instead of
@@ -125,7 +123,7 @@ type MemoryManager struct {
// activeMu is loosely analogous to Linux's struct
// mm_struct::page_table_lock.
- activeMu syncutil.DowngradableRWMutex `state:"nosave"`
+ activeMu sync.DowngradableRWMutex `state:"nosave"`
// pmas stores platform mapping areas used to implement vmas. Since pmas
// are stored by value, clients should usually use pmaIterator.ValuePtr()
diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD
index f404107af..a9a2642c5 100644
--- a/pkg/sentry/pgalloc/BUILD
+++ b/pkg/sentry/pgalloc/BUILD
@@ -73,6 +73,7 @@ go_library(
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
"//pkg/state",
+ "//pkg/sync",
"//pkg/syserror",
],
)
diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go
index f7f7298c4..c99e023d9 100644
--- a/pkg/sentry/pgalloc/pgalloc.go
+++ b/pkg/sentry/pgalloc/pgalloc.go
@@ -25,7 +25,6 @@ import (
"fmt"
"math"
"os"
- "sync"
"sync/atomic"
"syscall"
"time"
@@ -37,6 +36,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD
index b6d008dbe..85e882df9 100644
--- a/pkg/sentry/platform/interrupt/BUILD
+++ b/pkg/sentry/platform/interrupt/BUILD
@@ -10,6 +10,7 @@ go_library(
],
importpath = "gvisor.dev/gvisor/pkg/sentry/platform/interrupt",
visibility = ["//pkg/sentry:internal"],
+ deps = ["//pkg/sync"],
)
go_test(
diff --git a/pkg/sentry/platform/interrupt/interrupt.go b/pkg/sentry/platform/interrupt/interrupt.go
index a4651f500..57be41647 100644
--- a/pkg/sentry/platform/interrupt/interrupt.go
+++ b/pkg/sentry/platform/interrupt/interrupt.go
@@ -17,7 +17,8 @@ package interrupt
import (
"fmt"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// Receiver receives interrupt notifications from a Forwarder.
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index f3afd98da..6a358d1d4 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -55,6 +55,7 @@ go_library(
"//pkg/sentry/platform/safecopy",
"//pkg/sentry/time",
"//pkg/sentry/usermem",
+ "//pkg/sync",
],
)
diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go
index ea8b9632e..a25f3c449 100644
--- a/pkg/sentry/platform/kvm/address_space.go
+++ b/pkg/sentry/platform/kvm/address_space.go
@@ -15,13 +15,13 @@
package kvm
import (
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
)
// dirtySet tracks vCPUs for invalidation.
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
index e5fac0d6a..2f02c03cf 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
@@ -17,8 +17,6 @@
package kvm
import (
- "unsafe"
-
"gvisor.dev/gvisor/pkg/sentry/arch"
)
diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go
index f2c2c059e..a7850faed 100644
--- a/pkg/sentry/platform/kvm/kvm.go
+++ b/pkg/sentry/platform/kvm/kvm.go
@@ -18,13 +18,13 @@ package kvm
import (
"fmt"
"os"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
)
// KVM represents a lightweight VM context.
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index 7d02ebf19..e6d912168 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -17,7 +17,6 @@ package kvm
import (
"fmt"
"runtime"
- "sync"
"sync/atomic"
"syscall"
@@ -27,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
)
// machine contains state associated with the VM as a whole.
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index b99fe425e..873e39dc7 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -90,7 +90,9 @@ func (m *machine) dropPageTables(pt *pagetables.PageTables) {
// Clear from all PCIDs.
for _, c := range m.vCPUs {
- c.PCIDs.Drop(pt)
+ if c.PCIDs != nil {
+ c.PCIDs.Drop(pt)
+ }
}
}
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
index 7ae47f291..3b1f20219 100644
--- a/pkg/sentry/platform/kvm/machine_arm64.go
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -97,7 +97,9 @@ func (m *machine) dropPageTables(pt *pagetables.PageTables) {
// Clear from all PCIDs.
for _, c := range m.vCPUs {
- c.PCIDs.Drop(pt)
+ if c.PCIDs != nil {
+ c.PCIDs.Drop(pt)
+ }
}
}
diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD
index 0df8cfa0f..cd13390c3 100644
--- a/pkg/sentry/platform/ptrace/BUILD
+++ b/pkg/sentry/platform/ptrace/BUILD
@@ -33,6 +33,7 @@ go_library(
"//pkg/sentry/platform/interrupt",
"//pkg/sentry/platform/safecopy",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go
index 7b120a15d..bb0e03880 100644
--- a/pkg/sentry/platform/ptrace/ptrace.go
+++ b/pkg/sentry/platform/ptrace/ptrace.go
@@ -46,13 +46,13 @@ package ptrace
import (
"os"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/interrupt"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
)
var (
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 20244fd95..15dc46a5b 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -18,7 +18,6 @@ import (
"fmt"
"os"
"runtime"
- "sync"
"syscall"
"golang.org/x/sys/unix"
@@ -27,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Linux kernel errnos which "should never be seen by user programs", but will
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go
index 2e6fbe488..245b20722 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go
@@ -18,7 +18,6 @@
package ptrace
import (
- "sync"
"sync/atomic"
"syscall"
"unsafe"
@@ -26,6 +25,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/hostcpu"
+ "gvisor.dev/gvisor/pkg/sync"
)
// maskPool contains reusable CPU masks for setting affinity. Unfortunately,
diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go
index 3f094c2a7..86fd5ed58 100644
--- a/pkg/sentry/platform/ring0/defs.go
+++ b/pkg/sentry/platform/ring0/defs.go
@@ -17,7 +17,7 @@ package ring0
import (
"syscall"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
)
// Kernel is a global kernel object.
diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go
index 10dbd381f..9dae0dccb 100644
--- a/pkg/sentry/platform/ring0/defs_amd64.go
+++ b/pkg/sentry/platform/ring0/defs_amd64.go
@@ -18,6 +18,7 @@ package ring0
import (
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
)
var (
diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go
index dc0eeec01..a850ce6cf 100644
--- a/pkg/sentry/platform/ring0/defs_arm64.go
+++ b/pkg/sentry/platform/ring0/defs_arm64.go
@@ -18,6 +18,7 @@ package ring0
import (
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
)
var (
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index 64e9c0845..da07815ff 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -474,6 +474,16 @@ TEXT ·El1_sync(SB),NOSPLIT,$0
B el1_invalid
el1_da:
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ WORD $0xd538601a //MRS FAR_EL1, R26
+
+ MOVD R26, CPU_FAULT_ADDR(RSV_REG)
+
+ MOVD $0, CPU_ERROR_TYPE(RSV_REG)
+
+ MOVD $PageFault, R3
+ MOVD R3, CPU_VECTOR_CODE(RSV_REG)
+
B ·Halt(SB)
el1_ia:
diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD
index e2e15ba5c..387a7f6c3 100644
--- a/pkg/sentry/platform/ring0/pagetables/BUILD
+++ b/pkg/sentry/platform/ring0/pagetables/BUILD
@@ -96,7 +96,10 @@ go_library(
"//pkg/sentry/platform/kvm:__subpackages__",
"//pkg/sentry/platform/ring0:__subpackages__",
],
- deps = ["//pkg/sentry/usermem"],
+ deps = [
+ "//pkg/sentry/usermem",
+ "//pkg/sync",
+ ],
)
go_test(
diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go b/pkg/sentry/platform/ring0/pagetables/pcids_x86.go
index 0f029f25d..e199bae18 100644
--- a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go
+++ b/pkg/sentry/platform/ring0/pagetables/pcids_x86.go
@@ -17,7 +17,7 @@
package pagetables
import (
- "sync"
+ "gvisor.dev/gvisor/pkg/sync"
)
// limitPCID is the number of valid PCIDs.
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 4301b697c..1684dfc24 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -327,7 +327,7 @@ func PackInq(t *kernel.Task, inq int32, buf []byte) []byte {
}
// PackTOS packs an IP_TOS socket control message.
-func PackTOS(t *kernel.Task, tos int8, buf []byte) []byte {
+func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte {
return putCmsgStruct(
buf,
linux.SOL_IP,
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
index 5eb06bbf4..b70047d81 100644
--- a/pkg/sentry/socket/netfilter/BUILD
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -14,6 +14,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/binary",
+ "//pkg/log",
"//pkg/sentry/kernel",
"//pkg/sentry/usermem",
"//pkg/syserr",
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 9f87c32f1..a9cfc1749 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserr"
@@ -35,6 +36,7 @@ const errorTargetName = "ERROR"
// metadata is opaque to netstack. It holds data that we need to translate
// between Linux's and netstack's iptables representations.
+// TODO(gvisor.dev/issue/170): This might be removable.
type metadata struct {
HookEntry [linux.NF_INET_NUMHOOKS]uint32
Underflow [linux.NF_INET_NUMHOOKS]uint32
@@ -51,7 +53,7 @@ func GetInfo(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr) (linux.IPTG
}
// Find the appropriate table.
- table, err := findTable(ep, info.TableName())
+ table, err := findTable(ep, info.Name)
if err != nil {
return linux.IPTGetinfo{}, err
}
@@ -82,30 +84,31 @@ func GetEntries(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr, outLen i
}
// Find the appropriate table.
- table, err := findTable(ep, userEntries.TableName())
+ table, err := findTable(ep, userEntries.Name)
if err != nil {
return linux.KernelIPTGetEntries{}, err
}
// Convert netstack's iptables rules to something that the iptables
// tool can understand.
- entries, _, err := convertNetstackToBinary(userEntries.TableName(), table)
+ entries, _, err := convertNetstackToBinary(userEntries.Name.String(), table)
if err != nil {
return linux.KernelIPTGetEntries{}, err
}
if binary.Size(entries) > uintptr(outLen) {
+ log.Warningf("Insufficient GetEntries output size: %d", uintptr(outLen))
return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
}
return entries, nil
}
-func findTable(ep tcpip.Endpoint, tableName string) (iptables.Table, *syserr.Error) {
+func findTable(ep tcpip.Endpoint, tablename linux.TableName) (iptables.Table, *syserr.Error) {
ipt, err := ep.IPTables()
if err != nil {
return iptables.Table{}, syserr.FromError(err)
}
- table, ok := ipt.Tables[tableName]
+ table, ok := ipt.Tables[tablename.String()]
if !ok {
return iptables.Table{}, syserr.ErrInvalidArgument
}
@@ -135,110 +138,68 @@ func FillDefaultIPTables(stack *stack.Stack) {
// format expected by the iptables tool. Linux stores each table as a binary
// blob that can only be traversed by parsing a bit, reading some offsets,
// jumping to those offsets, parsing again, etc.
-func convertNetstackToBinary(name string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, *syserr.Error) {
+func convertNetstackToBinary(tablename string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, *syserr.Error) {
// Return values.
var entries linux.KernelIPTGetEntries
var meta metadata
// The table name has to fit in the struct.
- if linux.XT_TABLE_MAXNAMELEN < len(name) {
+ if linux.XT_TABLE_MAXNAMELEN < len(tablename) {
+ log.Warningf("Table name %q too long.", tablename)
return linux.KernelIPTGetEntries{}, metadata{}, syserr.ErrInvalidArgument
}
- copy(entries.Name[:], name)
-
- // Deal with the built in chains first (INPUT, OUTPUT, etc.). Each of
- // these chains ends with an unconditional policy entry.
- for hook := iptables.Prerouting; hook < iptables.NumHooks; hook++ {
- chain, ok := table.BuiltinChains[hook]
- if !ok {
- // This table doesn't support this hook.
- continue
- }
-
- // Sanity check.
- if len(chain.Rules) < 1 {
- return linux.KernelIPTGetEntries{}, metadata{}, syserr.ErrInvalidArgument
- }
+ copy(entries.Name[:], tablename)
- for ruleIdx, rule := range chain.Rules {
- // If this is the first rule of a builtin chain, set
- // the metadata hook entry point.
- if ruleIdx == 0 {
+ for ruleIdx, rule := range table.Rules {
+ // Is this a chain entry point?
+ for hook, hookRuleIdx := range table.BuiltinChains {
+ if hookRuleIdx == ruleIdx {
meta.HookEntry[hook] = entries.Size
}
-
- // Each rule corresponds to an entry.
- entry := linux.KernelIPTEntry{
- IPTEntry: linux.IPTEntry{
- NextOffset: linux.SizeOfIPTEntry,
- TargetOffset: linux.SizeOfIPTEntry,
- },
+ }
+ // Is this a chain underflow point?
+ for underflow, underflowRuleIdx := range table.Underflows {
+ if underflowRuleIdx == ruleIdx {
+ meta.Underflow[underflow] = entries.Size
}
+ }
- for _, matcher := range rule.Matchers {
- // Serialize the matcher and add it to the
- // entry.
- serialized := marshalMatcher(matcher)
- entry.Elems = append(entry.Elems, serialized...)
- entry.NextOffset += uint16(len(serialized))
- entry.TargetOffset += uint16(len(serialized))
- }
+ // Each rule corresponds to an entry.
+ entry := linux.KernelIPTEntry{
+ IPTEntry: linux.IPTEntry{
+ NextOffset: linux.SizeOfIPTEntry,
+ TargetOffset: linux.SizeOfIPTEntry,
+ },
+ }
- // Serialize and append the target.
- serialized := marshalTarget(rule.Target)
+ for _, matcher := range rule.Matchers {
+ // Serialize the matcher and add it to the
+ // entry.
+ serialized := marshalMatcher(matcher)
entry.Elems = append(entry.Elems, serialized...)
entry.NextOffset += uint16(len(serialized))
-
- // The underflow rule is the last rule in the chain,
- // and is an unconditional rule (i.e. it matches any
- // packet). This is enforced when saving iptables.
- if ruleIdx == len(chain.Rules)-1 {
- meta.Underflow[hook] = entries.Size
- }
-
- entries.Size += uint32(entry.NextOffset)
- entries.Entrytable = append(entries.Entrytable, entry)
- meta.NumEntries++
+ entry.TargetOffset += uint16(len(serialized))
}
- }
-
- // TODO(gvisor.dev/issue/170): Deal with the user chains here. Each of
- // these starts with an error node holding the chain's name and ends
- // with an unconditional return.
+ // Serialize and append the target.
+ serialized := marshalTarget(rule.Target)
+ entry.Elems = append(entry.Elems, serialized...)
+ entry.NextOffset += uint16(len(serialized))
- // Lastly, each table ends with an unconditional error target rule as
- // its final entry.
- errorEntry := linux.KernelIPTEntry{
- IPTEntry: linux.IPTEntry{
- NextOffset: linux.SizeOfIPTEntry,
- TargetOffset: linux.SizeOfIPTEntry,
- },
+ entries.Size += uint32(entry.NextOffset)
+ entries.Entrytable = append(entries.Entrytable, entry)
+ meta.NumEntries++
}
- var errorTarget linux.XTErrorTarget
- errorTarget.Target.TargetSize = linux.SizeOfXTErrorTarget
- copy(errorTarget.ErrorName[:], errorTargetName)
- copy(errorTarget.Target.Name[:], errorTargetName)
-
- // Serialize and add it to the list of entries.
- errorTargetBuf := make([]byte, 0, linux.SizeOfXTErrorTarget)
- serializedErrorTarget := binary.Marshal(errorTargetBuf, usermem.ByteOrder, errorTarget)
- errorEntry.Elems = append(errorEntry.Elems, serializedErrorTarget...)
- errorEntry.NextOffset += uint16(len(serializedErrorTarget))
-
- entries.Size += uint32(errorEntry.NextOffset)
- entries.Entrytable = append(entries.Entrytable, errorEntry)
- meta.NumEntries++
- meta.Size = entries.Size
+ meta.Size = entries.Size
return entries, meta, nil
}
func marshalMatcher(matcher iptables.Matcher) []byte {
switch matcher.(type) {
default:
- // TODO(gvisor.dev/issue/170): We don't support any matchers yet, so
- // any call to marshalMatcher will panic.
+ // TODO(gvisor.dev/issue/170): We don't support any matchers
+ // yet, so any call to marshalMatcher will panic.
panic(fmt.Errorf("unknown matcher of type %T", matcher))
}
}
@@ -246,28 +207,46 @@ func marshalMatcher(matcher iptables.Matcher) []byte {
func marshalTarget(target iptables.Target) []byte {
switch target.(type) {
case iptables.UnconditionalAcceptTarget:
- return marshalUnconditionalAcceptTarget()
+ return marshalStandardTarget(iptables.Accept)
+ case iptables.UnconditionalDropTarget:
+ return marshalStandardTarget(iptables.Drop)
+ case iptables.ErrorTarget:
+ return marshalErrorTarget()
default:
panic(fmt.Errorf("unknown target of type %T", target))
}
}
-func marshalUnconditionalAcceptTarget() []byte {
+func marshalStandardTarget(verdict iptables.Verdict) []byte {
// The target's name will be the empty string.
target := linux.XTStandardTarget{
Target: linux.XTEntryTarget{
TargetSize: linux.SizeOfXTStandardTarget,
},
- Verdict: translateStandardVerdict(iptables.Accept),
+ Verdict: translateFromStandardVerdict(verdict),
}
ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
return binary.Marshal(ret, usermem.ByteOrder, target)
}
-// translateStandardVerdict translates verdicts the same way as the iptables
+func marshalErrorTarget() []byte {
+ // This is an error target named error
+ target := linux.XTErrorTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTErrorTarget,
+ },
+ }
+ copy(target.Name[:], errorTargetName)
+ copy(target.Target.Name[:], errorTargetName)
+
+ ret := make([]byte, 0, linux.SizeOfXTErrorTarget)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
+// translateFromStandardVerdict translates verdicts the same way as the iptables
// tool.
-func translateStandardVerdict(verdict iptables.Verdict) int32 {
+func translateFromStandardVerdict(verdict iptables.Verdict) int32 {
switch verdict {
case iptables.Accept:
return -linux.NF_ACCEPT - 1
@@ -280,7 +259,258 @@ func translateStandardVerdict(verdict iptables.Verdict) int32 {
case iptables.Jump:
// TODO(gvisor.dev/issue/170): Support Jump.
panic("Jump isn't supported yet")
+ }
+ panic(fmt.Sprintf("unknown standard verdict: %d", verdict))
+}
+
+// translateToStandardVerdict translates from the value in a
+// linux.XTStandardTarget to an iptables.Verdict.
+func translateToStandardVerdict(val int32) (iptables.Verdict, *syserr.Error) {
+ // TODO(gvisor.dev/issue/170): Support other verdicts.
+ switch val {
+ case -linux.NF_ACCEPT - 1:
+ return iptables.Accept, nil
+ case -linux.NF_DROP - 1:
+ return iptables.Drop, nil
+ case -linux.NF_QUEUE - 1:
+ log.Warningf("Unsupported iptables verdict QUEUE.")
+ case linux.NF_RETURN:
+ log.Warningf("Unsupported iptables verdict RETURN.")
+ default:
+ log.Warningf("Unknown iptables verdict %d.", val)
+ }
+ return iptables.Invalid, syserr.ErrInvalidArgument
+}
+
+// SetEntries sets iptables rules for a single table. See
+// net/ipv4/netfilter/ip_tables.c:translate_table for reference.
+func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
+ printReplace(optVal)
+
+ // Get the basic rules data (struct ipt_replace).
+ if len(optVal) < linux.SizeOfIPTReplace {
+ log.Warningf("netfilter.SetEntries: optVal has insufficient size for replace %d", len(optVal))
+ return syserr.ErrInvalidArgument
+ }
+ var replace linux.IPTReplace
+ replaceBuf := optVal[:linux.SizeOfIPTReplace]
+ optVal = optVal[linux.SizeOfIPTReplace:]
+ binary.Unmarshal(replaceBuf, usermem.ByteOrder, &replace)
+
+ // TODO(gvisor.dev/issue/170): Support other tables.
+ var table iptables.Table
+ switch replace.Name.String() {
+ case iptables.TablenameFilter:
+ table = iptables.EmptyFilterTable()
default:
- panic(fmt.Sprintf("unknown standard verdict: %d", verdict))
+ log.Warningf("We don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
+ return syserr.ErrInvalidArgument
+ }
+
+ // Convert input into a list of rules and their offsets.
+ var offset uint32
+ var offsets []uint32
+ for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ {
+ // Get the struct ipt_entry.
+ if len(optVal) < linux.SizeOfIPTEntry {
+ log.Warningf("netfilter: optVal has insufficient size for entry %d", len(optVal))
+ return syserr.ErrInvalidArgument
+ }
+ var entry linux.IPTEntry
+ buf := optVal[:linux.SizeOfIPTEntry]
+ optVal = optVal[linux.SizeOfIPTEntry:]
+ binary.Unmarshal(buf, usermem.ByteOrder, &entry)
+ if entry.TargetOffset != linux.SizeOfIPTEntry {
+ // TODO(gvisor.dev/issue/170): Support matchers.
+ return syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/170): We should support IPTIP
+ // filtering. We reject any nonzero IPTIP values for now.
+ emptyIPTIP := linux.IPTIP{}
+ if entry.IP != emptyIPTIP {
+ log.Warningf("netfilter: non-empty struct iptip found")
+ return syserr.ErrInvalidArgument
+ }
+
+ // Get the target of the rule.
+ target, consumed, err := parseTarget(optVal)
+ if err != nil {
+ return err
+ }
+ optVal = optVal[consumed:]
+
+ table.Rules = append(table.Rules, iptables.Rule{Target: target})
+ offsets = append(offsets, offset)
+ offset += linux.SizeOfIPTEntry + consumed
+ }
+
+ // Go through the list of supported hooks for this table and, for each
+ // one, set the rule it corresponds to.
+ for hook, _ := range replace.HookEntry {
+ if table.ValidHooks()&uint32(hook) != 0 {
+ hk := hookFromLinux(hook)
+ for ruleIdx, offset := range offsets {
+ if offset == replace.HookEntry[hook] {
+ table.BuiltinChains[hk] = ruleIdx
+ }
+ if offset == replace.Underflow[hook] {
+ table.Underflows[hk] = ruleIdx
+ }
+ }
+ if ruleIdx := table.BuiltinChains[hk]; ruleIdx == iptables.HookUnset {
+ log.Warningf("Hook %v is unset.", hk)
+ return syserr.ErrInvalidArgument
+ }
+ if ruleIdx := table.Underflows[hk]; ruleIdx == iptables.HookUnset {
+ log.Warningf("Underflow %v is unset.", hk)
+ return syserr.ErrInvalidArgument
+ }
+ }
+ }
+
+ ipt := stack.IPTables()
+ table.SetMetadata(metadata{
+ HookEntry: replace.HookEntry,
+ Underflow: replace.Underflow,
+ NumEntries: replace.NumEntries,
+ Size: replace.Size,
+ })
+ ipt.Tables[replace.Name.String()] = table
+ stack.SetIPTables(ipt)
+
+ return nil
+}
+
+// parseTarget parses a target from the start of optVal and returns the target
+// along with the number of bytes it occupies in optVal.
+func parseTarget(optVal []byte) (iptables.Target, uint32, *syserr.Error) {
+ if len(optVal) < linux.SizeOfXTEntryTarget {
+ log.Warningf("netfilter: optVal has insufficient size for entry target %d", len(optVal))
+ return nil, 0, syserr.ErrInvalidArgument
+ }
+ var target linux.XTEntryTarget
+ buf := optVal[:linux.SizeOfXTEntryTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &target)
+ switch target.Name.String() {
+ case "":
+ // Standard target.
+ if len(optVal) < linux.SizeOfXTStandardTarget {
+ log.Warningf("netfilter.SetEntries: optVal has insufficient size for standard target %d", len(optVal))
+ return nil, 0, syserr.ErrInvalidArgument
+ }
+ var standardTarget linux.XTStandardTarget
+ buf = optVal[:linux.SizeOfXTStandardTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget)
+
+ verdict, err := translateToStandardVerdict(standardTarget.Verdict)
+ if err != nil {
+ return nil, 0, err
+ }
+ switch verdict {
+ case iptables.Accept:
+ return iptables.UnconditionalAcceptTarget{}, linux.SizeOfXTStandardTarget, nil
+ case iptables.Drop:
+ // TODO(gvisor.dev/issue/170): Return an
+ // iptables.UnconditionalDropTarget to support DROP.
+ log.Infof("netfilter DROP is not supported yet.")
+ return nil, 0, syserr.ErrInvalidArgument
+ default:
+ panic(fmt.Sprintf("Unknown verdict: %v", verdict))
+ }
+
+ case errorTargetName:
+ // Error target.
+ if len(optVal) < linux.SizeOfXTErrorTarget {
+ log.Infof("netfilter.SetEntries: optVal has insufficient size for error target %d", len(optVal))
+ return nil, 0, syserr.ErrInvalidArgument
+ }
+ var errorTarget linux.XTErrorTarget
+ buf = optVal[:linux.SizeOfXTErrorTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget)
+
+ // Error targets are used in 2 cases:
+ // * An actual error case. These rules have an error
+ // named errorTargetName. The last entry of the table
+ // is usually an error case to catch any packets that
+ // somehow fall through every rule.
+ // * To mark the start of a user defined chain. These
+ // rules have an error with the name of the chain.
+ switch errorTarget.Name.String() {
+ case errorTargetName:
+ return iptables.ErrorTarget{}, linux.SizeOfXTErrorTarget, nil
+ default:
+ log.Infof("Unknown error target %q doesn't exist or isn't supported yet.", errorTarget.Name.String())
+ return nil, 0, syserr.ErrInvalidArgument
+ }
+ }
+
+ // Unknown target.
+ log.Infof("Unknown target %q doesn't exist or isn't supported yet.", target.Name.String())
+ return nil, 0, syserr.ErrInvalidArgument
+}
+
+func hookFromLinux(hook int) iptables.Hook {
+ switch hook {
+ case linux.NF_INET_PRE_ROUTING:
+ return iptables.Prerouting
+ case linux.NF_INET_LOCAL_IN:
+ return iptables.Input
+ case linux.NF_INET_FORWARD:
+ return iptables.Forward
+ case linux.NF_INET_LOCAL_OUT:
+ return iptables.Output
+ case linux.NF_INET_POST_ROUTING:
+ return iptables.Postrouting
+ }
+ panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain", hook))
+}
+
+// printReplace prints information about the struct ipt_replace in optVal. It
+// is only for debugging.
+func printReplace(optVal []byte) {
+ // Basic replace info.
+ var replace linux.IPTReplace
+ replaceBuf := optVal[:linux.SizeOfIPTReplace]
+ optVal = optVal[linux.SizeOfIPTReplace:]
+ binary.Unmarshal(replaceBuf, usermem.ByteOrder, &replace)
+ log.Infof("Replacing table %q: %+v", replace.Name.String(), replace)
+
+ // Read in the list of entries at the end of replace.
+ var totalOffset uint16
+ for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ {
+ var entry linux.IPTEntry
+ entryBuf := optVal[:linux.SizeOfIPTEntry]
+ binary.Unmarshal(entryBuf, usermem.ByteOrder, &entry)
+ log.Infof("Entry %d (total offset %d): %+v", entryIdx, totalOffset, entry)
+
+ totalOffset += entry.NextOffset
+ if entry.TargetOffset == linux.SizeOfIPTEntry {
+ log.Infof("Entry has no matches.")
+ } else {
+ log.Infof("Entry has matches.")
+ }
+
+ var target linux.XTEntryTarget
+ targetBuf := optVal[entry.TargetOffset : entry.TargetOffset+linux.SizeOfXTEntryTarget]
+ binary.Unmarshal(targetBuf, usermem.ByteOrder, &target)
+ log.Infof("Target named %q: %+v", target.Name.String(), target)
+
+ switch target.Name.String() {
+ case "":
+ var standardTarget linux.XTStandardTarget
+ stBuf := optVal[entry.TargetOffset : entry.TargetOffset+linux.SizeOfXTStandardTarget]
+ binary.Unmarshal(stBuf, usermem.ByteOrder, &standardTarget)
+ log.Infof("Standard target with verdict %q (%d).", linux.VerdictStrings[standardTarget.Verdict], standardTarget.Verdict)
+ case errorTargetName:
+ var errorTarget linux.XTErrorTarget
+ etBuf := optVal[entry.TargetOffset : entry.TargetOffset+linux.SizeOfXTErrorTarget]
+ binary.Unmarshal(etBuf, usermem.ByteOrder, &errorTarget)
+ log.Infof("Error target with name %q.", errorTarget.Name.String())
+ default:
+ log.Infof("Unknown target type.")
+ }
+
+ optVal = optVal[entry.NextOffset:]
}
}
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index 136821963..103933144 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -27,6 +27,7 @@ go_library(
"//pkg/sentry/socket/unix",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD
index 463544c1a..2d9f4ba9b 100644
--- a/pkg/sentry/socket/netlink/port/BUILD
+++ b/pkg/sentry/socket/netlink/port/BUILD
@@ -8,6 +8,7 @@ go_library(
srcs = ["port.go"],
importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port",
visibility = ["//pkg/sentry:internal"],
+ deps = ["//pkg/sync"],
)
go_test(
diff --git a/pkg/sentry/socket/netlink/port/port.go b/pkg/sentry/socket/netlink/port/port.go
index e9d3275b1..2cd3afc22 100644
--- a/pkg/sentry/socket/netlink/port/port.go
+++ b/pkg/sentry/socket/netlink/port/port.go
@@ -24,7 +24,8 @@ import (
"fmt"
"math"
"math/rand"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// maxPorts is a sanity limit on the maximum number of ports to allocate per
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index d2e3644a6..cea56f4ed 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -17,7 +17,6 @@ package netlink
import (
"math"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
@@ -34,6 +33,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/unix"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index e414d8055..f78784569 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -34,6 +34,7 @@ go_library(
"//pkg/sentry/socket/netfilter",
"//pkg/sentry/unimpl",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 9e0d69046..d2f7e987d 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -29,7 +29,6 @@ import (
"io"
"math"
"reflect"
- "sync"
"syscall"
"time"
@@ -49,6 +48,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/netfilter"
"gvisor.dev/gvisor/pkg/sentry/unimpl"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -324,22 +324,15 @@ func bytesToIPAddress(addr []byte) tcpip.Address {
// converts it to the FullAddress format. It supports AF_UNIX, AF_INET,
// AF_INET6, and AF_PACKET addresses.
//
-// strict indicates whether addresses with the AF_UNSPEC family are accepted of not.
-//
// AddressAndFamily returns an address and its family.
-func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, uint16, *syserr.Error) {
+func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
// Make sure we have at least 2 bytes for the address family.
if len(addr) < 2 {
return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument
}
- family := usermem.ByteOrder.Uint16(addr)
- if family != uint16(sfamily) && (strict || family != linux.AF_UNSPEC) {
- return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported
- }
-
// Get the rest of the fields based on the address family.
- switch family {
+ switch family := usermem.ByteOrder.Uint16(addr); family {
case linux.AF_UNIX:
path := addr[2:]
if len(path) > linux.UnixPathMax {
@@ -638,10 +631,40 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
return r
}
+func (s *SocketOperations) checkFamily(family uint16, exact bool) *syserr.Error {
+ if family == uint16(s.family) {
+ return nil
+ }
+ if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 {
+ v, err := s.Endpoint.GetSockOptBool(tcpip.V6OnlyOption)
+ if err != nil {
+ return syserr.TranslateNetstackError(err)
+ }
+ if !v {
+ return nil
+ }
+ }
+ return syserr.ErrInvalidArgument
+}
+
+// mapFamily maps the AF_INET ANY address to the IPv4-mapped IPv6 ANY if the
+// receiver's family is AF_INET6.
+//
+// This is a hack to work around the fact that both IPv4 and IPv6 ANY are
+// represented by the empty string.
+//
+// TODO(gvisor.dev/issues/1556): remove this function.
+func (s *SocketOperations) mapFamily(addr tcpip.FullAddress, family uint16) tcpip.FullAddress {
+ if len(addr.Addr) == 0 && s.family == linux.AF_INET6 && family == linux.AF_INET {
+ addr.Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
+ }
+ return addr
+}
+
// Connect implements the linux syscall connect(2) for sockets backed by
// tpcip.Endpoint.
func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
- addr, family, err := AddressAndFamily(s.family, sockaddr, false /* strict */)
+ addr, family, err := AddressAndFamily(sockaddr)
if err != nil {
return err
}
@@ -653,6 +676,12 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
}
return syserr.TranslateNetstackError(err)
}
+
+ if err := s.checkFamily(family, false /* exact */); err != nil {
+ return err
+ }
+ addr = s.mapFamily(addr, family)
+
// Always return right away in the non-blocking case.
if !blocking {
return syserr.TranslateNetstackError(s.Endpoint.Connect(addr))
@@ -681,10 +710,14 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
// Bind implements the linux syscall bind(2) for sockets backed by
// tcpip.Endpoint.
func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
- addr, _, err := AddressAndFamily(s.family, sockaddr, true /* strict */)
+ addr, family, err := AddressAndFamily(sockaddr)
if err != nil {
return err
}
+ if err := s.checkFamily(family, true /* exact */); err != nil {
+ return err
+ }
+ addr = s.mapFamily(addr, family)
// Issue the bind request to the endpoint.
return syserr.TranslateNetstackError(s.Endpoint.Bind(addr))
@@ -985,13 +1018,23 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- if len(v) == 0 {
+ if v == 0 {
return []byte{}, nil
}
if outLen < linux.IFNAMSIZ {
return nil, syserr.ErrInvalidArgument
}
- return append([]byte(v), 0), nil
+ s := t.NetworkContext()
+ if s == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ nic, ok := s.Interfaces()[int32(v)]
+ if !ok {
+ // The NICID no longer indicates a valid interface, probably because that
+ // interface was removed.
+ return nil, syserr.ErrUnknownDevice
+ }
+ return append([]byte(nic.Name), 0), nil
case linux.SO_BROADCAST:
if outLen < sizeOfInt32 {
@@ -1225,11 +1268,11 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- var o uint32
+ var o int32
if v {
o = 1
}
- return int32(o), nil
+ return o, nil
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1334,6 +1377,21 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
}
return int32(v), nil
+ case linux.IP_RECVTOS:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReceiveTOSOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ var o int32
+ if v {
+ o = 1
+ }
+ return o, nil
+
default:
emitUnimplementedEventIP(t, name)
}
@@ -1367,6 +1425,26 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa
return nil
}
+ if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
+ switch name {
+ case linux.IPT_SO_SET_REPLACE:
+ if len(optVal) < linux.SizeOfIPTReplace {
+ return syserr.ErrInvalidArgument
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return syserr.ErrNoDevice
+ }
+ // Stack must be a netstack stack.
+ return netfilter.SetEntries(stack.(*Stack).Stack, optVal)
+
+ case linux.IPT_SO_SET_ADD_COUNTERS:
+ // TODO(gvisor.dev/issue/170): Counter support.
+ return nil
+ }
+ }
+
return SetSockOpt(t, s, s.Endpoint, level, name, optVal)
}
@@ -1438,7 +1516,20 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
if n == -1 {
n = len(optVal)
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n])))
+ name := string(optVal[:n])
+ if name == "" {
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(0)))
+ }
+ s := t.NetworkContext()
+ if s == nil {
+ return syserr.ErrNoDevice
+ }
+ for nicID, nic := range s.Interfaces() {
+ if nic.Name == name {
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(nicID)))
+ }
+ }
+ return syserr.ErrUnknownDevice
case linux.SO_BROADCAST:
if len(optVal) < sizeOfInt32 {
@@ -1819,6 +1910,13 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
}
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v)))
+ case linux.IP_RECVTOS:
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0))
+
case linux.IP_ADD_SOURCE_MEMBERSHIP,
linux.IP_BIND_ADDRESS_NO_PORT,
linux.IP_BLOCK_SOURCE,
@@ -1839,7 +1937,6 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
linux.IP_RECVFRAGSIZE,
linux.IP_RECVOPTS,
linux.IP_RECVORIGDSTADDR,
- linux.IP_RECVTOS,
linux.IP_RECVTTL,
linux.IP_RETOPTS,
linux.IP_TRANSPARENT,
@@ -2037,8 +2134,8 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32)
case linux.AF_INET6:
var out linux.SockAddrInet6
- if len(addr.Addr) == 4 {
- // Copy address is v4-mapped format.
+ if len(addr.Addr) == header.IPv4AddressSize {
+ // Copy address in v4-mapped format.
copy(out.Addr[12:], addr.Addr)
out.Addr[10] = 0xff
out.Addr[11] = 0xff
@@ -2259,7 +2356,14 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
}
func (s *SocketOperations) controlMessages() socket.ControlMessages {
- return socket.ControlMessages{IP: tcpip.ControlMessages{HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, Timestamp: s.readCM.Timestamp}}
+ return socket.ControlMessages{
+ IP: tcpip.ControlMessages{
+ HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
+ Timestamp: s.readCM.Timestamp,
+ HasTOS: s.readCM.HasTOS,
+ TOS: s.readCM.TOS,
+ },
+ }
}
// updateTimestamp sets the timestamp for SIOCGSTAMP. It should be called after
@@ -2352,10 +2456,14 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
var addr *tcpip.FullAddress
if len(to) > 0 {
- addrBuf, _, err := AddressAndFamily(s.family, to, true /* strict */)
+ addrBuf, family, err := AddressAndFamily(to)
if err != nil {
return 0, err
}
+ if err := s.checkFamily(family, false /* exact */); err != nil {
+ return 0, err
+ }
+ addrBuf = s.mapFamily(addrBuf, family)
addr = &addrBuf
}
diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD
deleted file mode 100644
index 4668b87d1..000000000
--- a/pkg/sentry/socket/rpcinet/BUILD
+++ /dev/null
@@ -1,69 +0,0 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
-load("@rules_cc//cc:defs.bzl", "cc_proto_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "rpcinet",
- srcs = [
- "device.go",
- "rpcinet.go",
- "socket.go",
- "stack.go",
- "stack_unsafe.go",
- ],
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet",
- visibility = ["//pkg/sentry:internal"],
- deps = [
- ":syscall_rpc_go_proto",
- "//pkg/abi/linux",
- "//pkg/binary",
- "//pkg/sentry/arch",
- "//pkg/sentry/context",
- "//pkg/sentry/device",
- "//pkg/sentry/fs",
- "//pkg/sentry/fs/fsutil",
- "//pkg/sentry/inet",
- "//pkg/sentry/kernel",
- "//pkg/sentry/kernel/time",
- "//pkg/sentry/socket",
- "//pkg/sentry/socket/hostinet",
- "//pkg/sentry/socket/rpcinet/conn",
- "//pkg/sentry/socket/rpcinet/notifier",
- "//pkg/sentry/unimpl",
- "//pkg/sentry/usermem",
- "//pkg/syserr",
- "//pkg/syserror",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/tcpip/stack",
- "//pkg/unet",
- "//pkg/waiter",
- ],
-)
-
-proto_library(
- name = "syscall_rpc_proto",
- srcs = ["syscall_rpc.proto"],
- visibility = [
- "//visibility:public",
- ],
-)
-
-cc_proto_library(
- name = "syscall_rpc_cc_proto",
- visibility = [
- "//visibility:public",
- ],
- deps = [":syscall_rpc_proto"],
-)
-
-go_proto_library(
- name = "syscall_rpc_go_proto",
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto",
- proto = ":syscall_rpc_proto",
- visibility = [
- "//visibility:public",
- ],
-)
diff --git a/pkg/sentry/socket/rpcinet/conn/BUILD b/pkg/sentry/socket/rpcinet/conn/BUILD
deleted file mode 100644
index 23eadcb1b..000000000
--- a/pkg/sentry/socket/rpcinet/conn/BUILD
+++ /dev/null
@@ -1,17 +0,0 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "conn",
- srcs = ["conn.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn",
- visibility = ["//pkg/sentry:internal"],
- deps = [
- "//pkg/binary",
- "//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto",
- "//pkg/syserr",
- "//pkg/unet",
- "@com_github_golang_protobuf//proto:go_default_library",
- ],
-)
diff --git a/pkg/sentry/socket/rpcinet/conn/conn.go b/pkg/sentry/socket/rpcinet/conn/conn.go
deleted file mode 100644
index 356adad99..000000000
--- a/pkg/sentry/socket/rpcinet/conn/conn.go
+++ /dev/null
@@ -1,187 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package conn is an RPC connection to a syscall RPC server.
-package conn
-
-import (
- "fmt"
- "sync"
- "sync/atomic"
- "syscall"
-
- "github.com/golang/protobuf/proto"
- "gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/syserr"
- "gvisor.dev/gvisor/pkg/unet"
-
- pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
-)
-
-type request struct {
- response []byte
- ready chan struct{}
- ignoreResult bool
-}
-
-// RPCConnection represents a single RPC connection to a syscall gofer.
-type RPCConnection struct {
- // reqID is the ID of the last request and must be accessed atomically.
- reqID uint64
-
- sendMu sync.Mutex
- socket *unet.Socket
-
- reqMu sync.Mutex
- requests map[uint64]request
-}
-
-// NewRPCConnection initializes a RPC connection to a socket gofer.
-func NewRPCConnection(s *unet.Socket) *RPCConnection {
- conn := &RPCConnection{socket: s, requests: map[uint64]request{}}
- go func() { // S/R-FIXME(b/77962828)
- var nums [16]byte
- for {
- for n := 0; n < len(nums); {
- nn, err := conn.socket.Read(nums[n:])
- if err != nil {
- panic(fmt.Sprint("error reading length from socket rpc gofer: ", err))
- }
- n += nn
- }
-
- b := make([]byte, binary.LittleEndian.Uint64(nums[:8]))
- id := binary.LittleEndian.Uint64(nums[8:])
-
- for n := 0; n < len(b); {
- nn, err := conn.socket.Read(b[n:])
- if err != nil {
- panic(fmt.Sprint("error reading request from socket rpc gofer: ", err))
- }
- n += nn
- }
-
- conn.reqMu.Lock()
- r := conn.requests[id]
- if r.ignoreResult {
- delete(conn.requests, id)
- } else {
- r.response = b
- conn.requests[id] = r
- }
- conn.reqMu.Unlock()
- close(r.ready)
- }
- }()
- return conn
-}
-
-// NewRequest makes a request to the RPC gofer and returns the request ID and a
-// channel which will be closed once the request completes.
-func (c *RPCConnection) NewRequest(req pb.SyscallRequest, ignoreResult bool) (uint64, chan struct{}) {
- b, err := proto.Marshal(&req)
- if err != nil {
- panic(fmt.Sprint("invalid proto: ", err))
- }
-
- id := atomic.AddUint64(&c.reqID, 1)
- ch := make(chan struct{})
-
- c.reqMu.Lock()
- c.requests[id] = request{ready: ch, ignoreResult: ignoreResult}
- c.reqMu.Unlock()
-
- c.sendMu.Lock()
- defer c.sendMu.Unlock()
-
- var nums [16]byte
- binary.LittleEndian.PutUint64(nums[:8], uint64(len(b)))
- binary.LittleEndian.PutUint64(nums[8:], id)
- for n := 0; n < len(nums); {
- nn, err := c.socket.Write(nums[n:])
- if err != nil {
- panic(fmt.Sprint("error writing length and ID to socket gofer: ", err))
- }
- n += nn
- }
-
- for n := 0; n < len(b); {
- nn, err := c.socket.Write(b[n:])
- if err != nil {
- panic(fmt.Sprint("error writing request to socket gofer: ", err))
- }
- n += nn
- }
-
- return id, ch
-}
-
-// RPCReadFile will execute the ReadFile helper RPC method which avoids the
-// common pattern of open(2), read(2), close(2) by doing all three operations
-// as a single RPC. It will read the entire file or return EFBIG if the file
-// was too large.
-func (c *RPCConnection) RPCReadFile(path string) ([]byte, *syserr.Error) {
- req := &pb.SyscallRequest_ReadFile{&pb.ReadFileRequest{
- Path: path,
- }}
-
- id, ch := c.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-ch
-
- res := c.Request(id).Result.(*pb.SyscallResponse_ReadFile).ReadFile.Result
- if e, ok := res.(*pb.ReadFileResponse_ErrorNumber); ok {
- return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.ReadFileResponse_Data).Data, nil
-}
-
-// RPCWriteFile will execute the WriteFile helper RPC method which avoids the
-// common pattern of open(2), write(2), write(2), close(2) by doing all
-// operations as a single RPC.
-func (c *RPCConnection) RPCWriteFile(path string, data []byte) (int64, *syserr.Error) {
- req := &pb.SyscallRequest_WriteFile{&pb.WriteFileRequest{
- Path: path,
- Content: data,
- }}
-
- id, ch := c.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-ch
-
- res := c.Request(id).Result.(*pb.SyscallResponse_WriteFile).WriteFile
- if e := res.ErrorNumber; e != 0 {
- return int64(res.Written), syserr.FromHost(syscall.Errno(e))
- }
-
- return int64(res.Written), nil
-}
-
-// Request retrieves the request corresponding to the given request ID.
-//
-// The channel returned by NewRequest must have been closed before Request can
-// be called. This will happen automatically, do not manually close the
-// channel.
-func (c *RPCConnection) Request(id uint64) pb.SyscallResponse {
- c.reqMu.Lock()
- r := c.requests[id]
- delete(c.requests, id)
- c.reqMu.Unlock()
-
- var resp pb.SyscallResponse
- if err := proto.Unmarshal(r.response, &resp); err != nil {
- panic(fmt.Sprint("invalid proto: ", err))
- }
-
- return resp
-}
diff --git a/pkg/sentry/socket/rpcinet/notifier/BUILD b/pkg/sentry/socket/rpcinet/notifier/BUILD
deleted file mode 100644
index a3585e10d..000000000
--- a/pkg/sentry/socket/rpcinet/notifier/BUILD
+++ /dev/null
@@ -1,16 +0,0 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "notifier",
- srcs = ["notifier.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier",
- visibility = ["//:sandbox"],
- deps = [
- "//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto",
- "//pkg/sentry/socket/rpcinet/conn",
- "//pkg/waiter",
- "@org_golang_x_sys//unix:go_default_library",
- ],
-)
diff --git a/pkg/sentry/socket/rpcinet/notifier/notifier.go b/pkg/sentry/socket/rpcinet/notifier/notifier.go
deleted file mode 100644
index 7efe4301f..000000000
--- a/pkg/sentry/socket/rpcinet/notifier/notifier.go
+++ /dev/null
@@ -1,231 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package notifier implements an FD notifier implementation over RPC.
-package notifier
-
-import (
- "fmt"
- "sync"
- "syscall"
-
- "golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn"
- pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-type fdInfo struct {
- queue *waiter.Queue
- waiting bool
-}
-
-// Notifier holds all the state necessary to issue notifications when IO events
-// occur in the observed FDs.
-type Notifier struct {
- // rpcConn is the connection that is used for sending RPCs.
- rpcConn *conn.RPCConnection
-
- // epFD is the epoll file descriptor used to register for io
- // notifications.
- epFD uint32
-
- // mu protects fdMap.
- mu sync.Mutex
-
- // fdMap maps file descriptors to their notification queues and waiting
- // status.
- fdMap map[uint32]*fdInfo
-}
-
-// NewRPCNotifier creates a new notifier object.
-func NewRPCNotifier(cn *conn.RPCConnection) (*Notifier, error) {
- id, c := cn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCreate1{&pb.EpollCreate1Request{}}}, false /* ignoreResult */)
- <-c
-
- res := cn.Request(id).Result.(*pb.SyscallResponse_EpollCreate1).EpollCreate1.Result
- if e, ok := res.(*pb.EpollCreate1Response_ErrorNumber); ok {
- return nil, syscall.Errno(e.ErrorNumber)
- }
-
- w := &Notifier{
- rpcConn: cn,
- epFD: res.(*pb.EpollCreate1Response_Fd).Fd,
- fdMap: make(map[uint32]*fdInfo),
- }
-
- go w.waitAndNotify() // S/R-FIXME(b/77962828)
-
- return w, nil
-}
-
-// waitFD waits on mask for fd. The fdMap mutex must be hold.
-func (n *Notifier) waitFD(fd uint32, fi *fdInfo, mask waiter.EventMask) error {
- if !fi.waiting && mask == 0 {
- return nil
- }
-
- e := pb.EpollEvent{
- Events: mask.ToLinux() | unix.EPOLLET,
- Fd: fd,
- }
-
- switch {
- case !fi.waiting && mask != 0:
- id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCtl{&pb.EpollCtlRequest{Epfd: n.epFD, Op: syscall.EPOLL_CTL_ADD, Fd: fd, Event: &e}}}, false /* ignoreResult */)
- <-c
-
- e := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_EpollCtl).EpollCtl.ErrorNumber
- if e != 0 {
- return syscall.Errno(e)
- }
-
- fi.waiting = true
- case fi.waiting && mask == 0:
- id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCtl{&pb.EpollCtlRequest{Epfd: n.epFD, Op: syscall.EPOLL_CTL_DEL, Fd: fd}}}, false /* ignoreResult */)
- <-c
- n.rpcConn.Request(id)
-
- fi.waiting = false
- case fi.waiting && mask != 0:
- id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCtl{&pb.EpollCtlRequest{Epfd: n.epFD, Op: syscall.EPOLL_CTL_MOD, Fd: fd, Event: &e}}}, false /* ignoreResult */)
- <-c
-
- e := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_EpollCtl).EpollCtl.ErrorNumber
- if e != 0 {
- return syscall.Errno(e)
- }
- }
-
- return nil
-}
-
-// addFD adds an FD to the list of FDs observed by n.
-func (n *Notifier) addFD(fd uint32, queue *waiter.Queue) {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- // Panic if we're already notifying on this FD.
- if _, ok := n.fdMap[fd]; ok {
- panic(fmt.Sprintf("File descriptor %d added twice", fd))
- }
-
- // We have nothing to wait for at the moment. Just add it to the map.
- n.fdMap[fd] = &fdInfo{queue: queue}
-}
-
-// updateFD updates the set of events the FD needs to be notified on.
-func (n *Notifier) updateFD(fd uint32) error {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- if fi, ok := n.fdMap[fd]; ok {
- return n.waitFD(fd, fi, fi.queue.Events())
- }
-
- return nil
-}
-
-// RemoveFD removes an FD from the list of FDs observed by n.
-func (n *Notifier) removeFD(fd uint32) {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- // Remove from map, then from epoll object.
- n.waitFD(fd, n.fdMap[fd], 0)
- delete(n.fdMap, fd)
-}
-
-// hasFD returns true if the FD is in the list of observed FDs.
-func (n *Notifier) hasFD(fd uint32) bool {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- _, ok := n.fdMap[fd]
- return ok
-}
-
-// waitAndNotify loops waiting for io event notifications from the epoll
-// object. Once notifications arrive, they are dispatched to the
-// registered queue.
-func (n *Notifier) waitAndNotify() error {
- for {
- id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollWait{&pb.EpollWaitRequest{Fd: n.epFD, NumEvents: 100, Msec: -1}}}, false /* ignoreResult */)
- <-c
-
- res := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_EpollWait).EpollWait.Result
- if e, ok := res.(*pb.EpollWaitResponse_ErrorNumber); ok {
- err := syscall.Errno(e.ErrorNumber)
- // NOTE(magi): I don't think epoll_wait can return EAGAIN but I'm being
- // conseratively careful here since exiting the notification thread
- // would be really bad.
- if err == syscall.EINTR || err == syscall.EAGAIN {
- continue
- }
- return err
- }
-
- n.mu.Lock()
- for _, e := range res.(*pb.EpollWaitResponse_Events).Events.Events {
- if fi, ok := n.fdMap[e.Fd]; ok {
- fi.queue.Notify(waiter.EventMaskFromLinux(e.Events))
- }
- }
- n.mu.Unlock()
- }
-}
-
-// AddFD adds an FD to the list of observed FDs.
-func (n *Notifier) AddFD(fd uint32, queue *waiter.Queue) error {
- n.addFD(fd, queue)
- return nil
-}
-
-// UpdateFD updates the set of events the FD needs to be notified on.
-func (n *Notifier) UpdateFD(fd uint32) error {
- return n.updateFD(fd)
-}
-
-// RemoveFD removes an FD from the list of observed FDs.
-func (n *Notifier) RemoveFD(fd uint32) {
- n.removeFD(fd)
-}
-
-// HasFD returns true if the FD is in the list of observed FDs.
-//
-// This should only be used by tests to assert that FDs are correctly
-// registered.
-func (n *Notifier) HasFD(fd uint32) bool {
- return n.hasFD(fd)
-}
-
-// NonBlockingPoll polls the given fd in non-blocking fashion. It is used just
-// to query the FD's current state; this method will block on the RPC response
-// although the syscall is non-blocking.
-func (n *Notifier) NonBlockingPoll(fd uint32, mask waiter.EventMask) waiter.EventMask {
- for {
- id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Poll{&pb.PollRequest{Fd: fd, Events: mask.ToLinux()}}}, false /* ignoreResult */)
- <-c
-
- res := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_Poll).Poll.Result
- if e, ok := res.(*pb.PollResponse_ErrorNumber); ok {
- if syscall.Errno(e.ErrorNumber) == syscall.EINTR {
- continue
- }
- return mask
- }
-
- return waiter.EventMaskFromLinux(res.(*pb.PollResponse_Events).Events)
- }
-}
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go
deleted file mode 100644
index ddb76d9d4..000000000
--- a/pkg/sentry/socket/rpcinet/socket.go
+++ /dev/null
@@ -1,909 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package rpcinet
-
-import (
- "sync/atomic"
- "syscall"
- "time"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
- "gvisor.dev/gvisor/pkg/sentry/socket"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier"
- pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
- "gvisor.dev/gvisor/pkg/sentry/unimpl"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/syserr"
- "gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-// socketOperations implements fs.FileOperations and socket.Socket for a socket
-// implemented using a host socket.
-type socketOperations struct {
- fsutil.FilePipeSeek `state:"nosave"`
- fsutil.FileNotDirReaddir `state:"nosave"`
- fsutil.FileNoFsync `state:"nosave"`
- fsutil.FileNoMMap `state:"nosave"`
- fsutil.FileNoSplice `state:"nosave"`
- fsutil.FileNoopFlush `state:"nosave"`
- fsutil.FileUseInodeUnstableAttr `state:"nosave"`
- socket.SendReceiveTimeout
-
- family int // Read-only.
- stype linux.SockType // Read-only.
- protocol int // Read-only.
-
- fd uint32 // must be O_NONBLOCK
- wq *waiter.Queue
- rpcConn *conn.RPCConnection
- notifier *notifier.Notifier
-
- // shState is the state of the connection with respect to shutdown. Because
- // we're mixing non-blocking semantics on the other side we have to adapt for
- // some strange differences between blocking and non-blocking sockets.
- shState int32
-}
-
-// Verify that we actually implement socket.Socket.
-var _ = socket.Socket(&socketOperations{})
-
-// New creates a new RPC socket.
-func newSocketFile(ctx context.Context, stack *Stack, family int, skType linux.SockType, protocol int) (*fs.File, *syserr.Error) {
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Socket{&pb.SocketRequest{Family: int64(family), Type: int64(skType | syscall.SOCK_NONBLOCK), Protocol: int64(protocol)}}}, false /* ignoreResult */)
- <-c
-
- res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Socket).Socket.Result
- if e, ok := res.(*pb.SocketResponse_ErrorNumber); ok {
- return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
- fd := res.(*pb.SocketResponse_Fd).Fd
-
- var wq waiter.Queue
- stack.notifier.AddFD(fd, &wq)
-
- dirent := socket.NewDirent(ctx, socketDevice)
- defer dirent.DecRef()
- return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &socketOperations{
- family: family,
- stype: skType,
- protocol: protocol,
- wq: &wq,
- fd: fd,
- rpcConn: stack.rpcConn,
- notifier: stack.notifier,
- }), nil
-}
-
-func isBlockingErrno(err error) bool {
- return err == syscall.EAGAIN || err == syscall.EWOULDBLOCK
-}
-
-func translateIOSyscallError(err error) error {
- if isBlockingErrno(err) {
- return syserror.ErrWouldBlock
- }
- return err
-}
-
-// setShutdownFlags will set the shutdown flag so we can handle blocking reads
-// after a read shutdown.
-func (s *socketOperations) setShutdownFlags(how int) {
- var f tcpip.ShutdownFlags
- switch how {
- case linux.SHUT_RD:
- f = tcpip.ShutdownRead
- case linux.SHUT_WR:
- f = tcpip.ShutdownWrite
- case linux.SHUT_RDWR:
- f = tcpip.ShutdownWrite | tcpip.ShutdownRead
- }
-
- // Atomically update the flags.
- for {
- old := atomic.LoadInt32(&s.shState)
- if atomic.CompareAndSwapInt32(&s.shState, old, old|int32(f)) {
- break
- }
- }
-}
-
-func (s *socketOperations) resetShutdownFlags() {
- atomic.StoreInt32(&s.shState, 0)
-}
-
-func (s *socketOperations) isShutRdSet() bool {
- return atomic.LoadInt32(&s.shState)&int32(tcpip.ShutdownRead) != 0
-}
-
-func (s *socketOperations) isShutWrSet() bool {
- return atomic.LoadInt32(&s.shState)&int32(tcpip.ShutdownWrite) != 0
-}
-
-// Release implements fs.FileOperations.Release.
-func (s *socketOperations) Release() {
- s.notifier.RemoveFD(s.fd)
-
- // We always need to close the FD.
- _, _ = s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Close{&pb.CloseRequest{Fd: s.fd}}}, true /* ignoreResult */)
-}
-
-// Readiness implements waiter.Waitable.Readiness.
-func (s *socketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
- return s.notifier.NonBlockingPoll(s.fd, mask)
-}
-
-// EventRegister implements waiter.Waitable.EventRegister.
-func (s *socketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
- s.wq.EventRegister(e, mask)
- s.notifier.UpdateFD(s.fd)
-}
-
-// EventUnregister implements waiter.Waitable.EventUnregister.
-func (s *socketOperations) EventUnregister(e *waiter.Entry) {
- s.wq.EventUnregister(e)
- s.notifier.UpdateFD(s.fd)
-}
-
-func rpcRead(t *kernel.Task, req *pb.SyscallRequest_Read) (*pb.ReadResponse_Data, *syserr.Error) {
- s := t.NetworkContext().(*Stack)
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-c
-
- res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Read).Read.Result
- if e, ok := res.(*pb.ReadResponse_ErrorNumber); ok {
- return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.ReadResponse_Data), nil
-}
-
-// Read implements fs.FileOperations.Read.
-func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
- req := &pb.SyscallRequest_Read{&pb.ReadRequest{
- Fd: s.fd,
- Length: uint32(dst.NumBytes()),
- }}
-
- res, se := rpcRead(ctx.(*kernel.Task), req)
- if se == nil {
- n, e := dst.CopyOut(ctx, res.Data)
- return int64(n), e
- }
-
- return 0, se.ToError()
-}
-
-func rpcWrite(t *kernel.Task, req *pb.SyscallRequest_Write) (uint32, *syserr.Error) {
- s := t.NetworkContext().(*Stack)
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-c
-
- res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Write).Write.Result
- if e, ok := res.(*pb.WriteResponse_ErrorNumber); ok {
- return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.WriteResponse_Length).Length, nil
-}
-
-// Write implements fs.FileOperations.Write.
-func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
- t := ctx.(*kernel.Task)
- v := buffer.NewView(int(src.NumBytes()))
-
- // Copy all the data into the buffer.
- if _, err := src.CopyIn(t, v); err != nil {
- return 0, err
- }
-
- n, err := rpcWrite(t, &pb.SyscallRequest_Write{&pb.WriteRequest{Fd: s.fd, Data: v}})
- if n > 0 && n < uint32(src.NumBytes()) {
- // The FileOperations.Write interface expects us to return ErrWouldBlock in
- // the event of a partial write.
- return int64(n), syserror.ErrWouldBlock
- }
- return int64(n), err.ToError()
-}
-
-func rpcConnect(t *kernel.Task, fd uint32, sockaddr []byte) *syserr.Error {
- s := t.NetworkContext().(*Stack)
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Connect{&pb.ConnectRequest{Fd: uint32(fd), Address: sockaddr}}}, false /* ignoreResult */)
- <-c
-
- if e := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Connect).Connect.ErrorNumber; e != 0 {
- return syserr.FromHost(syscall.Errno(e))
- }
- return nil
-}
-
-// Connect implements socket.Socket.Connect.
-func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
- if !blocking {
- e := rpcConnect(t, s.fd, sockaddr)
- if e == nil {
- // Reset the shutdown state on new connects.
- s.resetShutdownFlags()
- }
- return e
- }
-
- // Register for notification when the endpoint becomes writable, then
- // initiate the connection.
- e, ch := waiter.NewChannelEntry(nil)
- s.EventRegister(&e, waiter.EventOut|waiter.EventIn|waiter.EventHUp)
- defer s.EventUnregister(&e)
- for {
- if err := rpcConnect(t, s.fd, sockaddr); err == nil || err != syserr.ErrInProgress && err != syserr.ErrAlreadyInProgress {
- if err == nil {
- // Reset the shutdown state on new connects.
- s.resetShutdownFlags()
- }
- return err
- }
-
- // It's pending, so we have to wait for a notification, and fetch the
- // result once the wait completes.
- if err := t.Block(ch); err != nil {
- return syserr.FromError(err)
- }
- }
-}
-
-func rpcAccept(t *kernel.Task, fd uint32, peer bool) (*pb.AcceptResponse_ResultPayload, *syserr.Error) {
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Accept{&pb.AcceptRequest{Fd: fd, Peer: peer, Flags: syscall.SOCK_NONBLOCK}}}, false /* ignoreResult */)
- <-c
-
- res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Accept).Accept.Result
- if e, ok := res.(*pb.AcceptResponse_ErrorNumber); ok {
- return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
- return res.(*pb.AcceptResponse_Payload).Payload, nil
-}
-
-// Accept implements socket.Socket.Accept.
-func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
- payload, se := rpcAccept(t, s.fd, peerRequested)
-
- // Check if we need to block.
- if blocking && se == syserr.ErrTryAgain {
- // Register for notifications.
- e, ch := waiter.NewChannelEntry(nil)
- // FIXME(b/119878986): This waiter.EventHUp is a partial
- // measure, need to figure out how to translate linux events to
- // internal events.
- s.EventRegister(&e, waiter.EventIn|waiter.EventHUp)
- defer s.EventUnregister(&e)
-
- // Try to accept the connection again; if it fails, then wait until we
- // get a notification.
- for {
- if payload, se = rpcAccept(t, s.fd, peerRequested); se != syserr.ErrTryAgain {
- break
- }
-
- if err := t.Block(ch); err != nil {
- return 0, nil, 0, syserr.FromError(err)
- }
- }
- }
-
- // Handle any error from accept.
- if se != nil {
- return 0, nil, 0, se
- }
-
- var wq waiter.Queue
- s.notifier.AddFD(payload.Fd, &wq)
-
- dirent := socket.NewDirent(t, socketDevice)
- defer dirent.DecRef()
- fileFlags := fs.FileFlags{
- Read: true,
- Write: true,
- NonSeekable: true,
- NonBlocking: flags&linux.SOCK_NONBLOCK != 0,
- }
- file := fs.NewFile(t, dirent, fileFlags, &socketOperations{
- family: s.family,
- stype: s.stype,
- protocol: s.protocol,
- wq: &wq,
- fd: payload.Fd,
- rpcConn: s.rpcConn,
- notifier: s.notifier,
- })
- defer file.DecRef()
-
- fd, err := t.NewFDFrom(0, file, kernel.FDFlags{
- CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
- })
- if err != nil {
- return 0, nil, 0, syserr.FromError(err)
- }
- t.Kernel().RecordSocket(file)
-
- if peerRequested {
- return fd, socket.UnmarshalSockAddr(s.family, payload.Address.Address), payload.Address.Length, nil
- }
-
- return fd, nil, 0, nil
-}
-
-// Bind implements socket.Socket.Bind.
-func (s *socketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Bind{&pb.BindRequest{Fd: s.fd, Address: sockaddr}}}, false /* ignoreResult */)
- <-c
-
- if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Bind).Bind.ErrorNumber; e != 0 {
- return syserr.FromHost(syscall.Errno(e))
- }
- return nil
-}
-
-// Listen implements socket.Socket.Listen.
-func (s *socketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error {
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Listen{&pb.ListenRequest{Fd: s.fd, Backlog: int64(backlog)}}}, false /* ignoreResult */)
- <-c
-
- if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Listen).Listen.ErrorNumber; e != 0 {
- return syserr.FromHost(syscall.Errno(e))
- }
- return nil
-}
-
-// Shutdown implements socket.Socket.Shutdown.
-func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
- // We save the shutdown state because of strange differences on linux
- // related to recvs on blocking vs. non-blocking sockets after a SHUT_RD.
- // We need to emulate that behavior on the blocking side.
- // TODO(b/120096741): There is a possible race that can exist with loopback,
- // where data could possibly be lost.
- s.setShutdownFlags(how)
-
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Shutdown{&pb.ShutdownRequest{Fd: s.fd, How: int64(how)}}}, false /* ignoreResult */)
- <-c
-
- if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Shutdown).Shutdown.ErrorNumber; e != 0 {
- return syserr.FromHost(syscall.Errno(e))
- }
-
- return nil
-}
-
-// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
- // SO_RCVTIMEO and SO_SNDTIMEO are special because blocking is performed
- // within the sentry.
- if level == linux.SOL_SOCKET && name == linux.SO_RCVTIMEO {
- if outLen < linux.SizeOfTimeval {
- return nil, syserr.ErrInvalidArgument
- }
-
- return linux.NsecToTimeval(s.RecvTimeout()), nil
- }
- if level == linux.SOL_SOCKET && name == linux.SO_SNDTIMEO {
- if outLen < linux.SizeOfTimeval {
- return nil, syserr.ErrInvalidArgument
- }
-
- return linux.NsecToTimeval(s.SendTimeout()), nil
- }
-
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockOpt{&pb.GetSockOptRequest{Fd: s.fd, Level: int64(level), Name: int64(name), Length: uint32(outLen)}}}, false /* ignoreResult */)
- <-c
-
- res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_GetSockOpt).GetSockOpt.Result
- if e, ok := res.(*pb.GetSockOptResponse_ErrorNumber); ok {
- return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.GetSockOptResponse_Opt).Opt, nil
-}
-
-// SetSockOpt implements socket.Socket.SetSockOpt.
-func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
- // Because blocking actually happens within the sentry we need to inspect
- // this socket option to determine if it's a SO_RCVTIMEO or SO_SNDTIMEO,
- // and if so, we will save it and use it as the deadline for recv(2)
- // or send(2) related syscalls.
- if level == linux.SOL_SOCKET && name == linux.SO_RCVTIMEO {
- if len(opt) < linux.SizeOfTimeval {
- return syserr.ErrInvalidArgument
- }
-
- var v linux.Timeval
- binary.Unmarshal(opt[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
- if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
- return syserr.ErrDomain
- }
- s.SetRecvTimeout(v.ToNsecCapped())
- return nil
- }
- if level == linux.SOL_SOCKET && name == linux.SO_SNDTIMEO {
- if len(opt) < linux.SizeOfTimeval {
- return syserr.ErrInvalidArgument
- }
-
- var v linux.Timeval
- binary.Unmarshal(opt[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
- if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
- return syserr.ErrDomain
- }
- s.SetSendTimeout(v.ToNsecCapped())
- return nil
- }
-
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_SetSockOpt{&pb.SetSockOptRequest{Fd: s.fd, Level: int64(level), Name: int64(name), Opt: opt}}}, false /* ignoreResult */)
- <-c
-
- if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_SetSockOpt).SetSockOpt.ErrorNumber; e != 0 {
- return syserr.FromHost(syscall.Errno(e))
- }
- return nil
-}
-
-// GetPeerName implements socket.Socket.GetPeerName.
-func (s *socketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetPeerName{&pb.GetPeerNameRequest{Fd: s.fd}}}, false /* ignoreResult */)
- <-c
-
- res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_GetPeerName).GetPeerName.Result
- if e, ok := res.(*pb.GetPeerNameResponse_ErrorNumber); ok {
- return nil, 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- addr := res.(*pb.GetPeerNameResponse_Address).Address
- return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil
-}
-
-// GetSockName implements socket.Socket.GetSockName.
-func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockName{&pb.GetSockNameRequest{Fd: s.fd}}}, false /* ignoreResult */)
- <-c
-
- res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_GetSockName).GetSockName.Result
- if e, ok := res.(*pb.GetSockNameResponse_ErrorNumber); ok {
- return nil, 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- addr := res.(*pb.GetSockNameResponse_Address).Address
- return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil
-}
-
-func rpcIoctl(t *kernel.Task, fd, cmd uint32, arg []byte) ([]byte, error) {
- stack := t.NetworkContext().(*Stack)
-
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Ioctl{&pb.IOCtlRequest{Fd: fd, Cmd: cmd, Arg: arg}}}, false /* ignoreResult */)
- <-c
-
- res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Ioctl).Ioctl.Result
- if e, ok := res.(*pb.IOCtlResponse_ErrorNumber); ok {
- return nil, syscall.Errno(e.ErrorNumber)
- }
-
- return res.(*pb.IOCtlResponse_Value).Value, nil
-}
-
-// ifconfIoctlFromStack populates a struct ifconf for the SIOCGIFCONF ioctl.
-func ifconfIoctlFromStack(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error {
- // If Ptr is NULL, return the necessary buffer size via Len.
- // Otherwise, write up to Len bytes starting at Ptr containing ifreq
- // structs.
- t := ctx.(*kernel.Task)
- s := t.NetworkContext().(*Stack)
- if s == nil {
- return syserr.ErrNoDevice.ToError()
- }
-
- if ifc.Ptr == 0 {
- ifc.Len = int32(len(s.Interfaces())) * int32(linux.SizeOfIFReq)
- return nil
- }
-
- max := ifc.Len
- ifc.Len = 0
- for key, ifaceAddrs := range s.InterfaceAddrs() {
- iface := s.Interfaces()[key]
- for _, ifaceAddr := range ifaceAddrs {
- // Don't write past the end of the buffer.
- if ifc.Len+int32(linux.SizeOfIFReq) > max {
- break
- }
- if ifaceAddr.Family != linux.AF_INET {
- continue
- }
-
- // Populate ifr.ifr_addr.
- ifr := linux.IFReq{}
- ifr.SetName(iface.Name)
- usermem.ByteOrder.PutUint16(ifr.Data[0:2], uint16(ifaceAddr.Family))
- usermem.ByteOrder.PutUint16(ifr.Data[2:4], 0)
- copy(ifr.Data[4:8], ifaceAddr.Addr[:4])
-
- // Copy the ifr to userspace.
- dst := uintptr(ifc.Ptr) + uintptr(ifc.Len)
- ifc.Len += int32(linux.SizeOfIFReq)
- if _, err := usermem.CopyObjectOut(ctx, io, usermem.Addr(dst), ifr, usermem.IOOpts{
- AddressSpaceActive: true,
- }); err != nil {
- return err
- }
- }
- }
- return nil
-}
-
-// Ioctl implements fs.FileOperations.Ioctl.
-func (s *socketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
- t := ctx.(*kernel.Task)
-
- cmd := uint32(args[1].Int())
- arg := args[2].Pointer()
-
- var buf []byte
- switch cmd {
- // The following ioctls take 4 byte argument parameters.
- case syscall.TIOCINQ,
- syscall.TIOCOUTQ:
- buf = make([]byte, 4)
- // The following ioctls have args which are sizeof(struct ifreq).
- case syscall.SIOCGIFADDR,
- syscall.SIOCGIFBRDADDR,
- syscall.SIOCGIFDSTADDR,
- syscall.SIOCGIFFLAGS,
- syscall.SIOCGIFHWADDR,
- syscall.SIOCGIFINDEX,
- syscall.SIOCGIFMAP,
- syscall.SIOCGIFMETRIC,
- syscall.SIOCGIFMTU,
- syscall.SIOCGIFNAME,
- syscall.SIOCGIFNETMASK,
- syscall.SIOCGIFTXQLEN:
- buf = make([]byte, linux.SizeOfIFReq)
- case syscall.SIOCGIFCONF:
- // SIOCGIFCONF has slightly different behavior than the others, in that it
- // will need to populate the array of ifreqs.
- var ifc linux.IFConf
- if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifc, usermem.IOOpts{
- AddressSpaceActive: true,
- }); err != nil {
- return 0, err
- }
-
- if err := ifconfIoctlFromStack(ctx, io, &ifc); err != nil {
- return 0, err
- }
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), ifc, usermem.IOOpts{
- AddressSpaceActive: true,
- })
-
- return 0, err
-
- case linux.SIOCGIFMEM, linux.SIOCGIFPFLAGS, linux.SIOCGMIIPHY, linux.SIOCGMIIREG:
- unimpl.EmitUnimplementedEvent(ctx)
-
- default:
- return 0, syserror.ENOTTY
- }
-
- _, err := io.CopyIn(ctx, arg, buf, usermem.IOOpts{
- AddressSpaceActive: true,
- })
-
- if err != nil {
- return 0, err
- }
-
- v, err := rpcIoctl(t, s.fd, cmd, buf)
- if err != nil {
- return 0, err
- }
-
- if len(v) != len(buf) {
- return 0, syserror.EINVAL
- }
-
- _, err = io.CopyOut(ctx, arg, v, usermem.IOOpts{
- AddressSpaceActive: true,
- })
- return 0, err
-}
-
-func rpcRecvMsg(t *kernel.Task, req *pb.SyscallRequest_Recvmsg) (*pb.RecvmsgResponse_ResultPayload, *syserr.Error) {
- s := t.NetworkContext().(*Stack)
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-c
-
- res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Recvmsg).Recvmsg.Result
- if e, ok := res.(*pb.RecvmsgResponse_ErrorNumber); ok {
- return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.RecvmsgResponse_Payload).Payload, nil
-}
-
-// Because we only support SO_TIMESTAMP we will search control messages for
-// that value and set it if so, all other control messages will be ignored.
-func (s *socketOperations) extractControlMessages(payload *pb.RecvmsgResponse_ResultPayload) socket.ControlMessages {
- c := socket.ControlMessages{}
- if len(payload.GetCmsgData()) > 0 {
- // Parse the control messages looking for SO_TIMESTAMP.
- msgs, e := syscall.ParseSocketControlMessage(payload.GetCmsgData())
- if e != nil {
- return socket.ControlMessages{}
- }
- for _, m := range msgs {
- if m.Header.Level != linux.SOL_SOCKET || m.Header.Type != linux.SO_TIMESTAMP {
- continue
- }
-
- // Let's parse the time stamp and set it.
- if len(m.Data) < linux.SizeOfTimeval {
- // Give up on locating the SO_TIMESTAMP option.
- return socket.ControlMessages{}
- }
-
- var v linux.Timeval
- binary.Unmarshal(m.Data[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
- c.IP.HasTimestamp = true
- c.IP.Timestamp = v.ToNsecCapped()
- break
- }
- }
- return c
-}
-
-// RecvMsg implements socket.Socket.RecvMsg.
-func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
- req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{
- Fd: s.fd,
- Length: uint32(dst.NumBytes()),
- Sender: senderRequested,
- Trunc: flags&linux.MSG_TRUNC != 0,
- Peek: flags&linux.MSG_PEEK != 0,
- CmsgLength: uint32(controlDataLen),
- }}
-
- res, err := rpcRecvMsg(t, req)
- if err == nil {
- var e error
- var n int
- if len(res.Data) > 0 {
- n, e = dst.CopyOut(t, res.Data)
- if e == nil && n != len(res.Data) {
- panic("CopyOut failed to copy full buffer")
- }
- }
- c := s.extractControlMessages(res)
- return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e)
- }
- if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain || flags&linux.MSG_DONTWAIT != 0 {
- return 0, 0, nil, 0, socket.ControlMessages{}, err
- }
-
- // We'll have to block. Register for notifications and keep trying to
- // send all the data.
- e, ch := waiter.NewChannelEntry(nil)
- s.EventRegister(&e, waiter.EventIn)
- defer s.EventUnregister(&e)
-
- for {
- res, err := rpcRecvMsg(t, req)
- if err == nil {
- var e error
- var n int
- if len(res.Data) > 0 {
- n, e = dst.CopyOut(t, res.Data)
- if e == nil && n != len(res.Data) {
- panic("CopyOut failed to copy full buffer")
- }
- }
- c := s.extractControlMessages(res)
- return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e)
- }
- if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain {
- return 0, 0, nil, 0, socket.ControlMessages{}, err
- }
-
- if s.isShutRdSet() {
- // Blocking would have caused us to block indefinitely so we return 0,
- // this is the same behavior as Linux.
- return 0, 0, nil, 0, socket.ControlMessages{}, nil
- }
-
- if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
- if err == syserror.ETIMEDOUT {
- return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
- }
- return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
- }
- }
-}
-
-func rpcSendMsg(t *kernel.Task, req *pb.SyscallRequest_Sendmsg) (uint32, *syserr.Error) {
- s := t.NetworkContext().(*Stack)
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-c
-
- res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Sendmsg).Sendmsg.Result
- if e, ok := res.(*pb.SendmsgResponse_ErrorNumber); ok {
- return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.SendmsgResponse_Length).Length, nil
-}
-
-// SendMsg implements socket.Socket.SendMsg.
-func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
- // Whitelist flags.
- if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 {
- return 0, syserr.ErrInvalidArgument
- }
-
- // Reject Unix control messages.
- if !controlMessages.Unix.Empty() {
- return 0, syserr.ErrInvalidArgument
- }
-
- v := buffer.NewView(int(src.NumBytes()))
-
- // Copy all the data into the buffer.
- if _, err := src.CopyIn(t, v); err != nil {
- return 0, syserr.FromError(err)
- }
-
- // TODO(bgeffon): this needs to change to map directly to a SendMsg syscall
- // in the RPC.
- totalWritten := 0
- n, err := rpcSendMsg(t, &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{
- Fd: uint32(s.fd),
- Data: v,
- Address: to,
- More: flags&linux.MSG_MORE != 0,
- EndOfRecord: flags&linux.MSG_EOR != 0,
- }})
-
- if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain || flags&linux.MSG_DONTWAIT != 0 {
- return int(n), err
- }
-
- if n > 0 {
- totalWritten += int(n)
- v.TrimFront(int(n))
- }
-
- // We'll have to block. Register for notification and keep trying to
- // send all the data.
- e, ch := waiter.NewChannelEntry(nil)
- s.EventRegister(&e, waiter.EventOut)
- defer s.EventUnregister(&e)
-
- for {
- n, err := rpcSendMsg(t, &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{
- Fd: uint32(s.fd),
- Data: v,
- Address: to,
- More: flags&linux.MSG_MORE != 0,
- EndOfRecord: flags&linux.MSG_EOR != 0,
- }})
-
- if n > 0 {
- totalWritten += int(n)
- v.TrimFront(int(n))
-
- if err == nil && totalWritten < int(src.NumBytes()) {
- continue
- }
- }
-
- if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain {
- // We eat the error in this situation.
- return int(totalWritten), nil
- }
-
- if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
- if err == syserror.ETIMEDOUT {
- return int(totalWritten), syserr.ErrTryAgain
- }
- return int(totalWritten), syserr.FromError(err)
- }
- }
-}
-
-// State implements socket.Socket.State.
-func (s *socketOperations) State() uint32 {
- // TODO(b/127845868): Define a new rpc to query the socket state.
- return 0
-}
-
-// Type implements socket.Socket.Type.
-func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) {
- return s.family, s.stype, s.protocol
-}
-
-type socketProvider struct {
- family int
-}
-
-// Socket implements socket.Provider.Socket.
-func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) {
- // Check that we are using the RPC network stack.
- stack := t.NetworkContext()
- if stack == nil {
- return nil, nil
- }
-
- s, ok := stack.(*Stack)
- if !ok {
- return nil, nil
- }
-
- // Only accept TCP and UDP.
- //
- // Try to restrict the flags we will accept to minimize backwards
- // incompatibility with netstack.
- stype := stypeflags & linux.SOCK_TYPE_MASK
- switch stype {
- case syscall.SOCK_STREAM:
- switch protocol {
- case 0, syscall.IPPROTO_TCP:
- // ok
- default:
- return nil, nil
- }
- case syscall.SOCK_DGRAM:
- switch protocol {
- case 0, syscall.IPPROTO_UDP:
- // ok
- default:
- return nil, nil
- }
- default:
- return nil, nil
- }
-
- return newSocketFile(t, s, p.family, stype, protocol)
-}
-
-// Pair implements socket.Provider.Pair.
-func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
- // Not supported by AF_INET/AF_INET6.
- return nil, nil, nil
-}
-
-func init() {
- for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} {
- socket.RegisterProvider(family, &socketProvider{family})
- }
-}
diff --git a/pkg/sentry/socket/rpcinet/stack.go b/pkg/sentry/socket/rpcinet/stack.go
deleted file mode 100644
index f7878a760..000000000
--- a/pkg/sentry/socket/rpcinet/stack.go
+++ /dev/null
@@ -1,177 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package rpcinet
-
-import (
- "fmt"
- "syscall"
-
- "gvisor.dev/gvisor/pkg/sentry/inet"
- "gvisor.dev/gvisor/pkg/sentry/socket/hostinet"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier"
- "gvisor.dev/gvisor/pkg/syserr"
- "gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/unet"
-)
-
-// Stack implements inet.Stack for RPC backed sockets.
-type Stack struct {
- interfaces map[int32]inet.Interface
- interfaceAddrs map[int32][]inet.InterfaceAddr
- routes []inet.Route
- rpcConn *conn.RPCConnection
- notifier *notifier.Notifier
-}
-
-// NewStack returns a Stack containing the current state of the host network
-// stack.
-func NewStack(fd int32) (*Stack, error) {
- sock, err := unet.NewSocket(int(fd))
- if err != nil {
- return nil, err
- }
-
- stack := &Stack{
- interfaces: make(map[int32]inet.Interface),
- interfaceAddrs: make(map[int32][]inet.InterfaceAddr),
- rpcConn: conn.NewRPCConnection(sock),
- }
-
- var e error
- stack.notifier, e = notifier.NewRPCNotifier(stack.rpcConn)
- if e != nil {
- return nil, e
- }
-
- links, err := stack.DoNetlinkRouteRequest(syscall.RTM_GETLINK)
- if err != nil {
- return nil, fmt.Errorf("RTM_GETLINK failed: %v", err)
- }
-
- addrs, err := stack.DoNetlinkRouteRequest(syscall.RTM_GETADDR)
- if err != nil {
- return nil, fmt.Errorf("RTM_GETADDR failed: %v", err)
- }
-
- e = hostinet.ExtractHostInterfaces(links, addrs, stack.interfaces, stack.interfaceAddrs)
- if e != nil {
- return nil, e
- }
-
- routes, err := stack.DoNetlinkRouteRequest(syscall.RTM_GETROUTE)
- if err != nil {
- return nil, fmt.Errorf("RTM_GETROUTE failed: %v", err)
- }
-
- stack.routes, e = hostinet.ExtractHostRoutes(routes)
- if e != nil {
- return nil, e
- }
-
- return stack, nil
-}
-
-// RPCReadFile will execute the ReadFile helper RPC method which avoids the
-// common pattern of open(2), read(2), close(2) by doing all three operations
-// as a single RPC. It will read the entire file or return EFBIG if the file
-// was too large.
-func (s *Stack) RPCReadFile(path string) ([]byte, *syserr.Error) {
- return s.rpcConn.RPCReadFile(path)
-}
-
-// RPCWriteFile will execute the WriteFile helper RPC method which avoids the
-// common pattern of open(2), write(2), write(2), close(2) by doing all
-// operations as a single RPC.
-func (s *Stack) RPCWriteFile(path string, data []byte) (int64, *syserr.Error) {
- return s.rpcConn.RPCWriteFile(path, data)
-}
-
-// Interfaces implements inet.Stack.Interfaces.
-func (s *Stack) Interfaces() map[int32]inet.Interface {
- interfaces := make(map[int32]inet.Interface)
- for k, v := range s.interfaces {
- interfaces[k] = v
- }
- return interfaces
-}
-
-// InterfaceAddrs implements inet.Stack.InterfaceAddrs.
-func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
- addrs := make(map[int32][]inet.InterfaceAddr)
- for k, v := range s.interfaceAddrs {
- addrs[k] = append([]inet.InterfaceAddr(nil), v...)
- }
- return addrs
-}
-
-// SupportsIPv6 implements inet.Stack.SupportsIPv6.
-func (s *Stack) SupportsIPv6() bool {
- panic("rpcinet handles procfs directly this method should not be called")
-}
-
-// TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize.
-func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) {
- panic("rpcinet handles procfs directly this method should not be called")
-}
-
-// SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize.
-func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error {
- panic("rpcinet handles procfs directly this method should not be called")
-
-}
-
-// TCPSendBufferSize implements inet.Stack.TCPSendBufferSize.
-func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) {
- panic("rpcinet handles procfs directly this method should not be called")
-
-}
-
-// SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize.
-func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error {
- panic("rpcinet handles procfs directly this method should not be called")
-}
-
-// TCPSACKEnabled implements inet.Stack.TCPSACKEnabled.
-func (s *Stack) TCPSACKEnabled() (bool, error) {
- panic("rpcinet handles procfs directly this method should not be called")
-}
-
-// SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled.
-func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
- panic("rpcinet handles procfs directly this method should not be called")
-}
-
-// Statistics implements inet.Stack.Statistics.
-func (s *Stack) Statistics(stat interface{}, arg string) error {
- return syserr.ErrEndpointOperation.ToError()
-}
-
-// RouteTable implements inet.Stack.RouteTable.
-func (s *Stack) RouteTable() []inet.Route {
- return append([]inet.Route(nil), s.routes...)
-}
-
-// Resume implements inet.Stack.Resume.
-func (s *Stack) Resume() {}
-
-// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints.
-func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil }
-
-// CleanupEndpoints implements inet.Stack.CleanupEndpoints.
-func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil }
-
-// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
-func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {}
diff --git a/pkg/sentry/socket/rpcinet/stack_unsafe.go b/pkg/sentry/socket/rpcinet/stack_unsafe.go
deleted file mode 100644
index a94bdad83..000000000
--- a/pkg/sentry/socket/rpcinet/stack_unsafe.go
+++ /dev/null
@@ -1,193 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package rpcinet
-
-import (
- "syscall"
- "unsafe"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
- pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/syserr"
-)
-
-// NewNetlinkRouteRequest builds a netlink message for getting the RIB,
-// the routing information base.
-func newNetlinkRouteRequest(proto, seq, family int) []byte {
- rr := &syscall.NetlinkRouteRequest{}
- rr.Header.Len = uint32(syscall.NLMSG_HDRLEN + syscall.SizeofRtGenmsg)
- rr.Header.Type = uint16(proto)
- rr.Header.Flags = syscall.NLM_F_DUMP | syscall.NLM_F_REQUEST
- rr.Header.Seq = uint32(seq)
- rr.Data.Family = uint8(family)
- return netlinkRRtoWireFormat(rr)
-}
-
-func netlinkRRtoWireFormat(rr *syscall.NetlinkRouteRequest) []byte {
- b := make([]byte, rr.Header.Len)
- *(*uint32)(unsafe.Pointer(&b[0:4][0])) = rr.Header.Len
- *(*uint16)(unsafe.Pointer(&b[4:6][0])) = rr.Header.Type
- *(*uint16)(unsafe.Pointer(&b[6:8][0])) = rr.Header.Flags
- *(*uint32)(unsafe.Pointer(&b[8:12][0])) = rr.Header.Seq
- *(*uint32)(unsafe.Pointer(&b[12:16][0])) = rr.Header.Pid
- b[16] = byte(rr.Data.Family)
- return b
-}
-
-func (s *Stack) getNetlinkFd() (uint32, *syserr.Error) {
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Socket{&pb.SocketRequest{Family: int64(syscall.AF_NETLINK), Type: int64(syscall.SOCK_RAW | syscall.SOCK_NONBLOCK), Protocol: int64(syscall.NETLINK_ROUTE)}}}, false /* ignoreResult */)
- <-c
-
- res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Socket).Socket.Result
- if e, ok := res.(*pb.SocketResponse_ErrorNumber); ok {
- return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
- return res.(*pb.SocketResponse_Fd).Fd, nil
-}
-
-func (s *Stack) bindNetlinkFd(fd uint32, sockaddr []byte) *syserr.Error {
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Bind{&pb.BindRequest{Fd: fd, Address: sockaddr}}}, false /* ignoreResult */)
- <-c
-
- if e := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Bind).Bind.ErrorNumber; e != 0 {
- return syserr.FromHost(syscall.Errno(e))
- }
- return nil
-}
-
-func (s *Stack) closeNetlinkFd(fd uint32) {
- _, _ = s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Close{&pb.CloseRequest{Fd: fd}}}, true /* ignoreResult */)
-}
-
-func (s *Stack) rpcSendMsg(req *pb.SyscallRequest_Sendmsg) (uint32, *syserr.Error) {
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-c
-
- res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Sendmsg).Sendmsg.Result
- if e, ok := res.(*pb.SendmsgResponse_ErrorNumber); ok {
- return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.SendmsgResponse_Length).Length, nil
-}
-
-func (s *Stack) sendMsg(fd uint32, buf []byte, to []byte, flags int) (int, *syserr.Error) {
- // Whitelist flags.
- if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 {
- return 0, syserr.ErrInvalidArgument
- }
-
- req := &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{
- Fd: fd,
- Data: buf,
- Address: to,
- More: flags&linux.MSG_MORE != 0,
- EndOfRecord: flags&linux.MSG_EOR != 0,
- }}
-
- n, err := s.rpcSendMsg(req)
- return int(n), err
-}
-
-func (s *Stack) rpcRecvMsg(req *pb.SyscallRequest_Recvmsg) (*pb.RecvmsgResponse_ResultPayload, *syserr.Error) {
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-c
-
- res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Recvmsg).Recvmsg.Result
- if e, ok := res.(*pb.RecvmsgResponse_ErrorNumber); ok {
- return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.RecvmsgResponse_Payload).Payload, nil
-}
-
-func (s *Stack) recvMsg(fd, l, flags uint32) ([]byte, *syserr.Error) {
- req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{
- Fd: fd,
- Length: l,
- Sender: false,
- Trunc: flags&linux.MSG_TRUNC != 0,
- Peek: flags&linux.MSG_PEEK != 0,
- }}
-
- res, err := s.rpcRecvMsg(req)
- if err != nil {
- return nil, err
- }
- return res.Data, nil
-}
-
-func (s *Stack) netlinkRequest(proto, family int) ([]byte, error) {
- fd, err := s.getNetlinkFd()
- if err != nil {
- return nil, err.ToError()
- }
- defer s.closeNetlinkFd(fd)
-
- lsa := syscall.SockaddrNetlink{Family: syscall.AF_NETLINK}
- b := binary.Marshal(nil, usermem.ByteOrder, &lsa)
- if err := s.bindNetlinkFd(fd, b); err != nil {
- return nil, err.ToError()
- }
-
- wb := newNetlinkRouteRequest(proto, 1, family)
- _, err = s.sendMsg(fd, wb, b, 0)
- if err != nil {
- return nil, err.ToError()
- }
-
- var tab []byte
-done:
- for {
- rb, err := s.recvMsg(fd, uint32(syscall.Getpagesize()), 0)
- nr := len(rb)
- if err != nil {
- return nil, err.ToError()
- }
-
- if nr < syscall.NLMSG_HDRLEN {
- return nil, syserr.ErrInvalidArgument.ToError()
- }
-
- tab = append(tab, rb...)
- msgs, e := syscall.ParseNetlinkMessage(rb)
- if e != nil {
- return nil, e
- }
-
- for _, m := range msgs {
- if m.Header.Type == syscall.NLMSG_DONE {
- break done
- }
- if m.Header.Type == syscall.NLMSG_ERROR {
- return nil, syserr.ErrInvalidArgument.ToError()
- }
- }
- }
-
- return tab, nil
-}
-
-// DoNetlinkRouteRequest returns routing information base, also known as RIB,
-// which consists of network facility information, states and parameters.
-func (s *Stack) DoNetlinkRouteRequest(req int) ([]syscall.NetlinkMessage, error) {
- data, err := s.netlinkRequest(req, syscall.AF_UNSPEC)
- if err != nil {
- return nil, err
- }
- return syscall.ParseNetlinkMessage(data)
-}
diff --git a/pkg/sentry/socket/rpcinet/syscall_rpc.proto b/pkg/sentry/socket/rpcinet/syscall_rpc.proto
deleted file mode 100644
index b677e9eb3..000000000
--- a/pkg/sentry/socket/rpcinet/syscall_rpc.proto
+++ /dev/null
@@ -1,352 +0,0 @@
-syntax = "proto3";
-
-// package syscall_rpc is a set of networking related system calls that can be
-// forwarded to a socket gofer.
-//
-package syscall_rpc;
-
-message SendmsgRequest {
- uint32 fd = 1;
- bytes data = 2 [ctype = CORD];
- bytes address = 3;
- bool more = 4;
- bool end_of_record = 5;
-}
-
-message SendmsgResponse {
- oneof result {
- uint32 error_number = 1;
- uint32 length = 2;
- }
-}
-
-message IOCtlRequest {
- uint32 fd = 1;
- uint32 cmd = 2;
- bytes arg = 3;
-}
-
-message IOCtlResponse {
- oneof result {
- uint32 error_number = 1;
- bytes value = 2;
- }
-}
-
-message RecvmsgRequest {
- uint32 fd = 1;
- uint32 length = 2;
- bool sender = 3;
- bool peek = 4;
- bool trunc = 5;
- uint32 cmsg_length = 6;
-}
-
-message OpenRequest {
- bytes path = 1;
- uint32 flags = 2;
- uint32 mode = 3;
-}
-
-message OpenResponse {
- oneof result {
- uint32 error_number = 1;
- uint32 fd = 2;
- }
-}
-
-message ReadRequest {
- uint32 fd = 1;
- uint32 length = 2;
-}
-
-message ReadResponse {
- oneof result {
- uint32 error_number = 1;
- bytes data = 2 [ctype = CORD];
- }
-}
-
-message ReadFileRequest {
- string path = 1;
-}
-
-message ReadFileResponse {
- oneof result {
- uint32 error_number = 1;
- bytes data = 2 [ctype = CORD];
- }
-}
-
-message WriteRequest {
- uint32 fd = 1;
- bytes data = 2 [ctype = CORD];
-}
-
-message WriteResponse {
- oneof result {
- uint32 error_number = 1;
- uint32 length = 2;
- }
-}
-
-message WriteFileRequest {
- string path = 1;
- bytes content = 2;
-}
-
-message WriteFileResponse {
- uint32 error_number = 1;
- uint32 written = 2;
-}
-
-message AddressResponse {
- bytes address = 1;
- uint32 length = 2;
-}
-
-message RecvmsgResponse {
- message ResultPayload {
- bytes data = 1 [ctype = CORD];
- AddressResponse address = 2;
- uint32 length = 3;
- bytes cmsg_data = 4;
- }
- oneof result {
- uint32 error_number = 1;
- ResultPayload payload = 2;
- }
-}
-
-message BindRequest {
- uint32 fd = 1;
- bytes address = 2;
-}
-
-message BindResponse {
- uint32 error_number = 1;
-}
-
-message AcceptRequest {
- uint32 fd = 1;
- bool peer = 2;
- int64 flags = 3;
-}
-
-message AcceptResponse {
- message ResultPayload {
- uint32 fd = 1;
- AddressResponse address = 2;
- }
- oneof result {
- uint32 error_number = 1;
- ResultPayload payload = 2;
- }
-}
-
-message ConnectRequest {
- uint32 fd = 1;
- bytes address = 2;
-}
-
-message ConnectResponse {
- uint32 error_number = 1;
-}
-
-message ListenRequest {
- uint32 fd = 1;
- int64 backlog = 2;
-}
-
-message ListenResponse {
- uint32 error_number = 1;
-}
-
-message ShutdownRequest {
- uint32 fd = 1;
- int64 how = 2;
-}
-
-message ShutdownResponse {
- uint32 error_number = 1;
-}
-
-message CloseRequest {
- uint32 fd = 1;
-}
-
-message CloseResponse {
- uint32 error_number = 1;
-}
-
-message GetSockOptRequest {
- uint32 fd = 1;
- int64 level = 2;
- int64 name = 3;
- uint32 length = 4;
-}
-
-message GetSockOptResponse {
- oneof result {
- uint32 error_number = 1;
- bytes opt = 2;
- }
-}
-
-message SetSockOptRequest {
- uint32 fd = 1;
- int64 level = 2;
- int64 name = 3;
- bytes opt = 4;
-}
-
-message SetSockOptResponse {
- uint32 error_number = 1;
-}
-
-message GetSockNameRequest {
- uint32 fd = 1;
-}
-
-message GetSockNameResponse {
- oneof result {
- uint32 error_number = 1;
- AddressResponse address = 2;
- }
-}
-
-message GetPeerNameRequest {
- uint32 fd = 1;
-}
-
-message GetPeerNameResponse {
- oneof result {
- uint32 error_number = 1;
- AddressResponse address = 2;
- }
-}
-
-message SocketRequest {
- int64 family = 1;
- int64 type = 2;
- int64 protocol = 3;
-}
-
-message SocketResponse {
- oneof result {
- uint32 error_number = 1;
- uint32 fd = 2;
- }
-}
-
-message EpollWaitRequest {
- uint32 fd = 1;
- uint32 num_events = 2;
- sint64 msec = 3;
-}
-
-message EpollEvent {
- uint32 fd = 1;
- uint32 events = 2;
-}
-
-message EpollEvents {
- repeated EpollEvent events = 1;
-}
-
-message EpollWaitResponse {
- oneof result {
- uint32 error_number = 1;
- EpollEvents events = 2;
- }
-}
-
-message EpollCtlRequest {
- uint32 epfd = 1;
- int64 op = 2;
- uint32 fd = 3;
- EpollEvent event = 4;
-}
-
-message EpollCtlResponse {
- uint32 error_number = 1;
-}
-
-message EpollCreate1Request {
- int64 flag = 1;
-}
-
-message EpollCreate1Response {
- oneof result {
- uint32 error_number = 1;
- uint32 fd = 2;
- }
-}
-
-message PollRequest {
- uint32 fd = 1;
- uint32 events = 2;
-}
-
-message PollResponse {
- oneof result {
- uint32 error_number = 1;
- uint32 events = 2;
- }
-}
-
-message SyscallRequest {
- oneof args {
- SocketRequest socket = 1;
- SendmsgRequest sendmsg = 2;
- RecvmsgRequest recvmsg = 3;
- BindRequest bind = 4;
- AcceptRequest accept = 5;
- ConnectRequest connect = 6;
- ListenRequest listen = 7;
- ShutdownRequest shutdown = 8;
- CloseRequest close = 9;
- GetSockOptRequest get_sock_opt = 10;
- SetSockOptRequest set_sock_opt = 11;
- GetSockNameRequest get_sock_name = 12;
- GetPeerNameRequest get_peer_name = 13;
- EpollWaitRequest epoll_wait = 14;
- EpollCtlRequest epoll_ctl = 15;
- EpollCreate1Request epoll_create1 = 16;
- PollRequest poll = 17;
- ReadRequest read = 18;
- WriteRequest write = 19;
- OpenRequest open = 20;
- IOCtlRequest ioctl = 21;
- WriteFileRequest write_file = 22;
- ReadFileRequest read_file = 23;
- }
-}
-
-message SyscallResponse {
- oneof result {
- SocketResponse socket = 1;
- SendmsgResponse sendmsg = 2;
- RecvmsgResponse recvmsg = 3;
- BindResponse bind = 4;
- AcceptResponse accept = 5;
- ConnectResponse connect = 6;
- ListenResponse listen = 7;
- ShutdownResponse shutdown = 8;
- CloseResponse close = 9;
- GetSockOptResponse get_sock_opt = 10;
- SetSockOptResponse set_sock_opt = 11;
- GetSockNameResponse get_sock_name = 12;
- GetPeerNameResponse get_peer_name = 13;
- EpollWaitResponse epoll_wait = 14;
- EpollCtlResponse epoll_ctl = 15;
- EpollCreate1Response epoll_create1 = 16;
- PollResponse poll = 17;
- ReadResponse read = 18;
- WriteResponse write = 19;
- OpenResponse open = 20;
- IOCtlResponse ioctl = 21;
- WriteFileResponse write_file = 22;
- ReadFileResponse read_file = 23;
- }
-}
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
index 788ad70d2..d7ba95dff 100644
--- a/pkg/sentry/socket/unix/transport/BUILD
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -32,6 +32,7 @@ go_library(
"//pkg/ilist",
"//pkg/refs",
"//pkg/sentry/context",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index dea11e253..9e6fbc111 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -15,10 +15,9 @@
package transport
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/waiter"
diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go
index e27b1c714..5dcd3d95e 100644
--- a/pkg/sentry/socket/unix/transport/queue.go
+++ b/pkg/sentry/socket/unix/transport/queue.go
@@ -15,9 +15,8 @@
package transport
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 37c7ac3c1..fcc0da332 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -16,11 +16,11 @@
package transport
import (
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 91effe89a..7f49ba864 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -116,13 +116,16 @@ func (s *SocketOperations) Endpoint() transport.Endpoint {
// extractPath extracts and validates the address.
func extractPath(sockaddr []byte) (string, *syserr.Error) {
- addr, _, err := netstack.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */)
+ addr, family, err := netstack.AddressAndFamily(sockaddr)
if err != nil {
if err == syserr.ErrAddressFamilyNotSupported {
err = syserr.ErrInvalidArgument
}
return "", err
}
+ if family != linux.AF_UNIX {
+ return "", syserr.ErrInvalidArgument
+ }
// The address is trimmed by GetAddress.
p := string(addr.Addr)
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index d46421199..aa1ac720c 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -10,7 +10,8 @@ go_library(
"capability.go",
"clone.go",
"futex.go",
- "linux64.go",
+ "linux64_amd64.go",
+ "linux64_arm64.go",
"open.go",
"poll.go",
"ptrace.go",
diff --git a/pkg/sentry/strace/linux64.go b/pkg/sentry/strace/linux64_amd64.go
index e603f858f..1e823b685 100644
--- a/pkg/sentry/strace/linux64.go
+++ b/pkg/sentry/strace/linux64_amd64.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,8 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build amd64
+
package strace
+import (
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
// linuxAMD64 provides a mapping of the Linux amd64 syscalls and their argument
// types for display / formatting.
var linuxAMD64 = SyscallMap{
@@ -365,3 +372,13 @@ var linuxAMD64 = SyscallMap{
434: makeSyscallInfo("pidfd_open", Hex, Hex),
435: makeSyscallInfo("clone3", Hex, Hex),
}
+
+func init() {
+ syscallTables = append(syscallTables,
+ syscallTable{
+ os: abi.Linux,
+ arch: arch.AMD64,
+ syscalls: linuxAMD64,
+ },
+ )
+}
diff --git a/pkg/sentry/strace/linux64_arm64.go b/pkg/sentry/strace/linux64_arm64.go
new file mode 100644
index 000000000..c3ac5248d
--- /dev/null
+++ b/pkg/sentry/strace/linux64_arm64.go
@@ -0,0 +1,323 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package strace
+
+import (
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+// linuxARM64 provides a mapping of the Linux arm64 syscalls and their argument
+// types for display / formatting.
+var linuxARM64 = SyscallMap{
+ 0: makeSyscallInfo("io_setup", Hex, Hex),
+ 1: makeSyscallInfo("io_destroy", Hex),
+ 2: makeSyscallInfo("io_submit", Hex, Hex, Hex),
+ 3: makeSyscallInfo("io_cancel", Hex, Hex, Hex),
+ 4: makeSyscallInfo("io_getevents", Hex, Hex, Hex, Hex, Timespec),
+ 5: makeSyscallInfo("setxattr", Path, Path, Hex, Hex, Hex),
+ 6: makeSyscallInfo("lsetxattr", Path, Path, Hex, Hex, Hex),
+ 7: makeSyscallInfo("fsetxattr", FD, Path, Hex, Hex, Hex),
+ 8: makeSyscallInfo("getxattr", Path, Path, Hex, Hex),
+ 9: makeSyscallInfo("lgetxattr", Path, Path, Hex, Hex),
+ 10: makeSyscallInfo("fgetxattr", FD, Path, Hex, Hex),
+ 11: makeSyscallInfo("listxattr", Path, Path, Hex),
+ 12: makeSyscallInfo("llistxattr", Path, Path, Hex),
+ 13: makeSyscallInfo("flistxattr", FD, Path, Hex),
+ 14: makeSyscallInfo("removexattr", Path, Path),
+ 15: makeSyscallInfo("lremovexattr", Path, Path),
+ 16: makeSyscallInfo("fremovexattr", FD, Path),
+ 17: makeSyscallInfo("getcwd", PostPath, Hex),
+ 18: makeSyscallInfo("lookup_dcookie", Hex, Hex, Hex),
+ 19: makeSyscallInfo("eventfd2", Hex, Hex),
+ 20: makeSyscallInfo("epoll_create1", Hex),
+ 21: makeSyscallInfo("epoll_ctl", Hex, Hex, FD, Hex),
+ 22: makeSyscallInfo("epoll_pwait", Hex, Hex, Hex, Hex, SigSet, Hex),
+ 23: makeSyscallInfo("dup", FD),
+ 24: makeSyscallInfo("dup3", FD, FD, Hex),
+ 25: makeSyscallInfo("fcntl", FD, Hex, Hex),
+ 26: makeSyscallInfo("inotify_init1", Hex),
+ 27: makeSyscallInfo("inotify_add_watch", Hex, Path, Hex),
+ 28: makeSyscallInfo("inotify_rm_watch", Hex, Hex),
+ 29: makeSyscallInfo("ioctl", FD, Hex, Hex),
+ 30: makeSyscallInfo("ioprio_set", Hex, Hex, Hex),
+ 31: makeSyscallInfo("ioprio_get", Hex, Hex),
+ 32: makeSyscallInfo("flock", FD, Hex),
+ 33: makeSyscallInfo("mknodat", FD, Path, Mode, Hex),
+ 34: makeSyscallInfo("mkdirat", FD, Path, Hex),
+ 35: makeSyscallInfo("unlinkat", FD, Path, Hex),
+ 36: makeSyscallInfo("symlinkat", Path, Hex, Path),
+ 37: makeSyscallInfo("linkat", FD, Path, Hex, Path, Hex),
+ 38: makeSyscallInfo("renameat", FD, Path, Hex, Path),
+ 39: makeSyscallInfo("umount2", Path, Hex),
+ 40: makeSyscallInfo("mount", Path, Path, Path, Hex, Path),
+ 41: makeSyscallInfo("pivot_root", Path, Path),
+ 42: makeSyscallInfo("nfsservctl", Hex, Hex, Hex),
+ 43: makeSyscallInfo("statfs", Path, Hex),
+ 44: makeSyscallInfo("fstatfs", FD, Hex),
+ 45: makeSyscallInfo("truncate", Path, Hex),
+ 46: makeSyscallInfo("ftruncate", FD, Hex),
+ 47: makeSyscallInfo("fallocate", FD, Hex, Hex, Hex),
+ 48: makeSyscallInfo("faccessat", FD, Path, Oct, Hex),
+ 49: makeSyscallInfo("chdir", Path),
+ 50: makeSyscallInfo("fchdir", FD),
+ 51: makeSyscallInfo("chroot", Path),
+ 52: makeSyscallInfo("fchmod", FD, Mode),
+ 53: makeSyscallInfo("fchmodat", FD, Path, Mode),
+ 54: makeSyscallInfo("fchownat", FD, Path, Hex, Hex, Hex),
+ 55: makeSyscallInfo("fchown", FD, Hex, Hex),
+ 56: makeSyscallInfo("openat", FD, Path, OpenFlags, Mode),
+ 57: makeSyscallInfo("close", FD),
+ 58: makeSyscallInfo("vhangup"),
+ 59: makeSyscallInfo("pipe2", PipeFDs, Hex),
+ 60: makeSyscallInfo("quotactl", Hex, Hex, Hex, Hex),
+ 61: makeSyscallInfo("getdents64", FD, Hex, Hex),
+ 62: makeSyscallInfo("lseek", Hex, Hex, Hex),
+ 63: makeSyscallInfo("read", FD, ReadBuffer, Hex),
+ 64: makeSyscallInfo("write", FD, WriteBuffer, Hex),
+ 65: makeSyscallInfo("readv", FD, ReadIOVec, Hex),
+ 66: makeSyscallInfo("writev", FD, WriteIOVec, Hex),
+ 67: makeSyscallInfo("pread64", FD, ReadBuffer, Hex, Hex),
+ 68: makeSyscallInfo("pwrite64", FD, WriteBuffer, Hex, Hex),
+ 69: makeSyscallInfo("preadv", FD, ReadIOVec, Hex, Hex),
+ 70: makeSyscallInfo("pwritev", FD, WriteIOVec, Hex, Hex),
+ 71: makeSyscallInfo("sendfile", FD, FD, Hex, Hex),
+ 72: makeSyscallInfo("pselect6", Hex, Hex, Hex, Hex, Hex, Hex),
+ 73: makeSyscallInfo("ppoll", PollFDs, Hex, Timespec, SigSet, Hex),
+ 74: makeSyscallInfo("signalfd4", Hex, Hex, Hex, Hex),
+ 75: makeSyscallInfo("vmsplice", FD, Hex, Hex, Hex),
+ 76: makeSyscallInfo("splice", FD, Hex, FD, Hex, Hex, Hex),
+ 77: makeSyscallInfo("tee", FD, FD, Hex, Hex),
+ 78: makeSyscallInfo("readlinkat", FD, Path, ReadBuffer, Hex),
+ 79: makeSyscallInfo("fstatat", FD, Path, Stat, Hex),
+ 80: makeSyscallInfo("fstat", FD, Stat),
+ 81: makeSyscallInfo("sync"),
+ 82: makeSyscallInfo("fsync", FD),
+ 83: makeSyscallInfo("fdatasync", FD),
+ 84: makeSyscallInfo("sync_file_range", FD, Hex, Hex, Hex),
+ 85: makeSyscallInfo("timerfd_create", Hex, Hex),
+ 86: makeSyscallInfo("timerfd_settime", FD, Hex, ItimerSpec, PostItimerSpec),
+ 87: makeSyscallInfo("timerfd_gettime", FD, PostItimerSpec),
+ 88: makeSyscallInfo("utimensat", FD, Path, UTimeTimespec, Hex),
+ 89: makeSyscallInfo("acct", Hex),
+ 90: makeSyscallInfo("capget", CapHeader, PostCapData),
+ 91: makeSyscallInfo("capset", CapHeader, CapData),
+ 92: makeSyscallInfo("personality", Hex),
+ 93: makeSyscallInfo("exit", Hex),
+ 94: makeSyscallInfo("exit_group", Hex),
+ 95: makeSyscallInfo("waitid", Hex, Hex, Hex, Hex, Rusage),
+ 96: makeSyscallInfo("set_tid_address", Hex),
+ 97: makeSyscallInfo("unshare", CloneFlags),
+ 98: makeSyscallInfo("futex", Hex, FutexOp, Hex, Timespec, Hex, Hex),
+ 99: makeSyscallInfo("set_robust_list", Hex, Hex),
+ 100: makeSyscallInfo("get_robust_list", Hex, Hex, Hex),
+ 101: makeSyscallInfo("nanosleep", Timespec, PostTimespec),
+ 102: makeSyscallInfo("getitimer", ItimerType, PostItimerVal),
+ 103: makeSyscallInfo("setitimer", ItimerType, ItimerVal, PostItimerVal),
+ 104: makeSyscallInfo("kexec_load", Hex, Hex, Hex, Hex),
+ 105: makeSyscallInfo("init_module", Hex, Hex, Hex),
+ 106: makeSyscallInfo("delete_module", Hex, Hex),
+ 107: makeSyscallInfo("timer_create", Hex, Hex, Hex),
+ 108: makeSyscallInfo("timer_gettime", Hex, PostItimerSpec),
+ 109: makeSyscallInfo("timer_getoverrun", Hex),
+ 110: makeSyscallInfo("timer_settime", Hex, Hex, ItimerSpec, PostItimerSpec),
+ 111: makeSyscallInfo("timer_delete", Hex),
+ 112: makeSyscallInfo("clock_settime", Hex, Timespec),
+ 113: makeSyscallInfo("clock_gettime", Hex, PostTimespec),
+ 114: makeSyscallInfo("clock_getres", Hex, PostTimespec),
+ 115: makeSyscallInfo("clock_nanosleep", Hex, Hex, Timespec, PostTimespec),
+ 116: makeSyscallInfo("syslog", Hex, Hex, Hex),
+ 117: makeSyscallInfo("ptrace", PtraceRequest, Hex, Hex, Hex),
+ 118: makeSyscallInfo("sched_setparam", Hex, Hex),
+ 119: makeSyscallInfo("sched_setscheduler", Hex, Hex, Hex),
+ 120: makeSyscallInfo("sched_getscheduler", Hex),
+ 121: makeSyscallInfo("sched_getparam", Hex, Hex),
+ 122: makeSyscallInfo("sched_setaffinity", Hex, Hex, Hex),
+ 123: makeSyscallInfo("sched_getaffinity", Hex, Hex, Hex),
+ 124: makeSyscallInfo("sched_yield"),
+ 125: makeSyscallInfo("sched_get_priority_max", Hex),
+ 126: makeSyscallInfo("sched_get_priority_min", Hex),
+ 127: makeSyscallInfo("sched_rr_get_interval", Hex, Hex),
+ 128: makeSyscallInfo("restart_syscall"),
+ 129: makeSyscallInfo("kill", Hex, Signal),
+ 130: makeSyscallInfo("tkill", Hex, Signal),
+ 131: makeSyscallInfo("tgkill", Hex, Hex, Signal),
+ 132: makeSyscallInfo("sigaltstack", Hex, Hex),
+ 133: makeSyscallInfo("rt_sigsuspend", Hex),
+ 134: makeSyscallInfo("rt_sigaction", Signal, SigAction, PostSigAction),
+ 135: makeSyscallInfo("rt_sigprocmask", SignalMaskAction, SigSet, PostSigSet, Hex),
+ 136: makeSyscallInfo("rt_sigpending", Hex),
+ 137: makeSyscallInfo("rt_sigtimedwait", SigSet, Hex, Timespec, Hex),
+ 138: makeSyscallInfo("rt_sigqueueinfo", Hex, Signal, Hex),
+ 139: makeSyscallInfo("rt_sigreturn"),
+ 140: makeSyscallInfo("setpriority", Hex, Hex, Hex),
+ 141: makeSyscallInfo("getpriority", Hex, Hex),
+ 142: makeSyscallInfo("reboot", Hex, Hex, Hex, Hex),
+ 143: makeSyscallInfo("setregid", Hex, Hex),
+ 144: makeSyscallInfo("setgid", Hex),
+ 145: makeSyscallInfo("setreuid", Hex, Hex),
+ 146: makeSyscallInfo("setuid", Hex),
+ 147: makeSyscallInfo("setresuid", Hex, Hex, Hex),
+ 148: makeSyscallInfo("getresuid", Hex, Hex, Hex),
+ 149: makeSyscallInfo("setresgid", Hex, Hex, Hex),
+ 150: makeSyscallInfo("getresgid", Hex, Hex, Hex),
+ 151: makeSyscallInfo("setfsuid", Hex),
+ 152: makeSyscallInfo("setfsgid", Hex),
+ 153: makeSyscallInfo("times", Hex),
+ 154: makeSyscallInfo("setpgid", Hex, Hex),
+ 155: makeSyscallInfo("getpgid", Hex),
+ 156: makeSyscallInfo("getsid", Hex),
+ 157: makeSyscallInfo("setsid"),
+ 158: makeSyscallInfo("getgroups", Hex, Hex),
+ 159: makeSyscallInfo("setgroups", Hex, Hex),
+ 160: makeSyscallInfo("uname", Uname),
+ 161: makeSyscallInfo("sethostname", Hex, Hex),
+ 162: makeSyscallInfo("setdomainname", Hex, Hex),
+ 163: makeSyscallInfo("getrlimit", Hex, Hex),
+ 164: makeSyscallInfo("setrlimit", Hex, Hex),
+ 165: makeSyscallInfo("getrusage", Hex, Rusage),
+ 166: makeSyscallInfo("umask", Hex),
+ 167: makeSyscallInfo("prctl", Hex, Hex, Hex, Hex, Hex),
+ 168: makeSyscallInfo("getcpu", Hex, Hex, Hex),
+ 169: makeSyscallInfo("gettimeofday", Timeval, Hex),
+ 170: makeSyscallInfo("settimeofday", Timeval, Hex),
+ 171: makeSyscallInfo("adjtimex", Hex),
+ 172: makeSyscallInfo("getpid"),
+ 173: makeSyscallInfo("getppid"),
+ 174: makeSyscallInfo("getuid"),
+ 175: makeSyscallInfo("geteuid"),
+ 176: makeSyscallInfo("getgid"),
+ 177: makeSyscallInfo("getegid"),
+ 178: makeSyscallInfo("gettid"),
+ 179: makeSyscallInfo("sysinfo", Hex),
+ 180: makeSyscallInfo("mq_open", Hex, Hex, Hex, Hex),
+ 181: makeSyscallInfo("mq_unlink", Hex),
+ 182: makeSyscallInfo("mq_timedsend", Hex, Hex, Hex, Hex, Hex),
+ 183: makeSyscallInfo("mq_timedreceive", Hex, Hex, Hex, Hex, Hex),
+ 184: makeSyscallInfo("mq_notify", Hex, Hex),
+ 185: makeSyscallInfo("mq_getsetattr", Hex, Hex, Hex),
+ 186: makeSyscallInfo("msgget", Hex, Hex),
+ 187: makeSyscallInfo("msgctl", Hex, Hex, Hex),
+ 188: makeSyscallInfo("msgrcv", Hex, Hex, Hex, Hex, Hex),
+ 189: makeSyscallInfo("msgsnd", Hex, Hex, Hex, Hex),
+ 190: makeSyscallInfo("semget", Hex, Hex, Hex),
+ 191: makeSyscallInfo("semctl", Hex, Hex, Hex, Hex),
+ 192: makeSyscallInfo("semtimedop", Hex, Hex, Hex, Hex),
+ 193: makeSyscallInfo("semop", Hex, Hex, Hex),
+ 194: makeSyscallInfo("shmget", Hex, Hex, Hex),
+ 195: makeSyscallInfo("shmctl", Hex, Hex, Hex),
+ 196: makeSyscallInfo("shmat", Hex, Hex, Hex),
+ 197: makeSyscallInfo("shmdt", Hex),
+ 198: makeSyscallInfo("socket", SockFamily, SockType, SockProtocol),
+ 199: makeSyscallInfo("socketpair", SockFamily, SockType, SockProtocol, Hex),
+ 200: makeSyscallInfo("bind", FD, SockAddr, Hex),
+ 201: makeSyscallInfo("listen", FD, Hex),
+ 202: makeSyscallInfo("accept", FD, PostSockAddr, SockLen),
+ 203: makeSyscallInfo("connect", FD, SockAddr, Hex),
+ 204: makeSyscallInfo("getsockname", FD, PostSockAddr, SockLen),
+ 205: makeSyscallInfo("getpeername", FD, PostSockAddr, SockLen),
+ 206: makeSyscallInfo("sendto", FD, Hex, Hex, Hex, SockAddr, Hex),
+ 207: makeSyscallInfo("recvfrom", FD, Hex, Hex, Hex, PostSockAddr, SockLen),
+ 208: makeSyscallInfo("setsockopt", FD, Hex, Hex, Hex, Hex),
+ 209: makeSyscallInfo("getsockopt", FD, Hex, Hex, Hex, Hex),
+ 210: makeSyscallInfo("shutdown", FD, Hex),
+ 211: makeSyscallInfo("sendmsg", FD, SendMsgHdr, Hex),
+ 212: makeSyscallInfo("recvmsg", FD, RecvMsgHdr, Hex),
+ 213: makeSyscallInfo("readahead", Hex, Hex, Hex),
+ 214: makeSyscallInfo("brk", Hex),
+ 215: makeSyscallInfo("munmap", Hex, Hex),
+ 216: makeSyscallInfo("mremap", Hex, Hex, Hex, Hex, Hex),
+ 217: makeSyscallInfo("add_key", Hex, Hex, Hex, Hex, Hex),
+ 218: makeSyscallInfo("request_key", Hex, Hex, Hex, Hex),
+ 219: makeSyscallInfo("keyctl", Hex, Hex, Hex, Hex, Hex),
+ 220: makeSyscallInfo("clone", CloneFlags, Hex, Hex, Hex, Hex),
+ 221: makeSyscallInfo("execve", Path, ExecveStringVector, ExecveStringVector),
+ 222: makeSyscallInfo("mmap", Hex, Hex, Hex, Hex, FD, Hex),
+ 223: makeSyscallInfo("fadvise64", FD, Hex, Hex, Hex),
+ 224: makeSyscallInfo("swapon", Hex, Hex),
+ 225: makeSyscallInfo("swapoff", Hex),
+ 226: makeSyscallInfo("mprotect", Hex, Hex, Hex),
+ 227: makeSyscallInfo("msync", Hex, Hex, Hex),
+ 228: makeSyscallInfo("mlock", Hex, Hex),
+ 229: makeSyscallInfo("munlock", Hex, Hex),
+ 230: makeSyscallInfo("mlockall", Hex),
+ 231: makeSyscallInfo("munlockall"),
+ 232: makeSyscallInfo("mincore", Hex, Hex, Hex),
+ 233: makeSyscallInfo("madvise", Hex, Hex, Hex),
+ 234: makeSyscallInfo("remap_file_pages", Hex, Hex, Hex, Hex, Hex),
+ 235: makeSyscallInfo("mbind", Hex, Hex, Hex, Hex, Hex, Hex),
+ 236: makeSyscallInfo("get_mempolicy", Hex, Hex, Hex, Hex, Hex),
+ 237: makeSyscallInfo("set_mempolicy", Hex, Hex, Hex),
+ 238: makeSyscallInfo("migrate_pages", Hex, Hex, Hex, Hex),
+ 239: makeSyscallInfo("move_pages", Hex, Hex, Hex, Hex, Hex, Hex),
+ 240: makeSyscallInfo("rt_tgsigqueueinfo", Hex, Hex, Signal, Hex),
+ 241: makeSyscallInfo("perf_event_open", Hex, Hex, Hex, Hex, Hex),
+ 242: makeSyscallInfo("accept4", FD, PostSockAddr, SockLen, SockFlags),
+ 243: makeSyscallInfo("recvmmsg", FD, Hex, Hex, Hex, Hex),
+
+ 260: makeSyscallInfo("wait4", Hex, Hex, Hex, Rusage),
+ 261: makeSyscallInfo("prlimit64", Hex, Hex, Hex, Hex),
+ 262: makeSyscallInfo("fanotify_init", Hex, Hex),
+ 263: makeSyscallInfo("fanotify_mark", Hex, Hex, Hex, Hex, Hex),
+ 264: makeSyscallInfo("name_to_handle_at", FD, Hex, Hex, Hex, Hex),
+ 265: makeSyscallInfo("open_by_handle_at", FD, Hex, Hex),
+ 266: makeSyscallInfo("clock_adjtime", Hex, Hex),
+ 267: makeSyscallInfo("syncfs", FD),
+ 268: makeSyscallInfo("setns", FD, Hex),
+ 269: makeSyscallInfo("sendmmsg", FD, Hex, Hex, Hex),
+ 270: makeSyscallInfo("process_vm_readv", Hex, ReadIOVec, Hex, IOVec, Hex, Hex),
+ 271: makeSyscallInfo("process_vm_writev", Hex, IOVec, Hex, WriteIOVec, Hex, Hex),
+ 272: makeSyscallInfo("kcmp", Hex, Hex, Hex, Hex, Hex),
+ 273: makeSyscallInfo("finit_module", Hex, Hex, Hex),
+ 274: makeSyscallInfo("sched_setattr", Hex, Hex, Hex),
+ 275: makeSyscallInfo("sched_getattr", Hex, Hex, Hex),
+ 276: makeSyscallInfo("renameat2", FD, Path, Hex, Path, Hex),
+ 277: makeSyscallInfo("seccomp", Hex, Hex, Hex),
+ 278: makeSyscallInfo("getrandom", Hex, Hex, Hex),
+ 279: makeSyscallInfo("memfd_create", Path, Hex),
+ 280: makeSyscallInfo("bpf", Hex, Hex, Hex),
+ 281: makeSyscallInfo("execveat", FD, Path, Hex, Hex, Hex),
+ 282: makeSyscallInfo("userfaultfd", Hex),
+ 283: makeSyscallInfo("membarrier", Hex),
+ 284: makeSyscallInfo("mlock2", Hex, Hex, Hex),
+ 285: makeSyscallInfo("copy_file_range", FD, Hex, FD, Hex, Hex, Hex),
+ 286: makeSyscallInfo("preadv2", FD, ReadIOVec, Hex, Hex, Hex),
+ 287: makeSyscallInfo("pwritev2", FD, WriteIOVec, Hex, Hex, Hex),
+ 291: makeSyscallInfo("statx", FD, Path, Hex, Hex, Hex),
+ 292: makeSyscallInfo("io_pgetevents", Hex, Hex, Hex, Hex, Timespec, SigSet),
+ 293: makeSyscallInfo("rseq", Hex, Hex, Hex, Hex),
+ 424: makeSyscallInfo("pidfd_send_signal", FD, Signal, Hex, Hex),
+ 425: makeSyscallInfo("io_uring_setup", Hex, Hex),
+ 426: makeSyscallInfo("io_uring_enter", FD, Hex, Hex, Hex, SigSet, Hex),
+ 427: makeSyscallInfo("io_uring_register", FD, Hex, Hex, Hex),
+ 428: makeSyscallInfo("open_tree", FD, Path, Hex),
+ 429: makeSyscallInfo("move_mount", FD, Path, FD, Path, Hex),
+ 430: makeSyscallInfo("fsopen", Path, Hex), // Not quite a path, but close.
+ 431: makeSyscallInfo("fsconfig", FD, Hex, Hex, Hex, Hex),
+ 432: makeSyscallInfo("fsmount", FD, Hex, Hex),
+ 433: makeSyscallInfo("fspick", FD, Path, Hex),
+ 434: makeSyscallInfo("pidfd_open", Hex, Hex),
+ 435: makeSyscallInfo("clone3", Hex, Hex),
+}
+
+func init() {
+ syscallTables = append(syscallTables,
+ syscallTable{
+ os: abi.Linux,
+ arch: arch.ARM64,
+ syscalls: linuxARM64})
+}
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index 51f2efb39..b6d7177f4 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -341,7 +341,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string {
switch family {
case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX:
- fa, _, err := netstack.AddressAndFamily(int(family), b, true /* strict */)
+ fa, _, err := netstack.AddressAndFamily(b)
if err != nil {
return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err)
}
diff --git a/pkg/sentry/strace/syscalls.go b/pkg/sentry/strace/syscalls.go
index e5d486c4e..24e29a2ba 100644
--- a/pkg/sentry/strace/syscalls.go
+++ b/pkg/sentry/strace/syscalls.go
@@ -250,14 +250,7 @@ type syscallTable struct {
syscalls SyscallMap
}
-// syscallTables contains all syscall tables.
-var syscallTables = []syscallTable{
- {
- os: abi.Linux,
- arch: arch.AMD64,
- syscalls: linuxAMD64,
- },
-}
+var syscallTables []syscallTable
// Lookup returns the SyscallMap for the OS/Arch combination. The returned map
// must not be changed.
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index a76975cee..430d796ba 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -13,6 +13,8 @@ go_library(
"sigset.go",
"sys_aio.go",
"sys_capability.go",
+ "sys_clone_amd64.go",
+ "sys_clone_arm64.go",
"sys_epoll.go",
"sys_eventfd.go",
"sys_file.go",
@@ -91,6 +93,7 @@ go_library(
"//pkg/sentry/syscalls",
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/waiter",
diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go
index 1d9018c96..60469549d 100644
--- a/pkg/sentry/syscalls/linux/error.go
+++ b/pkg/sentry/syscalls/linux/error.go
@@ -16,13 +16,13 @@ package linux
import (
"io"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/syscalls/linux/linux64_amd64.go b/pkg/sentry/syscalls/linux/linux64_amd64.go
index 479c5f6ff..6b2920900 100644
--- a/pkg/sentry/syscalls/linux/linux64_amd64.go
+++ b/pkg/sentry/syscalls/linux/linux64_amd64.go
@@ -228,10 +228,10 @@ var AMD64 = &kernel.SyscallTable{
185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil),
186: syscalls.Supported("gettid", Gettid),
187: syscalls.Supported("readahead", Readahead),
- 188: syscalls.PartiallySupported("setxattr", Setxattr, "Only supported for tmpfs.", nil),
+ 188: syscalls.PartiallySupported("setxattr", SetXattr, "Only supported for tmpfs.", nil),
189: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
190: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 191: syscalls.PartiallySupported("getxattr", Getxattr, "Only supported for tmpfs.", nil),
+ 191: syscalls.PartiallySupported("getxattr", GetXattr, "Only supported for tmpfs.", nil),
192: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
193: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
194: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
diff --git a/pkg/sentry/syscalls/linux/linux64_arm64.go b/pkg/sentry/syscalls/linux/linux64_arm64.go
index d3f61f5e8..8c1b20911 100644
--- a/pkg/sentry/syscalls/linux/linux64_arm64.go
+++ b/pkg/sentry/syscalls/linux/linux64_arm64.go
@@ -41,10 +41,10 @@ var ARM64 = &kernel.SyscallTable{
2: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
3: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
4: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
- 5: syscalls.PartiallySupported("setxattr", Setxattr, "Only supported for tmpfs.", nil),
+ 5: syscalls.PartiallySupported("setxattr", SetXattr, "Only supported for tmpfs.", nil),
6: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
7: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 8: syscalls.PartiallySupported("getxattr", Getxattr, "Only supported for tmpfs.", nil),
+ 8: syscalls.PartiallySupported("getxattr", GetXattr, "Only supported for tmpfs.", nil),
9: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
10: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
11: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
diff --git a/pkg/sentry/syscalls/linux/sys_clone_amd64.go b/pkg/sentry/syscalls/linux/sys_clone_amd64.go
new file mode 100644
index 000000000..dd43cf18d
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_clone_amd64.go
@@ -0,0 +1,35 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// Clone implements linux syscall clone(2).
+// sys_clone has so many flavors. We implement the default one in linux 3.11
+// x86_64:
+// sys_clone(clone_flags, newsp, parent_tidptr, child_tidptr, tls_val)
+func Clone(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := int(args[0].Int())
+ stack := args[1].Pointer()
+ parentTID := args[2].Pointer()
+ childTID := args[3].Pointer()
+ tls := args[4].Pointer()
+ return clone(t, flags, stack, parentTID, childTID, tls)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_clone_arm64.go b/pkg/sentry/syscalls/linux/sys_clone_arm64.go
new file mode 100644
index 000000000..cf68a8949
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_clone_arm64.go
@@ -0,0 +1,35 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// Clone implements linux syscall clone(2).
+// sys_clone has so many flavors, and we implement the default one in linux 3.11
+// arm64(kernel/fork.c with CONFIG_CLONE_BACKWARDS defined in the config file):
+// sys_clone(clone_flags, newsp, parent_tidptr, tls_val, child_tidptr)
+func Clone(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := int(args[0].Int())
+ stack := args[1].Pointer()
+ parentTID := args[2].Pointer()
+ tls := args[3].Pointer()
+ childTID := args[4].Pointer()
+ return clone(t, flags, stack, parentTID, childTID, tls)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 4b5aafcc0..cda517a81 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -41,7 +41,7 @@ const maxListenBacklog = 1024
const maxAddrLen = 200
// maxOptLen is the maximum sockopt parameter length we're willing to accept.
-const maxOptLen = 1024
+const maxOptLen = 1024 * 8
// maxControlLen is the maximum length of the msghdr.msg_control buffer we're
// willing to accept. Note that this limit is smaller than Linux, which allows
diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go
index 4115116ff..b47c3b5c4 100644
--- a/pkg/sentry/syscalls/linux/sys_thread.go
+++ b/pkg/sentry/syscalls/linux/sys_thread.go
@@ -220,19 +220,6 @@ func clone(t *kernel.Task, flags int, stack usermem.Addr, parentTID usermem.Addr
return uintptr(ntid), ctrl, err
}
-// Clone implements linux syscall clone(2).
-// sys_clone has so many flavors. We implement the default one in linux 3.11
-// x86_64:
-// sys_clone(clone_flags, newsp, parent_tidptr, child_tidptr, tls_val)
-func Clone(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- flags := int(args[0].Int())
- stack := args[1].Pointer()
- parentTID := args[2].Pointer()
- childTID := args[3].Pointer()
- tls := args[4].Pointer()
- return clone(t, flags, stack, parentTID, childTID, tls)
-}
-
// Fork implements Linux syscall fork(2).
func Fork(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
// "A call to fork() is equivalent to a call to clone(2) specifying flags
diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go
index 97d9a65ea..23d20da6f 100644
--- a/pkg/sentry/syscalls/linux/sys_xattr.go
+++ b/pkg/sentry/syscalls/linux/sys_xattr.go
@@ -25,12 +25,12 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
-// Getxattr implements linux syscall getxattr(2).
-func Getxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+// GetXattr implements linux syscall getxattr(2).
+func GetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
pathAddr := args[0].Pointer()
nameAddr := args[1].Pointer()
valueAddr := args[2].Pointer()
- size := args[3].SizeT()
+ size := uint64(args[3].SizeT())
path, dirPath, err := copyInPath(t, pathAddr, false /* allowEmpty */)
if err != nil {
@@ -39,22 +39,28 @@ func Getxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
valueLen := 0
err = fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
- value, err := getxattr(t, d, dirPath, nameAddr)
+ // If getxattr(2) is called with size 0, the size of the value will be
+ // returned successfully even if it is nonzero. In that case, we need to
+ // retrieve the entire attribute value so we can return the correct size.
+ requestedSize := size
+ if size == 0 || size > linux.XATTR_SIZE_MAX {
+ requestedSize = linux.XATTR_SIZE_MAX
+ }
+
+ value, err := getXattr(t, d, dirPath, nameAddr, uint64(requestedSize))
if err != nil {
return err
}
valueLen = len(value)
- if size == 0 {
- return nil
- }
- if size > linux.XATTR_SIZE_MAX {
- size = linux.XATTR_SIZE_MAX
- }
- if valueLen > int(size) {
+ if uint64(valueLen) > requestedSize {
return syserror.ERANGE
}
+ // Skip copying out the attribute value if size is 0.
+ if size == 0 {
+ return nil
+ }
_, err = t.CopyOutBytes(valueAddr, []byte(value))
return err
})
@@ -64,8 +70,8 @@ func Getxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return uintptr(valueLen), nil, nil
}
-// getxattr implements getxattr from the given *fs.Dirent.
-func getxattr(t *kernel.Task, d *fs.Dirent, dirPath bool, nameAddr usermem.Addr) (string, error) {
+// getXattr implements getxattr(2) from the given *fs.Dirent.
+func getXattr(t *kernel.Task, d *fs.Dirent, dirPath bool, nameAddr usermem.Addr, size uint64) (string, error) {
if dirPath && !fs.IsDir(d.Inode.StableAttr) {
return "", syserror.ENOTDIR
}
@@ -83,15 +89,15 @@ func getxattr(t *kernel.Task, d *fs.Dirent, dirPath bool, nameAddr usermem.Addr)
return "", syserror.EOPNOTSUPP
}
- return d.Inode.Getxattr(name)
+ return d.Inode.GetXattr(t, name, size)
}
-// Setxattr implements linux syscall setxattr(2).
-func Setxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+// SetXattr implements linux syscall setxattr(2).
+func SetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
pathAddr := args[0].Pointer()
nameAddr := args[1].Pointer()
valueAddr := args[2].Pointer()
- size := args[3].SizeT()
+ size := uint64(args[3].SizeT())
flags := args[4].Uint()
path, dirPath, err := copyInPath(t, pathAddr, false /* allowEmpty */)
@@ -104,12 +110,12 @@ func Setxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
- return setxattr(t, d, dirPath, nameAddr, valueAddr, size, flags)
+ return setXattr(t, d, dirPath, nameAddr, valueAddr, uint64(size), flags)
})
}
-// setxattr implements setxattr from the given *fs.Dirent.
-func setxattr(t *kernel.Task, d *fs.Dirent, dirPath bool, nameAddr, valueAddr usermem.Addr, size uint, flags uint32) error {
+// setXattr implements setxattr(2) from the given *fs.Dirent.
+func setXattr(t *kernel.Task, d *fs.Dirent, dirPath bool, nameAddr, valueAddr usermem.Addr, size uint64, flags uint32) error {
if dirPath && !fs.IsDir(d.Inode.StableAttr) {
return syserror.ENOTDIR
}
@@ -136,7 +142,7 @@ func setxattr(t *kernel.Task, d *fs.Dirent, dirPath bool, nameAddr, valueAddr us
return syserror.EOPNOTSUPP
}
- return d.Inode.Setxattr(name, value)
+ return d.Inode.SetXattr(t, d, name, value, flags)
}
func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) {
diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD
index 18e212dff..3cde3a0be 100644
--- a/pkg/sentry/time/BUILD
+++ b/pkg/sentry/time/BUILD
@@ -9,7 +9,7 @@ go_template_instance(
out = "seqatomic_parameters_unsafe.go",
package = "time",
suffix = "Parameters",
- template = "//pkg/syncutil:generic_seqatomic",
+ template = "//pkg/sync:generic_seqatomic",
types = {
"Value": "Parameters",
},
@@ -36,7 +36,7 @@ go_library(
deps = [
"//pkg/log",
"//pkg/metric",
- "//pkg/syncutil",
+ "//pkg/sync",
"//pkg/syserror",
],
)
diff --git a/pkg/sentry/time/calibrated_clock.go b/pkg/sentry/time/calibrated_clock.go
index 318503277..f9a93115d 100644
--- a/pkg/sentry/time/calibrated_clock.go
+++ b/pkg/sentry/time/calibrated_clock.go
@@ -17,11 +17,11 @@
package time
import (
- "sync"
"time"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/metric"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/usage/BUILD b/pkg/sentry/usage/BUILD
index c32fe3241..5518ac3d0 100644
--- a/pkg/sentry/usage/BUILD
+++ b/pkg/sentry/usage/BUILD
@@ -18,5 +18,6 @@ go_library(
deps = [
"//pkg/bits",
"//pkg/memutil",
+ "//pkg/sync",
],
)
diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go
index d6ef644d8..538c645eb 100644
--- a/pkg/sentry/usage/memory.go
+++ b/pkg/sentry/usage/memory.go
@@ -17,12 +17,12 @@ package usage
import (
"fmt"
"os"
- "sync"
"sync/atomic"
"syscall"
"gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/memutil"
+ "gvisor.dev/gvisor/pkg/sync"
)
// MemoryKind represents a type of memory used by the application.
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index 4c6aa04a1..35c7be259 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -34,7 +34,7 @@ go_library(
"//pkg/sentry/kernel/auth",
"//pkg/sentry/memmap",
"//pkg/sentry/usermem",
- "//pkg/syncutil",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
@@ -54,6 +54,7 @@ go_test(
"//pkg/sentry/context/contexttest",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
],
)
diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go
index 1bc9c4a38..486a76475 100644
--- a/pkg/sentry/vfs/dentry.go
+++ b/pkg/sentry/vfs/dentry.go
@@ -16,9 +16,9 @@ package vfs
import (
"fmt"
- "sync"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index 66eb57bc2..c00b3c84b 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -17,13 +17,13 @@ package vfs
import (
"bytes"
"io"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go
index adff0b94b..3b933468d 100644
--- a/pkg/sentry/vfs/mount_test.go
+++ b/pkg/sentry/vfs/mount_test.go
@@ -17,8 +17,9 @@ package vfs
import (
"fmt"
"runtime"
- "sync"
"testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
func TestMountTableLookupEmpty(t *testing.T) {
diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go
index ab13fa461..bd90d36c4 100644
--- a/pkg/sentry/vfs/mount_unsafe.go
+++ b/pkg/sentry/vfs/mount_unsafe.go
@@ -26,7 +26,7 @@ import (
"sync/atomic"
"unsafe"
- "gvisor.dev/gvisor/pkg/syncutil"
+ "gvisor.dev/gvisor/pkg/sync"
)
// mountKey represents the location at which a Mount is mounted. It is
@@ -75,7 +75,7 @@ type mountTable struct {
// intrinsics and inline assembly, limiting the performance of this
// approach.)
- seq syncutil.SeqCount
+ seq sync.SeqCount
seed uint32 // for hashing keys
// size holds both length (number of elements) and capacity (number of
diff --git a/pkg/sentry/vfs/pathname.go b/pkg/sentry/vfs/pathname.go
index 8e155654f..cf80df90e 100644
--- a/pkg/sentry/vfs/pathname.go
+++ b/pkg/sentry/vfs/pathname.go
@@ -15,10 +15,9 @@
package vfs
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go
index f1edb0680..d279d05ca 100644
--- a/pkg/sentry/vfs/permissions.go
+++ b/pkg/sentry/vfs/permissions.go
@@ -30,6 +30,26 @@ const (
MayExec = 1
)
+// OnlyRead returns true if access _only_ allows read.
+func (a AccessTypes) OnlyRead() bool {
+ return a == MayRead
+}
+
+// MayRead returns true if access allows read.
+func (a AccessTypes) MayRead() bool {
+ return a&MayRead != 0
+}
+
+// MayWrite returns true if access allows write.
+func (a AccessTypes) MayWrite() bool {
+ return a&MayWrite != 0
+}
+
+// MayExec returns true if access allows exec.
+func (a AccessTypes) MayExec() bool {
+ return a&MayExec != 0
+}
+
// GenericCheckPermissions checks that creds has the given access rights on a
// file with the given permissions, UID, and GID, subject to the rules of
// fs/namei.c:generic_permission(). isDir is true if the file is a directory.
@@ -53,7 +73,7 @@ func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, isDir boo
}
// CAP_DAC_READ_SEARCH allows the caller to read and search arbitrary
// directories, and read arbitrary non-directory files.
- if (isDir && (ats&MayWrite == 0)) || ats == MayRead {
+ if (isDir && !ats.MayWrite()) || ats.OnlyRead() {
if creds.HasCapability(linux.CAP_DAC_READ_SEARCH) {
return nil
}
@@ -61,7 +81,7 @@ func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, isDir boo
// CAP_DAC_OVERRIDE allows arbitrary access to directories, read/write
// access to non-directory files, and execute access to non-directory files
// for which at least one execute bit is set.
- if isDir || (ats&MayExec == 0) || (mode&0111 != 0) {
+ if isDir || !ats.MayExec() || (mode&0111 != 0) {
if creds.HasCapability(linux.CAP_DAC_OVERRIDE) {
return nil
}
diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go
index f0641d314..8a0b382f6 100644
--- a/pkg/sentry/vfs/resolving_path.go
+++ b/pkg/sentry/vfs/resolving_path.go
@@ -16,11 +16,11 @@ package vfs
import (
"fmt"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index ea2db7031..1f21b0b31 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -29,12 +29,12 @@ package vfs
import (
"fmt"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/watchdog/BUILD b/pkg/sentry/watchdog/BUILD
index 4d8435265..28f21f13d 100644
--- a/pkg/sentry/watchdog/BUILD
+++ b/pkg/sentry/watchdog/BUILD
@@ -13,5 +13,6 @@ go_library(
"//pkg/metric",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/time",
+ "//pkg/sync",
],
)
diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go
index 5e4611333..bfb2fac26 100644
--- a/pkg/sentry/watchdog/watchdog.go
+++ b/pkg/sentry/watchdog/watchdog.go
@@ -32,7 +32,6 @@ package watchdog
import (
"bytes"
"fmt"
- "sync"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -40,6 +39,7 @@ import (
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Opts configures the watchdog.
diff --git a/pkg/sleep/sleep_test.go b/pkg/sleep/sleep_test.go
index 130806c86..af47e2ba1 100644
--- a/pkg/sleep/sleep_test.go
+++ b/pkg/sleep/sleep_test.go
@@ -376,6 +376,37 @@ func TestRace(t *testing.T) {
}
}
+// TestRaceInOrder tests that multiple wakers can continuously send wake requests to
+// the sleeper and that the wakers are retrieved in the order asserted.
+func TestRaceInOrder(t *testing.T) {
+ const wakers = 100
+ const wakeRequests = 10000
+
+ w := make([]Waker, wakers)
+ s := Sleeper{}
+
+ // Associate each waker and start goroutines that will assert them.
+ for i := range w {
+ s.AddWaker(&w[i], i)
+ }
+ go func() {
+ n := 0
+ for n < wakeRequests {
+ wk := w[n%len(w)]
+ wk.Assert()
+ n++
+ }
+ }()
+
+ // Wait for all wake up notifications from all wakers.
+ for i := 0; i < wakeRequests; i++ {
+ v, _ := s.Fetch(true)
+ if got, want := v, i%wakers; got != want {
+ t.Fatalf("got %d want %d", got, want)
+ }
+ }
+}
+
// BenchmarkSleeperMultiSelect measures how long it takes to fetch a wake up
// from 4 wakers when at least one is already asserted.
func BenchmarkSleeperMultiSelect(b *testing.B) {
diff --git a/pkg/syncutil/BUILD b/pkg/sync/BUILD
index cb1f41628..e8cd16b8f 100644
--- a/pkg/syncutil/BUILD
+++ b/pkg/sync/BUILD
@@ -29,8 +29,9 @@ go_template(
)
go_library(
- name = "syncutil",
+ name = "sync",
srcs = [
+ "aliases.go",
"downgradable_rwmutex_unsafe.go",
"memmove_unsafe.go",
"norace_unsafe.go",
@@ -38,15 +39,15 @@ go_library(
"seqcount.go",
"syncutil.go",
],
- importpath = "gvisor.dev/gvisor/pkg/syncutil",
+ importpath = "gvisor.dev/gvisor/pkg/sync",
)
go_test(
- name = "syncutil_test",
+ name = "sync_test",
size = "small",
srcs = [
"downgradable_rwmutex_test.go",
"seqcount_test.go",
],
- embed = [":syncutil"],
+ embed = [":sync"],
)
diff --git a/pkg/syncutil/LICENSE b/pkg/sync/LICENSE
index 6a66aea5e..6a66aea5e 100644
--- a/pkg/syncutil/LICENSE
+++ b/pkg/sync/LICENSE
diff --git a/pkg/syncutil/README.md b/pkg/sync/README.md
index 2183c4e20..2183c4e20 100644
--- a/pkg/syncutil/README.md
+++ b/pkg/sync/README.md
diff --git a/pkg/sync/aliases.go b/pkg/sync/aliases.go
new file mode 100644
index 000000000..20c7ca041
--- /dev/null
+++ b/pkg/sync/aliases.go
@@ -0,0 +1,37 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sync
+
+import (
+ "sync"
+)
+
+// Aliases of standard library types.
+type (
+ // Mutex is an alias of sync.Mutex.
+ Mutex = sync.Mutex
+
+ // RWMutex is an alias of sync.RWMutex.
+ RWMutex = sync.RWMutex
+
+ // Cond is an alias of sync.Cond.
+ Cond = sync.Cond
+
+ // Locker is an alias of sync.Locker.
+ Locker = sync.Locker
+
+ // Once is an alias of sync.Once.
+ Once = sync.Once
+
+ // Pool is an alias of sync.Pool.
+ Pool = sync.Pool
+
+ // WaitGroup is an alias of sync.WaitGroup.
+ WaitGroup = sync.WaitGroup
+
+ // Map is an alias of sync.Map.
+ Map = sync.Map
+)
diff --git a/pkg/syncutil/atomicptr_unsafe.go b/pkg/sync/atomicptr_unsafe.go
index 525c4beed..525c4beed 100644
--- a/pkg/syncutil/atomicptr_unsafe.go
+++ b/pkg/sync/atomicptr_unsafe.go
diff --git a/pkg/syncutil/atomicptrtest/BUILD b/pkg/sync/atomicptrtest/BUILD
index 63f411a90..418eda29c 100644
--- a/pkg/syncutil/atomicptrtest/BUILD
+++ b/pkg/sync/atomicptrtest/BUILD
@@ -9,7 +9,7 @@ go_template_instance(
out = "atomicptr_int_unsafe.go",
package = "atomicptr",
suffix = "Int",
- template = "//pkg/syncutil:generic_atomicptr",
+ template = "//pkg/sync:generic_atomicptr",
types = {
"Value": "int",
},
@@ -18,7 +18,7 @@ go_template_instance(
go_library(
name = "atomicptr",
srcs = ["atomicptr_int_unsafe.go"],
- importpath = "gvisor.dev/gvisor/pkg/syncutil/atomicptr",
+ importpath = "gvisor.dev/gvisor/pkg/sync/atomicptr",
)
go_test(
diff --git a/pkg/syncutil/atomicptrtest/atomicptr_test.go b/pkg/sync/atomicptrtest/atomicptr_test.go
index 8fdc5112e..8fdc5112e 100644
--- a/pkg/syncutil/atomicptrtest/atomicptr_test.go
+++ b/pkg/sync/atomicptrtest/atomicptr_test.go
diff --git a/pkg/syncutil/downgradable_rwmutex_test.go b/pkg/sync/downgradable_rwmutex_test.go
index ffaf7ecc7..f04496bc5 100644
--- a/pkg/syncutil/downgradable_rwmutex_test.go
+++ b/pkg/sync/downgradable_rwmutex_test.go
@@ -9,7 +9,7 @@
// addition of downgradingWriter and the renaming of num_iterations to
// numIterations to shut up Golint.
-package syncutil
+package sync
import (
"fmt"
diff --git a/pkg/syncutil/downgradable_rwmutex_unsafe.go b/pkg/sync/downgradable_rwmutex_unsafe.go
index 51e11555d..9bb55cd3a 100644
--- a/pkg/syncutil/downgradable_rwmutex_unsafe.go
+++ b/pkg/sync/downgradable_rwmutex_unsafe.go
@@ -16,7 +16,7 @@
// - RUnlock -> Lock (via writerSem)
// - DowngradeLock -> RLock (via readerSem)
-package syncutil
+package sync
import (
"sync"
diff --git a/pkg/syncutil/memmove_unsafe.go b/pkg/sync/memmove_unsafe.go
index 348675baa..ad4a3a37e 100644
--- a/pkg/syncutil/memmove_unsafe.go
+++ b/pkg/sync/memmove_unsafe.go
@@ -8,7 +8,7 @@
// Check go:linkname function signatures when updating Go version.
-package syncutil
+package sync
import (
"unsafe"
diff --git a/pkg/syncutil/norace_unsafe.go b/pkg/sync/norace_unsafe.go
index 0a0a9deda..006055dd6 100644
--- a/pkg/syncutil/norace_unsafe.go
+++ b/pkg/sync/norace_unsafe.go
@@ -5,7 +5,7 @@
// +build !race
-package syncutil
+package sync
import (
"unsafe"
diff --git a/pkg/syncutil/race_unsafe.go b/pkg/sync/race_unsafe.go
index 206067ec1..31d8fa9a6 100644
--- a/pkg/syncutil/race_unsafe.go
+++ b/pkg/sync/race_unsafe.go
@@ -5,7 +5,7 @@
// +build race
-package syncutil
+package sync
import (
"runtime"
diff --git a/pkg/syncutil/seqatomic_unsafe.go b/pkg/sync/seqatomic_unsafe.go
index cb6d2eb22..eda6fb131 100644
--- a/pkg/syncutil/seqatomic_unsafe.go
+++ b/pkg/sync/seqatomic_unsafe.go
@@ -13,7 +13,7 @@ import (
"strings"
"unsafe"
- "gvisor.dev/gvisor/pkg/syncutil"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Value is a required type parameter.
@@ -26,17 +26,17 @@ type Value struct{}
// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race
// with any writer critical sections in sc.
-func SeqAtomicLoad(sc *syncutil.SeqCount, ptr *Value) Value {
+func SeqAtomicLoad(sc *sync.SeqCount, ptr *Value) Value {
// This function doesn't use SeqAtomicTryLoad because doing so is
// measurably, significantly (~20%) slower; Go is awful at inlining.
var val Value
for {
epoch := sc.BeginRead()
- if syncutil.RaceEnabled {
+ if sync.RaceEnabled {
// runtime.RaceDisable() doesn't actually stop the race detector,
// so it can't help us here. Instead, call runtime.memmove
// directly, which is not instrumented by the race detector.
- syncutil.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
+ sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
} else {
// This is ~40% faster for short reads than going through memmove.
val = *ptr
@@ -52,10 +52,10 @@ func SeqAtomicLoad(sc *syncutil.SeqCount, ptr *Value) Value {
// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read
// would race with a writer critical section, SeqAtomicTryLoad returns
// (unspecified, false).
-func SeqAtomicTryLoad(sc *syncutil.SeqCount, epoch syncutil.SeqCountEpoch, ptr *Value) (Value, bool) {
+func SeqAtomicTryLoad(sc *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value) (Value, bool) {
var val Value
- if syncutil.RaceEnabled {
- syncutil.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
+ if sync.RaceEnabled {
+ sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
} else {
val = *ptr
}
@@ -66,7 +66,7 @@ func init() {
var val Value
typ := reflect.TypeOf(val)
name := typ.Name()
- if ptrs := syncutil.PointersInType(typ, name); len(ptrs) != 0 {
+ if ptrs := sync.PointersInType(typ, name); len(ptrs) != 0 {
panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n")))
}
}
diff --git a/pkg/syncutil/seqatomictest/BUILD b/pkg/sync/seqatomictest/BUILD
index ba18f3238..eba21518d 100644
--- a/pkg/syncutil/seqatomictest/BUILD
+++ b/pkg/sync/seqatomictest/BUILD
@@ -9,7 +9,7 @@ go_template_instance(
out = "seqatomic_int_unsafe.go",
package = "seqatomic",
suffix = "Int",
- template = "//pkg/syncutil:generic_seqatomic",
+ template = "//pkg/sync:generic_seqatomic",
types = {
"Value": "int",
},
@@ -18,9 +18,9 @@ go_template_instance(
go_library(
name = "seqatomic",
srcs = ["seqatomic_int_unsafe.go"],
- importpath = "gvisor.dev/gvisor/pkg/syncutil/seqatomic",
+ importpath = "gvisor.dev/gvisor/pkg/sync/seqatomic",
deps = [
- "//pkg/syncutil",
+ "//pkg/sync",
],
)
@@ -29,7 +29,5 @@ go_test(
size = "small",
srcs = ["seqatomic_test.go"],
embed = [":seqatomic"],
- deps = [
- "//pkg/syncutil",
- ],
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/syncutil/seqatomictest/seqatomic_test.go b/pkg/sync/seqatomictest/seqatomic_test.go
index b0db44999..2c4568b07 100644
--- a/pkg/syncutil/seqatomictest/seqatomic_test.go
+++ b/pkg/sync/seqatomictest/seqatomic_test.go
@@ -19,11 +19,11 @@ import (
"testing"
"time"
- "gvisor.dev/gvisor/pkg/syncutil"
+ "gvisor.dev/gvisor/pkg/sync"
)
func TestSeqAtomicLoadUncontended(t *testing.T) {
- var seq syncutil.SeqCount
+ var seq sync.SeqCount
const want = 1
data := want
if got := SeqAtomicLoadInt(&seq, &data); got != want {
@@ -32,7 +32,7 @@ func TestSeqAtomicLoadUncontended(t *testing.T) {
}
func TestSeqAtomicLoadAfterWrite(t *testing.T) {
- var seq syncutil.SeqCount
+ var seq sync.SeqCount
var data int
const want = 1
seq.BeginWrite()
@@ -44,7 +44,7 @@ func TestSeqAtomicLoadAfterWrite(t *testing.T) {
}
func TestSeqAtomicLoadDuringWrite(t *testing.T) {
- var seq syncutil.SeqCount
+ var seq sync.SeqCount
var data int
const want = 1
seq.BeginWrite()
@@ -59,7 +59,7 @@ func TestSeqAtomicLoadDuringWrite(t *testing.T) {
}
func TestSeqAtomicTryLoadUncontended(t *testing.T) {
- var seq syncutil.SeqCount
+ var seq sync.SeqCount
const want = 1
data := want
epoch := seq.BeginRead()
@@ -69,7 +69,7 @@ func TestSeqAtomicTryLoadUncontended(t *testing.T) {
}
func TestSeqAtomicTryLoadDuringWrite(t *testing.T) {
- var seq syncutil.SeqCount
+ var seq sync.SeqCount
var data int
epoch := seq.BeginRead()
seq.BeginWrite()
@@ -80,7 +80,7 @@ func TestSeqAtomicTryLoadDuringWrite(t *testing.T) {
}
func TestSeqAtomicTryLoadAfterWrite(t *testing.T) {
- var seq syncutil.SeqCount
+ var seq sync.SeqCount
var data int
epoch := seq.BeginRead()
seq.BeginWrite()
@@ -91,7 +91,7 @@ func TestSeqAtomicTryLoadAfterWrite(t *testing.T) {
}
func BenchmarkSeqAtomicLoadIntUncontended(b *testing.B) {
- var seq syncutil.SeqCount
+ var seq sync.SeqCount
const want = 42
data := want
b.RunParallel(func(pb *testing.PB) {
@@ -104,7 +104,7 @@ func BenchmarkSeqAtomicLoadIntUncontended(b *testing.B) {
}
func BenchmarkSeqAtomicTryLoadIntUncontended(b *testing.B) {
- var seq syncutil.SeqCount
+ var seq sync.SeqCount
const want = 42
data := want
b.RunParallel(func(pb *testing.PB) {
diff --git a/pkg/syncutil/seqcount.go b/pkg/sync/seqcount.go
index 11d8dbfaa..a1e895352 100644
--- a/pkg/syncutil/seqcount.go
+++ b/pkg/sync/seqcount.go
@@ -3,7 +3,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-package syncutil
+package sync
import (
"fmt"
diff --git a/pkg/syncutil/seqcount_test.go b/pkg/sync/seqcount_test.go
index 14d6aedea..6eb7b4b59 100644
--- a/pkg/syncutil/seqcount_test.go
+++ b/pkg/sync/seqcount_test.go
@@ -3,7 +3,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-package syncutil
+package sync
import (
"reflect"
diff --git a/pkg/syncutil/syncutil.go b/pkg/sync/syncutil.go
index 66e750d06..b16cf5333 100644
--- a/pkg/syncutil/syncutil.go
+++ b/pkg/sync/syncutil.go
@@ -3,5 +3,5 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Package syncutil provides synchronization primitives.
-package syncutil
+// Package sync provides synchronization primitives.
+package sync
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index e07ebd153..ebc8d0209 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -15,6 +15,7 @@ go_library(
importpath = "gvisor.dev/gvisor/pkg/tcpip",
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip/buffer",
"//pkg/tcpip/iptables",
"//pkg/waiter",
@@ -29,7 +30,7 @@ go_test(
)
go_test(
- name = "timer_test",
+ name = "tcpip_x_test",
size = "small",
srcs = ["timer_test.go"],
deps = [":tcpip"],
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
index 78df5a0b1..3df7d18d3 100644
--- a/pkg/tcpip/adapters/gonet/BUILD
+++ b/pkg/tcpip/adapters/gonet/BUILD
@@ -9,6 +9,7 @@ go_library(
importpath = "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet",
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index cd6ce930a..a2f44b496 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -20,9 +20,9 @@ import (
"errors"
"io"
"net"
- "sync"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 2f15bf1f1..885d773b0 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -33,6 +33,9 @@ type NetworkChecker func(*testing.T, []header.Network)
// TransportChecker is a function to check a property of a transport packet.
type TransportChecker func(*testing.T, header.Transport)
+// ControlMessagesChecker is a function to check a property of ancillary data.
+type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages)
+
// IPv4 checks the validity and properties of the given IPv4 packet. It is
// expected to be used in conjunction with other network checkers for specific
// properties. For example, to check the source and destination address, one
@@ -158,6 +161,19 @@ func FragmentFlags(flags uint8) NetworkChecker {
}
}
+// ReceiveTOS creates a checker that checks the TOS field in ControlMessages.
+func ReceiveTOS(want uint8) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasTOS {
+ t.Fatalf("got cm.HasTOS = %t, want cm.TOS = %d", cm.HasTOS, want)
+ }
+ if got := cm.TOS; got != want {
+ t.Fatalf("got cm.TOS = %d, want %d", got, want)
+ }
+ }
+}
+
// TOS creates a checker that checks the TOS field.
func TOS(tos uint8, label uint32) NetworkChecker {
return func(t *testing.T, h []header.Network) {
@@ -754,3 +770,9 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
}
}
}
+
+// NDPRS creates a checker that checks that the packet contains a valid NDP
+// Router Solicitation message (as per the raw wire format).
+func NDPRS() NetworkChecker {
+ return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize)
+}
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index f2061c778..cd747d100 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -20,6 +20,7 @@ go_library(
"ndp_neighbor_solicit.go",
"ndp_options.go",
"ndp_router_advert.go",
+ "ndp_router_solicit.go",
"tcp.go",
"udp.go",
],
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 135a60b12..70e6ce095 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -84,6 +84,13 @@ const (
// The address is ff02::1.
IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ // IPv6AllRoutersMulticastAddress is a link-local multicast group that
+ // all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets
+ // destined to this address will reach all routers on a link.
+ //
+ // The address is ff02::2.
+ IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+
// IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460,
// section 5.
IPv6MinimumMTU = 1280
@@ -333,6 +340,17 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool {
return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80
}
+// IsV6UniqueLocalAddress determines if the provided address is an IPv6
+// unique-local address (within the prefix FC00::/7).
+func IsV6UniqueLocalAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv6AddressSize {
+ return false
+ }
+ // According to RFC 4193 section 3.1, a unique local address has the prefix
+ // FC00::/7.
+ return (addr[0] & 0xfe) == 0xfc
+}
+
// AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier
// (IID) to buf as outlined by RFC 7217 and returns the extended buffer.
//
@@ -371,3 +389,35 @@ func LinkLocalAddrWithOpaqueIID(nicName string, dadCounter uint8, secretKey []by
return tcpip.Address(AppendOpaqueInterfaceIdentifier(lladdrb[:IIDOffsetInIPv6Address], IPv6LinkLocalPrefix.Subnet(), nicName, dadCounter, secretKey))
}
+
+// IPv6AddressScope is the scope of an IPv6 address.
+type IPv6AddressScope int
+
+const (
+ // LinkLocalScope indicates a link-local address.
+ LinkLocalScope IPv6AddressScope = iota
+
+ // UniqueLocalScope indicates a unique-local address.
+ UniqueLocalScope
+
+ // GlobalScope indicates a global address.
+ GlobalScope
+)
+
+// ScopeForIPv6Address returns the scope for an IPv6 address.
+func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) {
+ if len(addr) != IPv6AddressSize {
+ return GlobalScope, tcpip.ErrBadAddress
+ }
+
+ switch {
+ case IsV6LinkLocalAddress(addr):
+ return LinkLocalScope, nil
+
+ case IsV6UniqueLocalAddress(addr):
+ return UniqueLocalScope, nil
+
+ default:
+ return GlobalScope, nil
+ }
+}
diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go
index 1994003ed..29f54bc57 100644
--- a/pkg/tcpip/header/ipv6_test.go
+++ b/pkg/tcpip/header/ipv6_test.go
@@ -25,7 +25,13 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
)
-const linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+const (
+ linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+)
func TestEthernetAdddressToModifiedEUI64(t *testing.T) {
expectedIID := [header.IIDSize]byte{0, 2, 3, 255, 254, 4, 5, 6}
@@ -206,3 +212,91 @@ func TestLinkLocalAddrWithOpaqueIID(t *testing.T) {
})
}
}
+
+func TestIsV6UniqueLocalAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ expected bool
+ }{
+ {
+ name: "Valid Unique 1",
+ addr: uniqueLocalAddr1,
+ expected: true,
+ },
+ {
+ name: "Valid Unique 2",
+ addr: uniqueLocalAddr1,
+ expected: true,
+ },
+ {
+ name: "Link Local",
+ addr: linkLocalAddr,
+ expected: false,
+ },
+ {
+ name: "Global",
+ addr: globalAddr,
+ expected: false,
+ },
+ {
+ name: "IPv4",
+ addr: "\x01\x02\x03\x04",
+ expected: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := header.IsV6UniqueLocalAddress(test.addr); got != test.expected {
+ t.Errorf("got header.IsV6UniqueLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected)
+ }
+ })
+ }
+}
+
+func TestScopeForIPv6Address(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ scope header.IPv6AddressScope
+ err *tcpip.Error
+ }{
+ {
+ name: "Unique Local",
+ addr: uniqueLocalAddr1,
+ scope: header.UniqueLocalScope,
+ err: nil,
+ },
+ {
+ name: "Link Local",
+ addr: linkLocalAddr,
+ scope: header.LinkLocalScope,
+ err: nil,
+ },
+ {
+ name: "Global",
+ addr: globalAddr,
+ scope: header.GlobalScope,
+ err: nil,
+ },
+ {
+ name: "IPv4",
+ addr: "\x01\x02\x03\x04",
+ scope: header.GlobalScope,
+ err: tcpip.ErrBadAddress,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ got, err := header.ScopeForIPv6Address(test.addr)
+ if err != test.err {
+ t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (_, %v), want = (_, %v)", test.addr, err, test.err)
+ }
+ if got != test.scope {
+ t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (%d, _), want = (%d, _)", test.addr, got, test.scope)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/ndp_router_solicit.go b/pkg/tcpip/header/ndp_router_solicit.go
new file mode 100644
index 000000000..9e67ba95d
--- /dev/null
+++ b/pkg/tcpip/header/ndp_router_solicit.go
@@ -0,0 +1,36 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+// NDPRouterSolicit is an NDP Router Solicitation message. It will only contain
+// the body of an ICMPv6 packet.
+//
+// See RFC 4861 section 4.1 for more details.
+type NDPRouterSolicit []byte
+
+const (
+ // NDPRSMinimumSize is the minimum size of a valid NDP Router
+ // Solicitation message (body of an ICMPv6 packet).
+ NDPRSMinimumSize = 4
+
+ // ndpRSOptionsOffset is the start of the NDP options in an
+ // NDPRouterSolicit.
+ ndpRSOptionsOffset = 4
+)
+
+// Options returns an NDPOptions of the the options body.
+func (b NDPRouterSolicit) Options() NDPOptions {
+ return NDPOptions(b[ndpRSOptionsOffset:])
+}
diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD
index cc5f531e2..64769c333 100644
--- a/pkg/tcpip/iptables/BUILD
+++ b/pkg/tcpip/iptables/BUILD
@@ -11,5 +11,8 @@ go_library(
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/iptables",
visibility = ["//visibility:public"],
- deps = ["//pkg/tcpip/buffer"],
+ deps = [
+ "//pkg/log",
+ "//pkg/tcpip/buffer",
+ ],
)
diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go
index 68c68d4aa..647970133 100644
--- a/pkg/tcpip/iptables/iptables.go
+++ b/pkg/tcpip/iptables/iptables.go
@@ -16,66 +16,114 @@
// tool.
package iptables
+// Table names.
const (
- tablenameNat = "nat"
- tablenameMangle = "mangle"
+ TablenameNat = "nat"
+ TablenameMangle = "mangle"
+ TablenameFilter = "filter"
)
// Chain names as defined by net/ipv4/netfilter/ip_tables.c.
const (
- chainNamePrerouting = "PREROUTING"
- chainNameInput = "INPUT"
- chainNameForward = "FORWARD"
- chainNameOutput = "OUTPUT"
- chainNamePostrouting = "POSTROUTING"
+ ChainNamePrerouting = "PREROUTING"
+ ChainNameInput = "INPUT"
+ ChainNameForward = "FORWARD"
+ ChainNameOutput = "OUTPUT"
+ ChainNamePostrouting = "POSTROUTING"
)
+// HookUnset indicates that there is no hook set for an entrypoint or
+// underflow.
+const HookUnset = -1
+
// DefaultTables returns a default set of tables. Each chain is set to accept
// all packets.
func DefaultTables() IPTables {
+ // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for
+ // iotas.
return IPTables{
Tables: map[string]Table{
- tablenameNat: Table{
- BuiltinChains: map[Hook]Chain{
- Prerouting: unconditionalAcceptChain(chainNamePrerouting),
- Input: unconditionalAcceptChain(chainNameInput),
- Output: unconditionalAcceptChain(chainNameOutput),
- Postrouting: unconditionalAcceptChain(chainNamePostrouting),
+ TablenameNat: Table{
+ Rules: []Rule{
+ Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: ErrorTarget{}},
+ },
+ BuiltinChains: map[Hook]int{
+ Prerouting: 0,
+ Input: 1,
+ Output: 2,
+ Postrouting: 3,
},
- DefaultTargets: map[Hook]Target{
- Prerouting: UnconditionalAcceptTarget{},
- Input: UnconditionalAcceptTarget{},
- Output: UnconditionalAcceptTarget{},
- Postrouting: UnconditionalAcceptTarget{},
+ Underflows: map[Hook]int{
+ Prerouting: 0,
+ Input: 1,
+ Output: 2,
+ Postrouting: 3,
},
- UserChains: map[string]Chain{},
+ UserChains: map[string]int{},
},
- tablenameMangle: Table{
- BuiltinChains: map[Hook]Chain{
- Prerouting: unconditionalAcceptChain(chainNamePrerouting),
- Output: unconditionalAcceptChain(chainNameOutput),
+ TablenameMangle: Table{
+ Rules: []Rule{
+ Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: ErrorTarget{}},
+ },
+ BuiltinChains: map[Hook]int{
+ Prerouting: 0,
+ Output: 1,
},
- DefaultTargets: map[Hook]Target{
- Prerouting: UnconditionalAcceptTarget{},
- Output: UnconditionalAcceptTarget{},
+ Underflows: map[Hook]int{
+ Prerouting: 0,
+ Output: 1,
},
- UserChains: map[string]Chain{},
+ UserChains: map[string]int{},
+ },
+ TablenameFilter: Table{
+ Rules: []Rule{
+ Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: UnconditionalAcceptTarget{}},
+ Rule{Target: ErrorTarget{}},
+ },
+ BuiltinChains: map[Hook]int{
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ },
+ Underflows: map[Hook]int{
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ },
+ UserChains: map[string]int{},
},
},
Priorities: map[Hook][]string{
- Prerouting: []string{tablenameMangle, tablenameNat},
- Output: []string{tablenameMangle, tablenameNat},
+ Input: []string{TablenameNat, TablenameFilter},
+ Prerouting: []string{TablenameMangle, TablenameNat},
+ Output: []string{TablenameMangle, TablenameNat, TablenameFilter},
},
}
}
-func unconditionalAcceptChain(name string) Chain {
- return Chain{
- Name: name,
- Rules: []Rule{
- Rule{
- Target: UnconditionalAcceptTarget{},
- },
+// EmptyFilterTable returns a Table with no rules and the filter table chains
+// mapped to HookUnset.
+func EmptyFilterTable() Table {
+ return Table{
+ Rules: []Rule{},
+ BuiltinChains: map[Hook]int{
+ Input: HookUnset,
+ Forward: HookUnset,
+ Output: HookUnset,
+ },
+ Underflows: map[Hook]int{
+ Input: HookUnset,
+ Forward: HookUnset,
+ Output: HookUnset,
},
+ UserChains: map[string]int{},
}
}
diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go
index 19a7f77e3..b94a4c941 100644
--- a/pkg/tcpip/iptables/targets.go
+++ b/pkg/tcpip/iptables/targets.go
@@ -16,7 +16,10 @@
package iptables
-import "gvisor.dev/gvisor/pkg/tcpip/buffer"
+import (
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
// UnconditionalAcceptTarget accepts all packets.
type UnconditionalAcceptTarget struct{}
@@ -33,3 +36,14 @@ type UnconditionalDropTarget struct{}
func (UnconditionalDropTarget) Action(packet buffer.VectorisedView) (Verdict, string) {
return Drop, ""
}
+
+// ErrorTarget logs an error and drops the packet. It represents a target that
+// should be unreachable.
+type ErrorTarget struct{}
+
+// Action implements Target.Action.
+func (ErrorTarget) Action(packet buffer.VectorisedView) (Verdict, string) {
+ log.Warningf("ErrorTarget triggered.")
+ return Drop, ""
+
+}
diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go
index 42a79ef9f..540f8c0b4 100644
--- a/pkg/tcpip/iptables/types.go
+++ b/pkg/tcpip/iptables/types.go
@@ -61,9 +61,12 @@ const (
type Verdict int
const (
+ // Invalid indicates an unkonwn or erroneous verdict.
+ Invalid Verdict = iota
+
// Accept indicates the packet should continue traversing netstack as
// normal.
- Accept Verdict = iota
+ Accept
// Drop inicates the packet should be dropped, stopping traversing
// netstack.
@@ -104,29 +107,22 @@ type IPTables struct {
Priorities map[Hook][]string
}
-// A Table defines a set of chains and hooks into the network stack. The
-// currently supported tables are:
-// * nat
-// * mangle
+// A Table defines a set of chains and hooks into the network stack. It is
+// really just a list of rules with some metadata for entrypoints and such.
type Table struct {
- // BuiltinChains holds the un-deletable chains built into netstack. If
- // a hook isn't present in the map, this table doesn't utilize that
- // hook.
- BuiltinChains map[Hook]Chain
+ // Rules holds the rules that make up the table.
+ Rules []Rule
- // DefaultTargets holds a target for each hook that will be executed if
- // chain traversal doesn't yield a verdict.
- DefaultTargets map[Hook]Target
+ // BuiltinChains maps builtin chains to their entrypoint rule in Rules.
+ BuiltinChains map[Hook]int
+
+ // Underflows maps builtin chains to their underflow rule in Rules
+ // (i.e. the rule to execute if the chain returns without a verdict).
+ Underflows map[Hook]int
// UserChains holds user-defined chains for the keyed by name. Users
// can give their chains arbitrary names.
- UserChains map[string]Chain
-
- // Chains maps names to chains for both builtin and user-defined chains.
- // Its entries point to Chains already either in BuiltinChains or
- // UserChains, and its purpose is to make looking up tables by name
- // fast.
- Chains map[string]*Chain
+ UserChains map[string]int
// Metadata holds information about the Table that is useful to users
// of IPTables, but not to the netstack IPTables code itself.
@@ -152,21 +148,6 @@ func (table *Table) SetMetadata(metadata interface{}) {
table.metadata = metadata
}
-// A Chain defines a list of rules for packet processing. When a packet
-// traverses a chain, it is checked against each rule until either a rule
-// returns a verdict or the chain ends.
-//
-// By convention, builtin chains end with a rule that matches everything and
-// returns either Accept or Drop. User-defined chains end with Return. These
-// aren't strictly necessary here, but the iptables tool writes tables this way.
-type Chain struct {
- // Name is the chain name.
- Name string
-
- // Rules is the list of rules to traverse.
- Rules []Rule
-}
-
// A Rule is a packet processing rule. It consists of two pieces. First it
// contains zero or more matchers, each of which is a specification of which
// packets this rule applies to. If there are no matchers in the rule, it
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index 897c94821..66cc53ed4 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -16,6 +16,7 @@ go_library(
importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased",
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index fa8a703d9..b7f60178e 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -41,10 +41,10 @@ package fdbased
import (
"fmt"
- "sync"
"syscall"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD
index a4f9cdd69..09165dd4c 100644
--- a/pkg/tcpip/link/sharedmem/BUILD
+++ b/pkg/tcpip/link/sharedmem/BUILD
@@ -15,6 +15,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/log",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
@@ -31,6 +32,7 @@ go_test(
],
embed = [":sharedmem"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD
index 6b5bc542c..a0d4ad0be 100644
--- a/pkg/tcpip/link/sharedmem/pipe/BUILD
+++ b/pkg/tcpip/link/sharedmem/pipe/BUILD
@@ -21,4 +21,5 @@ go_test(
"pipe_test.go",
],
embed = [":pipe"],
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
index 59ef69a8b..dc239a0d0 100644
--- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
+++ b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
@@ -18,8 +18,9 @@ import (
"math/rand"
"reflect"
"runtime"
- "sync"
"testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
func TestSimpleReadWrite(t *testing.T) {
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 080f9d667..655e537c4 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -23,11 +23,11 @@
package sharedmem
import (
- "sync"
"sync/atomic"
"syscall"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 89603c48f..5c729a439 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -22,11 +22,11 @@ import (
"math/rand"
"os"
"strings"
- "sync"
"syscall"
"testing"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index acf1e022c..ed16076fd 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -28,6 +28,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/log",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
],
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index 6da5238ec..92f2aa13a 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -19,9 +19,9 @@ package fragmentation
import (
"fmt"
"log"
- "sync"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index 9e002e396..0a83d81f2 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -18,9 +18,9 @@ import (
"container/heap"
"fmt"
"math"
- "sync"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
index e156b01f6..a6ef3bdcc 100644
--- a/pkg/tcpip/ports/BUILD
+++ b/pkg/tcpip/ports/BUILD
@@ -9,6 +9,7 @@ go_library(
importpath = "gvisor.dev/gvisor/pkg/tcpip/ports",
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
],
)
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index 6c5e19e8f..b937cb84b 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -18,9 +18,9 @@ package ports
import (
"math"
"math/rand"
- "sync"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index b8f9517d0..783351a69 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -36,6 +36,7 @@ go_library(
"//pkg/ilist",
"//pkg/rand",
"//pkg/sleep",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/hash/jenkins",
@@ -50,7 +51,7 @@ go_library(
go_test(
name = "stack_x_test",
- size = "small",
+ size = "medium",
srcs = [
"ndp_test.go",
"stack_test.go",
@@ -83,6 +84,7 @@ go_test(
embed = [":stack"],
deps = [
"//pkg/sleep",
+ "//pkg/sync",
"//pkg/tcpip",
],
)
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index 267df60d1..403557fd7 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -16,10 +16,10 @@ package stack
import (
"fmt"
- "sync"
"time"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 9946b8fe8..1baa498d0 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -16,12 +16,12 @@ package stack
import (
"fmt"
- "sync"
"sync/atomic"
"testing"
"time"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index 35825ebf7..c99d387d5 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -17,6 +17,7 @@ package stack
import (
"fmt"
"log"
+ "math/rand"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -38,24 +39,36 @@ const (
// Default = 1s (from RFC 4861 section 10).
defaultRetransmitTimer = time.Second
+ // defaultMaxRtrSolicitations is the default number of Router
+ // Solicitation messages to send when a NIC becomes enabled.
+ //
+ // Default = 3 (from RFC 4861 section 10).
+ defaultMaxRtrSolicitations = 3
+
+ // defaultRtrSolicitationInterval is the default amount of time between
+ // sending Router Solicitation messages.
+ //
+ // Default = 4s (from 4861 section 10).
+ defaultRtrSolicitationInterval = 4 * time.Second
+
+ // defaultMaxRtrSolicitationDelay is the default maximum amount of time
+ // to wait before sending the first Router Solicitation message.
+ //
+ // Default = 1s (from 4861 section 10).
+ defaultMaxRtrSolicitationDelay = time.Second
+
// defaultHandleRAs is the default configuration for whether or not to
// handle incoming Router Advertisements as a host.
- //
- // Default = true.
defaultHandleRAs = true
// defaultDiscoverDefaultRouters is the default configuration for
// whether or not to discover default routers from incoming Router
// Advertisements, as a host.
- //
- // Default = true.
defaultDiscoverDefaultRouters = true
// defaultDiscoverOnLinkPrefixes is the default configuration for
// whether or not to discover on-link prefixes from incoming Router
// Advertisements' Prefix Information option, as a host.
- //
- // Default = true.
defaultDiscoverOnLinkPrefixes = true
// defaultAutoGenGlobalAddresses is the default configuration for
@@ -74,26 +87,31 @@ const (
// value of 0 means unspecified, so the smallest valid value is 1.
// Note, the unit of the RetransmitTimer field in the Router
// Advertisement is milliseconds.
- //
- // Min = 1ms.
minimumRetransmitTimer = time.Millisecond
+ // minimumRtrSolicitationInterval is the minimum amount of time to wait
+ // between sending Router Solicitation messages. This limit is imposed
+ // to make sure that Router Solicitation messages are not sent all at
+ // once, defeating the purpose of sending the initial few messages.
+ minimumRtrSolicitationInterval = 500 * time.Millisecond
+
+ // minimumMaxRtrSolicitationDelay is the minimum amount of time to wait
+ // before sending the first Router Solicitation message. It is 0 because
+ // we cannot have a negative delay.
+ minimumMaxRtrSolicitationDelay = 0
+
// MaxDiscoveredDefaultRouters is the maximum number of discovered
// default routers. The stack should stop discovering new routers after
// discovering MaxDiscoveredDefaultRouters routers.
//
// This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and
// SHOULD be more.
- //
- // Max = 10.
MaxDiscoveredDefaultRouters = 10
// MaxDiscoveredOnLinkPrefixes is the maximum number of discovered
// on-link prefixes. The stack should stop discovering new on-link
// prefixes after discovering MaxDiscoveredOnLinkPrefixes on-link
// prefixes.
- //
- // Max = 10.
MaxDiscoveredOnLinkPrefixes = 10
// validPrefixLenForAutoGen is the expected prefix length that an
@@ -115,6 +133,30 @@ var (
MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour
)
+// DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an
+// NDP Router Advertisement informed the Stack about.
+type DHCPv6ConfigurationFromNDPRA int
+
+const (
+ // DHCPv6NoConfiguration indicates that no configurations are available via
+ // DHCPv6.
+ DHCPv6NoConfiguration DHCPv6ConfigurationFromNDPRA = iota
+
+ // DHCPv6ManagedAddress indicates that addresses are available via DHCPv6.
+ //
+ // DHCPv6ManagedAddress also implies DHCPv6OtherConfigurations because DHCPv6
+ // will return all available configuration information.
+ DHCPv6ManagedAddress
+
+ // DHCPv6OtherConfigurations indicates that other configuration information is
+ // available via DHCPv6.
+ //
+ // Other configurations are configurations other than addresses. Examples of
+ // other configurations are recursive DNS server list, DNS search lists and
+ // default gateway.
+ DHCPv6OtherConfigurations
+)
+
// NDPDispatcher is the interface integrators of netstack must implement to
// receive and handle NDP related events.
type NDPDispatcher interface {
@@ -194,7 +236,20 @@ type NDPDispatcher interface {
// already known DNS servers. If called with known DNS servers, their
// valid lifetimes must be refreshed to lifetime (it may be increased,
// decreased, or completely invalidated when lifetime = 0).
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration)
+
+ // OnDHCPv6Configuration will be called with an updated configuration that is
+ // available via DHCPv6 for a specified NIC.
+ //
+ // NDPDispatcher assumes that the initial configuration available by DHCPv6 is
+ // DHCPv6NoConfiguration.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA)
}
// NDPConfigurations is the NDP configurations for the netstack.
@@ -208,9 +263,24 @@ type NDPConfigurations struct {
// The amount of time to wait between sending Neighbor solicitation
// messages.
//
- // Must be greater than 0.5s.
+ // Must be greater than or equal to 1ms.
RetransmitTimer time.Duration
+ // The number of Router Solicitation messages to send when the NIC
+ // becomes enabled.
+ MaxRtrSolicitations uint8
+
+ // The amount of time between transmitting Router Solicitation messages.
+ //
+ // Must be greater than or equal to 0.5s.
+ RtrSolicitationInterval time.Duration
+
+ // The maximum amount of time before transmitting the first Router
+ // Solicitation message.
+ //
+ // Must be greater than or equal to 0s.
+ MaxRtrSolicitationDelay time.Duration
+
// HandleRAs determines whether or not Router Advertisements will be
// processed.
HandleRAs bool
@@ -241,12 +311,15 @@ type NDPConfigurations struct {
// default values.
func DefaultNDPConfigurations() NDPConfigurations {
return NDPConfigurations{
- DupAddrDetectTransmits: defaultDupAddrDetectTransmits,
- RetransmitTimer: defaultRetransmitTimer,
- HandleRAs: defaultHandleRAs,
- DiscoverDefaultRouters: defaultDiscoverDefaultRouters,
- DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes,
- AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses,
+ DupAddrDetectTransmits: defaultDupAddrDetectTransmits,
+ RetransmitTimer: defaultRetransmitTimer,
+ MaxRtrSolicitations: defaultMaxRtrSolicitations,
+ RtrSolicitationInterval: defaultRtrSolicitationInterval,
+ MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay,
+ HandleRAs: defaultHandleRAs,
+ DiscoverDefaultRouters: defaultDiscoverDefaultRouters,
+ DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes,
+ AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses,
}
}
@@ -255,10 +328,24 @@ func DefaultNDPConfigurations() NDPConfigurations {
//
// If RetransmitTimer is less than minimumRetransmitTimer, then a value of
// defaultRetransmitTimer will be used.
+//
+// If RtrSolicitationInterval is less than minimumRtrSolicitationInterval, then
+// a value of defaultRtrSolicitationInterval will be used.
+//
+// If MaxRtrSolicitationDelay is less than minimumMaxRtrSolicitationDelay, then
+// a value of defaultMaxRtrSolicitationDelay will be used.
func (c *NDPConfigurations) validate() {
if c.RetransmitTimer < minimumRetransmitTimer {
c.RetransmitTimer = defaultRetransmitTimer
}
+
+ if c.RtrSolicitationInterval < minimumRtrSolicitationInterval {
+ c.RtrSolicitationInterval = defaultRtrSolicitationInterval
+ }
+
+ if c.MaxRtrSolicitationDelay < minimumMaxRtrSolicitationDelay {
+ c.MaxRtrSolicitationDelay = defaultMaxRtrSolicitationDelay
+ }
}
// ndpState is the per-interface NDP state.
@@ -279,8 +366,15 @@ type ndpState struct {
// Information option.
onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState
+ // The timer used to send the next router solicitation message.
+ // If routers are being solicited, rtrSolicitTimer MUST NOT be nil.
+ rtrSolicitTimer *time.Timer
+
// The addresses generated by SLAAC.
autoGenAddresses map[tcpip.Address]autoGenAddressState
+
+ // The last learned DHCPv6 configuration from an NDP RA.
+ dhcpv6Configuration DHCPv6ConfigurationFromNDPRA
}
// dadState holds the Duplicate Address Detection timer and channel to signal
@@ -461,10 +555,12 @@ func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining u
// address.
panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", ndp.nic.ID(), snmc, addr))
}
+ snmcRef.incRef()
// Use the unspecified address as the source address when performing
// DAD.
r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), snmcRef, false, false)
+ defer r.Release()
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
@@ -533,6 +629,28 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
return
}
+ // Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we
+ // only inform the dispatcher on configuration changes. We do nothing else
+ // with the information.
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ var configuration DHCPv6ConfigurationFromNDPRA
+ switch {
+ case ra.ManagedAddrConfFlag():
+ configuration = DHCPv6ManagedAddress
+
+ case ra.OtherConfFlag():
+ configuration = DHCPv6OtherConfigurations
+
+ default:
+ configuration = DHCPv6NoConfiguration
+ }
+
+ if ndp.dhcpv6Configuration != configuration {
+ ndp.dhcpv6Configuration = configuration
+ ndpDisp.OnDHCPv6Configuration(ndp.nic.ID(), configuration)
+ }
+ }
+
// Is the NIC configured to discover default routers?
if ndp.configs.DiscoverDefaultRouters {
rtr, ok := ndp.defaultRouters[ip]
@@ -876,7 +994,7 @@ func (ndp *ndpState) newAutoGenAddress(prefix tcpip.Subnet, pl, vl time.Duration
// If the preferred lifetime is zero, then the address should be considered
// deprecated.
deprecated := pl == 0
- ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated)
+ ref, err := ndp.nic.addPermanentAddressLocked(protocolAddr, FirstPrimaryEndpoint, slaac, deprecated)
if err != nil {
log.Fatalf("ndp: error when adding address %s: %s", protocolAddr, err)
}
@@ -1070,3 +1188,84 @@ func (ndp *ndpState) cleanupHostOnlyState() {
log.Fatalf("ndp: still have discovered default routers after cleaning up, found = %d", got)
}
}
+
+// startSolicitingRouters starts soliciting routers, as per RFC 4861 section
+// 6.3.7. If routers are already being solicited, this function does nothing.
+//
+// The NIC ndp belongs to MUST be locked.
+func (ndp *ndpState) startSolicitingRouters() {
+ if ndp.rtrSolicitTimer != nil {
+ // We are already soliciting routers.
+ return
+ }
+
+ remaining := ndp.configs.MaxRtrSolicitations
+ if remaining == 0 {
+ return
+ }
+
+ // Calculate the random delay before sending our first RS, as per RFC
+ // 4861 section 6.3.7.
+ var delay time.Duration
+ if ndp.configs.MaxRtrSolicitationDelay > 0 {
+ delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay)))
+ }
+
+ ndp.rtrSolicitTimer = time.AfterFunc(delay, func() {
+ // Send an RS message with the unspecified source address.
+ ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, true)
+ r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false)
+ defer r.Release()
+
+ payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadSize)
+ pkt := header.ICMPv6(hdr.Prepend(payloadSize))
+ pkt.SetType(header.ICMPv6RouterSolicit)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ sent := r.Stats().ICMP.V6PacketsSent
+ if err := r.WritePacket(nil,
+ NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ TOS: DefaultTOS,
+ }, tcpip.PacketBuffer{Header: hdr},
+ ); err != nil {
+ sent.Dropped.Increment()
+ log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err)
+ // Don't send any more messages if we had an error.
+ remaining = 0
+ } else {
+ sent.RouterSolicit.Increment()
+ remaining--
+ }
+
+ ndp.nic.mu.Lock()
+ defer ndp.nic.mu.Unlock()
+ if remaining == 0 {
+ ndp.rtrSolicitTimer = nil
+ } else if ndp.rtrSolicitTimer != nil {
+ // Note, we need to explicitly check to make sure that
+ // the timer field is not nil because if it was nil but
+ // we still reached this point, then we know the NIC
+ // was requested to stop soliciting routers so we don't
+ // need to send the next Router Solicitation message.
+ ndp.rtrSolicitTimer.Reset(ndp.configs.RtrSolicitationInterval)
+ }
+ })
+
+}
+
+// stopSolicitingRouters stops soliciting routers. If routers are not currently
+// being solicited, this function does nothing.
+//
+// The NIC ndp belongs to MUST be locked.
+func (ndp *ndpState) stopSolicitingRouters() {
+ if ndp.rtrSolicitTimer == nil {
+ // Nothing to do.
+ return
+ }
+
+ ndp.rtrSolicitTimer.Stop()
+ ndp.rtrSolicitTimer = nil
+}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index d334af289..1a52e0e68 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -35,12 +35,12 @@ import (
)
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
- linkAddr1 = "\x02\x02\x03\x04\x05\x06"
- linkAddr2 = "\x02\x02\x03\x04\x05\x07"
- linkAddr3 = "\x02\x02\x03\x04\x05\x08"
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
+ linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
defaultTimeout = 100 * time.Millisecond
)
@@ -162,18 +162,24 @@ type ndpRDNSSEvent struct {
rdnss ndpRDNSS
}
+type ndpDHCPv6Event struct {
+ nicID tcpip.NICID
+ configuration stack.DHCPv6ConfigurationFromNDPRA
+}
+
var _ stack.NDPDispatcher = (*ndpDispatcher)(nil)
// ndpDispatcher implements NDPDispatcher so tests can know when various NDP
// related events happen for test purposes.
type ndpDispatcher struct {
- dadC chan ndpDADEvent
- routerC chan ndpRouterEvent
- rememberRouter bool
- prefixC chan ndpPrefixEvent
- rememberPrefix bool
- autoGenAddrC chan ndpAutoGenAddrEvent
- rdnssC chan ndpRDNSSEvent
+ dadC chan ndpDADEvent
+ routerC chan ndpRouterEvent
+ rememberRouter bool
+ prefixC chan ndpPrefixEvent
+ rememberPrefix bool
+ autoGenAddrC chan ndpAutoGenAddrEvent
+ rdnssC chan ndpRDNSSEvent
+ dhcpv6ConfigurationC chan ndpDHCPv6Event
}
// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus.
@@ -280,6 +286,16 @@ func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tc
}
}
+// Implements stack.NDPDispatcher.OnDHCPv6Configuration.
+func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) {
+ if c := n.dhcpv6ConfigurationC; c != nil {
+ c <- ndpDHCPv6Event{
+ nicID,
+ configuration,
+ }
+ }
+}
+
// TestDADResolve tests that an address successfully resolves after performing
// DAD for various values of DupAddrDetectTransmits and RetransmitTimer.
// Included in the subtests is a test to make sure that an invalid
@@ -797,21 +813,32 @@ func TestSetNDPConfigurations(t *testing.T) {
}
}
-// raBufWithOpts returns a valid NDP Router Advertisement with options.
-//
-// Note, raBufWithOpts does not populate any of the RA fields other than the
-// Router Lifetime.
-func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer {
+// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options
+// and DHCPv6 configurations specified.
+func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer {
icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length())
hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
pkt.SetType(header.ICMPv6RouterAdvert)
pkt.SetCode(0)
- ra := header.NDPRouterAdvert(pkt.NDPPayload())
+ raPayload := pkt.NDPPayload()
+ ra := header.NDPRouterAdvert(raPayload)
+ // Populate the Router Lifetime.
+ binary.BigEndian.PutUint16(raPayload[2:], rl)
+ // Populate the Managed Address flag field.
+ if managedAddress {
+ // The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing)
+ // of the RA payload.
+ raPayload[1] |= (1 << 7)
+ }
+ // Populate the Other Configurations flag field.
+ if otherConfigurations {
+ // The Other Configurations flag field is the 6th bit of byte #1
+ // (0-indexing) of the RA payload.
+ raPayload[1] |= (1 << 6)
+ }
opts := ra.Options()
opts.Serialize(optSer)
- // Populate the Router Lifetime.
- binary.BigEndian.PutUint16(pkt.NDPPayload()[2:], rl)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
payloadLength := hdr.UsedLength()
iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
@@ -826,6 +853,23 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ
return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()}
}
+// raBufWithOpts returns a valid NDP Router Advertisement with options.
+//
+// Note, raBufWithOpts does not populate any of the RA fields other than the
+// Router Lifetime.
+func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer {
+ return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer)
+}
+
+// raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related
+// fields set.
+//
+// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the
+// DHCPv6 related ones.
+func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) tcpip.PacketBuffer {
+ return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{})
+}
+
// raBuf returns a valid NDP Router Advertisement.
//
// Note, raBuf does not populate any of the RA fields other than the
@@ -1688,9 +1732,11 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd
return ndpDisp, e, s
}
-// addrForNewConnection returns the local address used when creating a new
-// connection.
-func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address {
+// addrForNewConnectionTo returns the local address used when creating a new
+// connection to addr.
+func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address {
+ t.Helper()
+
wq := waiter.Queue{}
we, ch := waiter.NewChannelEntry(nil)
wq.EventRegister(&we, waiter.EventIn)
@@ -1704,8 +1750,8 @@ func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address {
if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
}
- if err := ep.Connect(dstAddr); err != nil {
- t.Fatalf("ep.Connect(%+v): %s", dstAddr, err)
+ if err := ep.Connect(addr); err != nil {
+ t.Fatalf("ep.Connect(%+v): %s", addr, err)
}
got, err := ep.GetLocalAddress()
if err != nil {
@@ -1714,9 +1760,19 @@ func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address {
return got.Addr
}
+// addrForNewConnection returns the local address used when creating a new
+// connection.
+func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address {
+ t.Helper()
+
+ return addrForNewConnectionTo(t, s, dstAddr)
+}
+
// addrForNewConnectionWithAddr returns the local address used when creating a
// new connection with a specific local address.
func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address {
+ t.Helper()
+
wq := waiter.Queue{}
we, ch := waiter.NewChannelEntry(nil)
wq.EventRegister(&we, waiter.EventIn)
@@ -2389,6 +2445,119 @@ func TestAutoGenAddrRemoval(t *testing.T) {
}
}
+// TestAutoGenAddrAfterRemoval tests adding a SLAAC address that was previously
+// assigned to the NIC but is in the permanentExpired state.
+func TestAutoGenAddrAfterRemoval(t *testing.T) {
+ t.Parallel()
+
+ const nicID = 1
+
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
+
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
+
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
+
+ // Receive a PI to auto-generate addr1 with a large valid and preferred
+ // lifetime.
+ const largeLifetimeSeconds = 999
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ expectPrimaryAddr(addr1)
+
+ // Add addr2 as a static address.
+ protoAddr2 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr2,
+ }
+ if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d, %s) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
+ }
+ // addr2 should be more preferred now since it is at the front of the primary
+ // list.
+ expectPrimaryAddr(addr2)
+
+ // Get a route using addr2 to increment its reference count then remove it
+ // to leave it in the permanentExpired state.
+ r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err)
+ }
+ defer r.Release()
+ if err := s.RemoveAddress(nicID, addr2.Address); err != nil {
+ t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err)
+ }
+ // addr1 should be preferred again since addr2 is in the expired state.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to auto-generate addr2 as valid and preferred.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ // addr2 should be more preferred now that it is closer to the front of the
+ // primary list and not deprecated.
+ expectPrimaryAddr(addr2)
+
+ // Removing the address should result in an invalidation event immediately.
+ // It should still be in the permanentExpired state because r is still held.
+ //
+ // We remove addr2 here to make sure addr2 was marked as a SLAAC address
+ // (it was previously marked as a static address).
+ if err := s.RemoveAddress(1, addr2.Address); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
+ }
+ expectAutoGenAddrEvent(addr2, invalidatedAddr)
+ // addr1 should be more preferred since addr2 is in the expired state.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to auto-generate addr2 as valid and deprecated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ // addr1 should still be more preferred since addr2 is deprecated, even though
+ // it is closer to the front of the primary list.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to refresh addr2's preferred lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto gen addr event")
+ default:
+ }
+ // addr2 should be more preferred now that it is not deprecated.
+ expectPrimaryAddr(addr2)
+
+ if err := s.RemoveAddress(1, addr2.Address); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
+ }
+ expectAutoGenAddrEvent(addr2, invalidatedAddr)
+ expectPrimaryAddr(addr1)
+}
+
// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that
// is already assigned to the NIC, the static address remains.
func TestAutoGenAddrStaticConflict(t *testing.T) {
@@ -2951,3 +3120,318 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) {
default:
}
}
+
+// TestDHCPv6ConfigurationFromNDPDA tests that the NDPDispatcher is properly
+// informed when new information about what configurations are available via
+// DHCPv6 is learned.
+func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
+ const nicID = 1
+
+ ndpDisp := ndpDispatcher{
+ dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1),
+ rememberRouter: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ expectDHCPv6Event := func(configuration stack.DHCPv6ConfigurationFromNDPRA) {
+ t.Helper()
+ select {
+ case e := <-ndpDisp.dhcpv6ConfigurationC:
+ if diff := cmp.Diff(ndpDHCPv6Event{nicID: nicID, configuration: configuration}, e, cmp.AllowUnexported(e)); diff != "" {
+ t.Errorf("dhcpv6 event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DHCPv6 configuration event")
+ }
+ }
+
+ expectNoDHCPv6Event := func() {
+ t.Helper()
+ select {
+ case <-ndpDisp.dhcpv6ConfigurationC:
+ t.Fatal("unexpected DHCPv6 configuration event")
+ default:
+ }
+ }
+
+ // The initial DHCPv6 configuration should be stack.DHCPv6NoConfiguration.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to Other
+ // Configurations.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectDHCPv6Event(stack.DHCPv6OtherConfigurations)
+ // Receiving the same update again should not result in an event to the
+ // NDPDispatcher.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to Managed Address.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
+ expectDHCPv6Event(stack.DHCPv6ManagedAddress)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to none.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
+ expectDHCPv6Event(stack.DHCPv6NoConfiguration)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to Managed Address.
+ //
+ // Note, when the M flag is set, the O flag is redundant.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
+ expectDHCPv6Event(stack.DHCPv6ManagedAddress)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
+ expectNoDHCPv6Event()
+ // Even though the DHCPv6 flags are different, the effective configuration is
+ // the same so we should not receive a new event.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
+ expectNoDHCPv6Event()
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to Other
+ // Configurations.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectDHCPv6Event(stack.DHCPv6OtherConfigurations)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectNoDHCPv6Event()
+}
+
+// TestRouterSolicitation tests the initial Router Solicitations that are sent
+// when a NIC newly becomes enabled.
+func TestRouterSolicitation(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ maxRtrSolicit uint8
+ rtrSolicitInt time.Duration
+ effectiveRtrSolicitInt time.Duration
+ maxRtrSolicitDelay time.Duration
+ effectiveMaxRtrSolicitDelay time.Duration
+ }{
+ {
+ name: "Single RS with delay",
+ maxRtrSolicit: 1,
+ rtrSolicitInt: time.Second,
+ effectiveRtrSolicitInt: time.Second,
+ maxRtrSolicitDelay: time.Second,
+ effectiveMaxRtrSolicitDelay: time.Second,
+ },
+ {
+ name: "Two RS with delay",
+ maxRtrSolicit: 2,
+ rtrSolicitInt: time.Second,
+ effectiveRtrSolicitInt: time.Second,
+ maxRtrSolicitDelay: 500 * time.Millisecond,
+ effectiveMaxRtrSolicitDelay: 500 * time.Millisecond,
+ },
+ {
+ name: "Single RS without delay",
+ maxRtrSolicit: 1,
+ rtrSolicitInt: time.Second,
+ effectiveRtrSolicitInt: time.Second,
+ maxRtrSolicitDelay: 0,
+ effectiveMaxRtrSolicitDelay: 0,
+ },
+ {
+ name: "Two RS without delay and invalid zero interval",
+ maxRtrSolicit: 2,
+ rtrSolicitInt: 0,
+ effectiveRtrSolicitInt: 4 * time.Second,
+ maxRtrSolicitDelay: 0,
+ effectiveMaxRtrSolicitDelay: 0,
+ },
+ {
+ name: "Three RS without delay",
+ maxRtrSolicit: 3,
+ rtrSolicitInt: 500 * time.Millisecond,
+ effectiveRtrSolicitInt: 500 * time.Millisecond,
+ maxRtrSolicitDelay: 0,
+ effectiveMaxRtrSolicitDelay: 0,
+ },
+ {
+ name: "Two RS with invalid negative delay",
+ maxRtrSolicit: 2,
+ rtrSolicitInt: time.Second,
+ effectiveRtrSolicitInt: time.Second,
+ maxRtrSolicitDelay: -3 * time.Second,
+ effectiveMaxRtrSolicitDelay: time.Second,
+ },
+ }
+
+ // This Run will not return until the parallel tests finish.
+ //
+ // We need this because we need to do some teardown work after the
+ // parallel tests complete.
+ //
+ // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
+ // more details.
+ t.Run("group", func(t *testing.T) {
+ for _, test := range tests {
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+ e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1)
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
+ select {
+ case p := <-e.C:
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+ checker.IPv6(t,
+ p.Pkt.Header.View(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS(),
+ )
+
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for packet")
+ }
+ }
+ waitForNothing := func(timeout time.Duration) {
+ t.Helper()
+ select {
+ case <-e.C:
+ t.Fatal("unexpectedly got a packet")
+ case <-time.After(timeout):
+ }
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ MaxRtrSolicitations: test.maxRtrSolicit,
+ RtrSolicitationInterval: test.rtrSolicitInt,
+ MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
+ },
+ })
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Make sure each RS got sent at the right
+ // times.
+ remaining := test.maxRtrSolicit
+ if remaining > 0 {
+ waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultTimeout)
+ remaining--
+ }
+ for ; remaining > 0; remaining-- {
+ waitForNothing(test.effectiveRtrSolicitInt - defaultTimeout)
+ waitForPkt(2 * defaultTimeout)
+ }
+
+ // Make sure no more RS.
+ if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
+ waitForNothing(test.effectiveRtrSolicitInt + defaultTimeout)
+ } else {
+ waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultTimeout)
+ }
+
+ // Make sure the counter got properly
+ // incremented.
+ if got, want := s.Stats().ICMP.V6PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
+ t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
+ }
+ })
+ }
+ })
+}
+
+// TestStopStartSolicitingRouters tests that when forwarding is enabled or
+// disabled, router solicitations are stopped or started, respecitively.
+func TestStopStartSolicitingRouters(t *testing.T) {
+ t.Parallel()
+
+ const interval = 500 * time.Millisecond
+ const delay = time.Second
+ const maxRtrSolicitations = 3
+ e := channel.New(maxRtrSolicitations, 1280, linkAddr1)
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
+ select {
+ case p := <-e.C:
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+ checker.IPv6(t, p.Pkt.Header.View(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS())
+
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for packet")
+ }
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ MaxRtrSolicitations: maxRtrSolicitations,
+ RtrSolicitationInterval: interval,
+ MaxRtrSolicitationDelay: delay,
+ },
+ })
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Enable forwarding which should stop router solicitations.
+ s.SetForwarding(true)
+ select {
+ case <-e.C:
+ // A single RS may have been sent before forwarding was enabled.
+ select {
+ case <-e.C:
+ t.Fatal("Should not have sent more than one RS message")
+ case <-time.After(interval + defaultTimeout):
+ }
+ case <-time.After(delay + defaultTimeout):
+ }
+
+ // Enabling forwarding again should do nothing.
+ s.SetForwarding(true)
+ select {
+ case <-e.C:
+ t.Fatal("unexpectedly got a packet after becoming a router")
+ case <-time.After(delay + defaultTimeout):
+ }
+
+ // Disable forwarding which should start router solicitations.
+ s.SetForwarding(false)
+ waitForPkt(delay + defaultTimeout)
+ waitForPkt(interval + defaultTimeout)
+ waitForPkt(interval + defaultTimeout)
+ select {
+ case <-e.C:
+ t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
+ case <-time.After(interval + defaultTimeout):
+ }
+
+ // Disabling forwarding again should do nothing.
+ s.SetForwarding(false)
+ select {
+ case <-e.C:
+ t.Fatal("unexpectedly got a packet after becoming a router")
+ case <-time.After(delay + defaultTimeout):
+ }
+}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 4144d5d0f..4452a1302 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -15,10 +15,12 @@
package stack
import (
+ "log"
+ "sort"
"strings"
- "sync"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -27,10 +29,11 @@ import (
// NIC represents a "network interface card" to which the networking stack is
// attached.
type NIC struct {
- stack *Stack
- id tcpip.NICID
- name string
- linkEP LinkEndpoint
+ stack *Stack
+ id tcpip.NICID
+ name string
+ linkEP LinkEndpoint
+ context NICContext
mu sync.RWMutex
spoofing bool
@@ -84,7 +87,7 @@ const (
)
// newNIC returns a new NIC using the default NDP configurations from stack.
-func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC {
+func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC {
// TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For
// example, make sure that the link address it provides is a valid
// unicast ethernet address.
@@ -98,6 +101,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC {
id: id,
name: name,
linkEP: ep,
+ context: ctx,
primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint),
endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
mcastJoins: make(map[NetworkEndpointID]int32),
@@ -173,49 +177,72 @@ func (n *NIC) enable() *tcpip.Error {
}
// Do not auto-generate an IPv6 link-local address for loopback devices.
- if !n.stack.autoGenIPv6LinkLocal || n.isLoopback() {
- return nil
- }
+ if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() {
+ var addr tcpip.Address
+ if oIID := n.stack.opaqueIIDOpts; oIID.NICNameFromID != nil {
+ addr = header.LinkLocalAddrWithOpaqueIID(oIID.NICNameFromID(n.ID(), n.name), 0, oIID.SecretKey)
+ } else {
+ l2addr := n.linkEP.LinkAddress()
- var addr tcpip.Address
- if oIID := n.stack.opaqueIIDOpts; oIID.NICNameFromID != nil {
- addr = header.LinkLocalAddrWithOpaqueIID(oIID.NICNameFromID(n.ID(), n.name), 0, oIID.SecretKey)
- } else {
- l2addr := n.linkEP.LinkAddress()
+ // Only attempt to generate the link-local address if we have a valid MAC
+ // address.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
+ // LinkEndpoint.LinkAddress) before reaching this point.
+ if !header.IsValidUnicastEthernetAddress(l2addr) {
+ return nil
+ }
- // Only attempt to generate the link-local address if we have a valid MAC
- // address.
- //
- // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
- // LinkEndpoint.LinkAddress) before reaching this point.
- if !header.IsValidUnicastEthernetAddress(l2addr) {
- return nil
+ addr = header.LinkLocalAddr(l2addr)
}
- addr = header.LinkLocalAddr(l2addr)
+ if _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: addr,
+ PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen,
+ },
+ }, CanBePrimaryEndpoint, static, false /* deprecated */); err != nil {
+ return err
+ }
}
- _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen,
- },
- }, CanBePrimaryEndpoint)
+ // If we are operating as a router, then do not solicit routers since we
+ // won't process the RAs anyways.
+ //
+ // Routers do not process Router Advertisements (RA) the same way a host
+ // does. That is, routers do not learn from RAs (e.g. on-link prefixes
+ // and default routers). Therefore, soliciting RAs from other routers on
+ // a link is unnecessary for routers.
+ if !n.stack.forwarding {
+ n.ndp.startSolicitingRouters()
+ }
- return err
+ return nil
}
// becomeIPv6Router transitions n into an IPv6 router.
//
// When transitioning into an IPv6 router, host-only state (NDP discovered
// routers, discovered on-link prefixes, and auto-generated addresses) will
-// be cleaned up/invalidated.
+// be cleaned up/invalidated and NDP router solicitations will be stopped.
func (n *NIC) becomeIPv6Router() {
n.mu.Lock()
defer n.mu.Unlock()
n.ndp.cleanupHostOnlyState()
+ n.ndp.stopSolicitingRouters()
+}
+
+// becomeIPv6Host transitions n into an IPv6 host.
+//
+// When transitioning into an IPv6 host, NDP router solicitations will be
+// started.
+func (n *NIC) becomeIPv6Host() {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ n.ndp.startSolicitingRouters()
}
// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
@@ -249,13 +276,17 @@ func (n *NIC) setSpoofing(enable bool) {
n.mu.Unlock()
}
-// primaryEndpoint returns the primary endpoint of n for the given network
-// protocol.
-//
// primaryEndpoint will return the first non-deprecated endpoint if such an
-// endpoint exists. If no non-deprecated endpoint exists, the first deprecated
-// endpoint will be returned.
-func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint {
+// endpoint exists for the given protocol and remoteAddr. If no non-deprecated
+// endpoint exists, the first deprecated endpoint will be returned.
+//
+// If an IPv6 primary endpoint is requested, Source Address Selection (as
+// defined by RFC 6724 section 5) will be performed.
+func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) *referencedNetworkEndpoint {
+ if protocol == header.IPv6ProtocolNumber && remoteAddr != "" {
+ return n.primaryIPv6Endpoint(remoteAddr)
+ }
+
n.mu.RLock()
defer n.mu.RUnlock()
@@ -294,6 +325,103 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN
return deprecatedEndpoint
}
+// ipv6AddrCandidate is an IPv6 candidate for Source Address Selection (RFC
+// 6724 section 5).
+type ipv6AddrCandidate struct {
+ ref *referencedNetworkEndpoint
+ scope header.IPv6AddressScope
+}
+
+// primaryIPv6Endpoint returns an IPv6 endpoint following Source Address
+// Selection (RFC 6724 section 5).
+//
+// Note, only rules 1-3 are followed.
+//
+// remoteAddr must be a valid IPv6 address.
+func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ primaryAddrs := n.primary[header.IPv6ProtocolNumber]
+
+ if len(primaryAddrs) == 0 {
+ return nil
+ }
+
+ // Create a candidate set of available addresses we can potentially use as a
+ // source address.
+ cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs))
+ for _, r := range primaryAddrs {
+ // If r is not valid for outgoing connections, it is not a valid endpoint.
+ if !r.isValidForOutgoing() {
+ continue
+ }
+
+ addr := r.ep.ID().LocalAddress
+ scope, err := header.ScopeForIPv6Address(addr)
+ if err != nil {
+ // Should never happen as we got r from the primary IPv6 endpoint list and
+ // ScopeForIPv6Address only returns an error if addr is not an IPv6
+ // address.
+ log.Fatalf("header.ScopeForIPv6Address(%s): %s", addr, err)
+ }
+
+ cs = append(cs, ipv6AddrCandidate{
+ ref: r,
+ scope: scope,
+ })
+ }
+
+ remoteScope, err := header.ScopeForIPv6Address(remoteAddr)
+ if err != nil {
+ // primaryIPv6Endpoint should never be called with an invalid IPv6 address.
+ log.Fatalf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err)
+ }
+
+ // Sort the addresses as per RFC 6724 section 5 rules 1-3.
+ //
+ // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5.
+ sort.Slice(cs, func(i, j int) bool {
+ sa := cs[i]
+ sb := cs[j]
+
+ // Prefer same address as per RFC 6724 section 5 rule 1.
+ if sa.ref.ep.ID().LocalAddress == remoteAddr {
+ return true
+ }
+ if sb.ref.ep.ID().LocalAddress == remoteAddr {
+ return false
+ }
+
+ // Prefer appropriate scope as per RFC 6724 section 5 rule 2.
+ if sa.scope < sb.scope {
+ return sa.scope >= remoteScope
+ } else if sb.scope < sa.scope {
+ return sb.scope < remoteScope
+ }
+
+ // Avoid deprecated addresses as per RFC 6724 section 5 rule 3.
+ if saDep, sbDep := sa.ref.deprecated, sb.ref.deprecated; saDep != sbDep {
+ // If sa is not deprecated, it is preferred over sb.
+ return sbDep
+ }
+
+ // sa and sb are equal, return the endpoint that is closest to the front of
+ // the primary endpoint list.
+ return i < j
+ })
+
+ // Return the most preferred address that can have its reference count
+ // incremented.
+ for _, c := range cs {
+ if r := c.ref; r.tryIncRef() {
+ return r
+ }
+ }
+
+ return nil
+}
+
// hasPermanentAddrLocked returns true if n has a permanent (including currently
// tentative) address, addr.
func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool {
@@ -405,7 +533,12 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
return ref
}
-func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) (*referencedNetworkEndpoint, *tcpip.Error) {
+// addPermanentAddressLocked adds a permanent address to n.
+//
+// If n already has the address in a non-permanent state,
+// addPermanentAddressLocked will promote it to permanent and update the
+// endpoint with the properties provided.
+func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) {
id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address}
if ref, ok := n.endpoints[id]; ok {
switch ref.getKind() {
@@ -413,10 +546,14 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p
// The NIC already have a permanent endpoint with that address.
return nil, tcpip.ErrDuplicateAddress
case permanentExpired, temporary:
- // Promote the endpoint to become permanent and respect
- // the new peb.
+ // Promote the endpoint to become permanent and respect the new peb,
+ // configType and deprecated status.
if ref.tryIncRef() {
+ // TODO(b/147748385): Perform Duplicate Address Detection when promoting
+ // an IPv6 endpoint to permanent.
ref.setKind(permanent)
+ ref.deprecated = deprecated
+ ref.configType = configType
refs := n.primary[ref.protocol]
for i, r := range refs {
@@ -448,9 +585,13 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p
}
}
- return n.addAddressLocked(protocolAddress, peb, permanent, static, false)
+ return n.addAddressLocked(protocolAddress, peb, permanent, configType, deprecated)
}
+// addAddressLocked adds a new protocolAddress to n.
+//
+// If the address is already known by n (irrespective of the state it is in),
+// addAddressLocked does nothing and returns tcpip.ErrDuplicateAddress.
func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) {
// TODO(b/141022673): Validate IP address before adding them.
@@ -525,7 +666,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
// Add the endpoint.
n.mu.Lock()
- _, err := n.addPermanentAddressLocked(protocolAddress, peb)
+ _, err := n.addPermanentAddressLocked(protocolAddress, peb, static, false /* deprecated */)
n.mu.Unlock()
return err
@@ -658,7 +799,7 @@ func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) {
n.mu.Unlock()
}
-// Subnets returns the Subnets associated with this NIC.
+// AddressRanges returns the Subnets associated with this NIC.
func (n *NIC) AddressRanges() []tcpip.Subnet {
n.mu.RLock()
defer n.mu.RUnlock()
@@ -807,7 +948,7 @@ func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.A
Address: addr,
PrefixLen: netProto.DefaultPrefixLen(),
},
- }, NeverPrimaryEndpoint); err != nil {
+ }, NeverPrimaryEndpoint, static, false /* deprecated */); err != nil {
return err
}
}
@@ -1185,7 +1326,8 @@ type referencedNetworkEndpoint struct {
kind networkEndpointKind
// configType is the method that was used to configure this endpoint.
- // This must never change after the endpoint is added to a NIC.
+ // This must never change except during endpoint creation and promotion to
+ // permanent.
configType networkEndpointConfigType
// deprecated indicates whether or not the endpoint should be considered
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index fb7ac409e..fc56a6d79 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -21,13 +21,13 @@ package stack
import (
"encoding/binary"
- "sync"
"sync/atomic"
"time"
"golang.org/x/time/rate"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -547,6 +547,49 @@ type TransportEndpointInfo struct {
RegisterNICID tcpip.NICID
}
+// AddrNetProto unwraps the specified address if it is a V4-mapped V6 address
+// and returns the network protocol number to be used to communicate with the
+// specified address. It returns an error if the passed address is incompatible
+// with the receiver.
+func (e *TransportEndpointInfo) AddrNetProto(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto := e.NetProto
+ switch len(addr.Addr) {
+ case header.IPv4AddressSize:
+ netProto = header.IPv4ProtocolNumber
+ case header.IPv6AddressSize:
+ if header.IsV4MappedAddress(addr.Addr) {
+ netProto = header.IPv4ProtocolNumber
+ addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
+ if addr.Addr == header.IPv4Any {
+ addr.Addr = ""
+ }
+ }
+ }
+
+ switch len(e.ID.LocalAddress) {
+ case header.IPv4AddressSize:
+ if len(addr.Addr) == header.IPv6AddressSize {
+ return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState
+ }
+ case header.IPv6AddressSize:
+ if len(addr.Addr) == header.IPv4AddressSize {
+ return tcpip.FullAddress{}, 0, tcpip.ErrNetworkUnreachable
+ }
+ }
+
+ switch {
+ case netProto == e.NetProto:
+ case netProto == header.IPv4ProtocolNumber && e.NetProto == header.IPv6ProtocolNumber:
+ if v6only {
+ return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute
+ }
+ default:
+ return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState
+ }
+
+ return addr, netProto, nil
+}
+
// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
// marker interface.
func (*TransportEndpointInfo) IsEndpointInfo() {}
@@ -707,7 +750,9 @@ func (s *Stack) Stats() tcpip.Stats {
// SetForwarding enables or disables the packet forwarding between NICs.
//
// When forwarding becomes enabled, any host-only state on all NICs will be
-// cleaned up.
+// cleaned up and if IPv6 is enabled, NDP Router Solicitations will be started.
+// When forwarding becomes disabled and if IPv6 is enabled, NDP Router
+// Solicitations will be stopped.
func (s *Stack) SetForwarding(enable bool) {
// TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
s.mu.Lock()
@@ -729,6 +774,10 @@ func (s *Stack) SetForwarding(enable bool) {
for _, nic := range s.nics {
nic.becomeIPv6Router()
}
+ } else {
+ for _, nic := range s.nics {
+ nic.becomeIPv6Host()
+ }
}
}
@@ -796,6 +845,9 @@ func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNum
return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue)
}
+// NICContext is an opaque pointer used to store client-supplied NIC metadata.
+type NICContext interface{}
+
// NICOptions specifies the configuration of a NIC as it is being created.
// The zero value creates an enabled, unnamed NIC.
type NICOptions struct {
@@ -805,6 +857,12 @@ type NICOptions struct {
// Disabled specifies whether to avoid calling Attach on the passed
// LinkEndpoint.
Disabled bool
+
+ // Context specifies user-defined data that will be returned in stack.NICInfo
+ // for the NIC. Clients of this library can use it to add metadata that
+ // should be tracked alongside a NIC, to avoid having to keep a
+ // map[tcpip.NICID]metadata mirroring stack.Stack's nic map.
+ Context NICContext
}
// CreateNICWithOptions creates a NIC with the provided id, LinkEndpoint, and
@@ -819,7 +877,7 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp
return tcpip.ErrDuplicateNICID
}
- n := newNIC(s, id, opts.Name, ep)
+ n := newNIC(s, id, opts.Name, ep, opts.Context)
s.nics[id] = n
if !opts.Disabled {
@@ -860,7 +918,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool {
return false
}
-// NICSubnets returns a map of NICIDs to their associated subnets.
+// NICAddressRanges returns a map of NICIDs to their associated subnets.
func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -886,6 +944,18 @@ type NICInfo struct {
MTU uint32
Stats NICStats
+
+ // Context is user-supplied data optionally supplied in CreateNICWithOptions.
+ // See type NICOptions for more details.
+ Context NICContext
+}
+
+// HasNIC returns true if the NICID is defined in the stack.
+func (s *Stack) HasNIC(id tcpip.NICID) bool {
+ s.mu.RLock()
+ _, ok := s.nics[id]
+ s.mu.RUnlock()
+ return ok
}
// NICInfo returns a map of NICIDs to their associated information.
@@ -908,6 +978,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
Flags: flags,
MTU: nic.linkEP.MTU(),
Stats: nic.stats,
+ Context: nic.context,
}
}
return nics
@@ -1041,9 +1112,9 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol
return nic.primaryAddress(protocol), nil
}
-func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) {
+func (s *Stack) getRefEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) {
if len(localAddr) == 0 {
- return nic.primaryEndpoint(netProto)
+ return nic.primaryEndpoint(netProto, remoteAddr)
}
return nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint)
}
@@ -1059,7 +1130,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
if id != 0 && !needRoute {
if nic, ok := s.nics[id]; ok {
- if ref := s.getRefEP(nic, localAddr, netProto); ref != nil {
+ if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil
}
}
@@ -1069,7 +1140,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
continue
}
if nic, ok := s.nics[route.NIC]; ok {
- if ref := s.getRefEP(nic, localAddr, netProto); ref != nil {
+ if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
if len(remoteAddr) == 0 {
// If no remote address was provided, then the route
// provided will refer to the link local address.
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 9ac50bb23..4b3d18f1b 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -35,6 +35,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
const (
@@ -2001,6 +2002,46 @@ func TestNICAutoGenAddr(t *testing.T) {
}
}
+// TestNICContextPreservation tests that you can read out via stack.NICInfo the
+// Context data you pass via NICContext.Context in stack.CreateNICWithOptions.
+func TestNICContextPreservation(t *testing.T) {
+ var ctx *int
+ tests := []struct {
+ name string
+ opts stack.NICOptions
+ want stack.NICContext
+ }{
+ {
+ "context_set",
+ stack.NICOptions{Context: ctx},
+ ctx,
+ },
+ {
+ "context_not_set",
+ stack.NICOptions{},
+ nil,
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{})
+ id := tcpip.NICID(1)
+ ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"))
+ if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil {
+ t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err)
+ }
+ nicinfos := s.NICInfo()
+ nicinfo, ok := nicinfos[id]
+ if !ok {
+ t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos)
+ }
+ if got, want := nicinfo.Context == test.want, true; got != want {
+ t.Fatal("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want)
+ }
+ })
+ }
+}
+
// TestNICAutoGenAddrWithOpaque tests the auto-generation of IPv6 link-local
// addresses with opaque interface identifiers. Link Local addresses should
// always be generated with opaque IIDs if configured to use them, even if the
@@ -2371,3 +2412,154 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
}
}
}
+
+func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
+ const (
+ linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ nicID = 1
+ )
+
+ // Rule 3 is not tested here, and is instead tested by NDP's AutoGenAddr test.
+ tests := []struct {
+ name string
+ nicAddrs []tcpip.Address
+ connectAddr tcpip.Address
+ expectedLocalAddr tcpip.Address
+ }{
+ // Test Rule 1 of RFC 6724 section 5.
+ {
+ name: "Same Global most preferred (last address)",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ connectAddr: globalAddr1,
+ expectedLocalAddr: globalAddr1,
+ },
+ {
+ name: "Same Global most preferred (first address)",
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
+ connectAddr: globalAddr1,
+ expectedLocalAddr: globalAddr1,
+ },
+ {
+ name: "Same Link Local most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
+ connectAddr: linkLocalAddr1,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Same Link Local most preferred (first address)",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ connectAddr: linkLocalAddr1,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Same Unique Local most preferred (last address)",
+ nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1},
+ connectAddr: uniqueLocalAddr1,
+ expectedLocalAddr: uniqueLocalAddr1,
+ },
+ {
+ name: "Same Unique Local most preferred (first address)",
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
+ connectAddr: uniqueLocalAddr1,
+ expectedLocalAddr: uniqueLocalAddr1,
+ },
+
+ // Test Rule 2 of RFC 6724 section 5.
+ {
+ name: "Global most preferred (last address)",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ connectAddr: globalAddr2,
+ expectedLocalAddr: globalAddr1,
+ },
+ {
+ name: "Global most preferred (first address)",
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
+ connectAddr: globalAddr2,
+ expectedLocalAddr: globalAddr1,
+ },
+ {
+ name: "Link Local most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
+ connectAddr: linkLocalAddr2,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Link Local most preferred (first address)",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ connectAddr: linkLocalAddr2,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Unique Local most preferred (last address)",
+ nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1},
+ connectAddr: uniqueLocalAddr2,
+ expectedLocalAddr: uniqueLocalAddr1,
+ },
+ {
+ name: "Unique Local most preferred (first address)",
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
+ connectAddr: uniqueLocalAddr2,
+ expectedLocalAddr: uniqueLocalAddr1,
+ },
+
+ // Test returning the endpoint that is closest to the front when
+ // candidate addresses are "equal" from the perspective of RFC 6724
+ // section 5.
+ {
+ name: "Unique Local for Global",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, uniqueLocalAddr2},
+ connectAddr: globalAddr2,
+ expectedLocalAddr: uniqueLocalAddr1,
+ },
+ {
+ name: "Link Local for Global",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2},
+ connectAddr: globalAddr2,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Link Local for Unique Local",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2},
+ connectAddr: uniqueLocalAddr2,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ Gateway: llAddr3,
+ NIC: nicID,
+ }})
+ s.AddLinkAddress(nicID, llAddr3, linkAddr3)
+
+ for _, a := range test.nicAddrs {
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil {
+ t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err)
+ }
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ if got := addrForNewConnectionTo(t, s, tcpip.FullAddress{Addr: test.connectAddr, NIC: nicID, Port: 1234}); got != test.expectedLocalAddr {
+ t.Errorf("got local address = %s, want = %s", got, test.expectedLocalAddr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 67c21be42..d686e6eb8 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -18,8 +18,8 @@ import (
"fmt"
"math/rand"
"sort"
- "sync"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -104,7 +104,14 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p
return
}
// multiPortEndpoints are guaranteed to have at least one element.
- selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, pkt)
+ transEP := selectEndpoint(id, mpep, epsByNic.seed)
+ if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
+ queuedProtocol.QueuePacket(r, transEP, id, pkt)
+ epsByNic.mu.RUnlock()
+ return
+ }
+
+ transEP.HandlePacket(r, id, pkt)
epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
}
@@ -130,7 +137,7 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint
// registerEndpoint returns true if it succeeds. It fails and returns
// false if ep already has an element with the same key.
-func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
epsByNic.mu.Lock()
defer epsByNic.mu.Unlock()
@@ -140,7 +147,7 @@ func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort
}
// This is a new binding.
- multiPortEp := &multiPortEndpoint{}
+ multiPortEp := &multiPortEndpoint{demux: d, netProto: netProto, transProto: transProto}
multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
multiPortEp.reuse = reusePort
epsByNic.endpoints[bindToDevice] = multiPortEp
@@ -168,18 +175,34 @@ func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t T
// newTransportDemuxer.
type transportDemuxer struct {
// protocol is immutable.
- protocol map[protocolIDs]*transportEndpoints
+ protocol map[protocolIDs]*transportEndpoints
+ queuedProtocols map[protocolIDs]queuedTransportProtocol
+}
+
+// queuedTransportProtocol if supported by a protocol implementation will cause
+// the dispatcher to delivery packets to the QueuePacket method instead of
+// calling HandlePacket directly on the endpoint.
+type queuedTransportProtocol interface {
+ QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt tcpip.PacketBuffer)
}
func newTransportDemuxer(stack *Stack) *transportDemuxer {
- d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)}
+ d := &transportDemuxer{
+ protocol: make(map[protocolIDs]*transportEndpoints),
+ queuedProtocols: make(map[protocolIDs]queuedTransportProtocol),
+ }
// Add each network and transport pair to the demuxer.
for netProto := range stack.networkProtocols {
for proto := range stack.transportProtocols {
- d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{
+ protoIDs := protocolIDs{netProto, proto}
+ d.protocol[protoIDs] = &transportEndpoints{
endpoints: make(map[TransportEndpointID]*endpointsByNic),
}
+ qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol)
+ if isQueued {
+ d.queuedProtocols[protoIDs] = qTransProto
+ }
}
}
@@ -209,7 +232,11 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
//
// +stateify savable
type multiPortEndpoint struct {
- mu sync.RWMutex `state:"nosave"`
+ mu sync.RWMutex `state:"nosave"`
+ demux *transportDemuxer
+ netProto tcpip.NetworkProtocolNumber
+ transProto tcpip.TransportProtocolNumber
+
endpointsArr []TransportEndpoint
endpointsMap map[TransportEndpoint]int
// reuse indicates if more than one endpoint is allowed.
@@ -258,13 +285,22 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) {
ep.mu.RLock()
+ queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}]
for i, endpoint := range ep.endpointsArr {
// HandlePacket takes ownership of pkt, so each endpoint needs
// its own copy except for the final one.
if i == len(ep.endpointsArr)-1 {
+ if mustQueue {
+ queuedProtocol.QueuePacket(r, endpoint, id, pkt)
+ break
+ }
endpoint.HandlePacket(r, id, pkt)
break
}
+ if mustQueue {
+ queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone())
+ continue
+ }
endpoint.HandlePacket(r, id, pkt.Clone())
}
ep.mu.RUnlock() // Don't use defer for performance reasons.
@@ -357,7 +393,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
if epsByNic, ok := eps.endpoints[id]; ok {
// There was already a binding.
- return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
+ return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice)
}
// This is a new binding.
@@ -367,7 +403,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
}
eps.endpoints[id] = epsByNic
- return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
+ return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index df5ced887..5e9237de9 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -41,7 +41,7 @@ const (
type testContext struct {
t *testing.T
- linkEPs map[string]*channel.Endpoint
+ linkEps map[tcpip.NICID]*channel.Endpoint
s *stack.Stack
ep tcpip.Endpoint
@@ -66,27 +66,24 @@ func (c *testContext) createV6Endpoint(v6only bool) {
}
}
-// newDualTestContextMultiNic creates the testing context and also linkEpNames
-// named NICs.
-func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext {
+// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs.
+func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
- linkEPs := make(map[string]*channel.Endpoint)
- for i, linkEpName := range linkEpNames {
- channelEP := channel.New(256, mtu, "")
- nicID := tcpip.NICID(i + 1)
- opts := stack.NICOptions{Name: linkEpName}
- if err := s.CreateNICWithOptions(nicID, channelEP, opts); err != nil {
- t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
+ linkEps := make(map[tcpip.NICID]*channel.Endpoint)
+ for _, linkEpID := range linkEpIDs {
+ channelEp := channel.New(256, mtu, "")
+ if err := s.CreateNIC(linkEpID, channelEp); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
}
- linkEPs[linkEpName] = channelEP
+ linkEps[linkEpID] = channelEp
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
+ if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil {
t.Fatalf("AddAddress IPv4 failed: %v", err)
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
+ if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
t.Fatalf("AddAddress IPv6 failed: %v", err)
}
}
@@ -105,7 +102,7 @@ func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string)
return &testContext{
t: t,
s: s,
- linkEPs: linkEPs,
+ linkEps: linkEps,
}
}
@@ -122,7 +119,7 @@ func newPayload() []byte {
return b
}
-func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) {
+func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -153,7 +150,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEPs[linkEpName].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
+ c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
@@ -183,7 +180,7 @@ func TestTransportDemuxerRegister(t *testing.T) {
func TestDistribution(t *testing.T) {
type endpointSockopts struct {
reuse int
- bindToDevice string
+ bindToDevice tcpip.NICID
}
for _, test := range []struct {
name string
@@ -191,71 +188,71 @@ func TestDistribution(t *testing.T) {
endpoints []endpointSockopts
// wantedDistribution is the wanted ratio of packets received on each
// endpoint for each NIC on which packets are injected.
- wantedDistributions map[string][]float64
+ wantedDistributions map[tcpip.NICID][]float64
}{
{
"BindPortReuse",
// 5 endpoints that all have reuse set.
[]endpointSockopts{
- {1, ""},
- {1, ""},
- {1, ""},
- {1, ""},
- {1, ""},
+ {1, 0},
+ {1, 0},
+ {1, 0},
+ {1, 0},
+ {1, 0},
},
- map[string][]float64{
+ map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed evenly.
- "dev0": {0.2, 0.2, 0.2, 0.2, 0.2},
+ 1: {0.2, 0.2, 0.2, 0.2, 0.2},
},
},
{
"BindToDevice",
// 3 endpoints with various bindings.
[]endpointSockopts{
- {0, "dev0"},
- {0, "dev1"},
- {0, "dev2"},
+ {0, 1},
+ {0, 2},
+ {0, 3},
},
- map[string][]float64{
+ map[tcpip.NICID][]float64{
// Injected packets on dev0 go only to the endpoint bound to dev0.
- "dev0": {1, 0, 0},
+ 1: {1, 0, 0},
// Injected packets on dev1 go only to the endpoint bound to dev1.
- "dev1": {0, 1, 0},
+ 2: {0, 1, 0},
// Injected packets on dev2 go only to the endpoint bound to dev2.
- "dev2": {0, 0, 1},
+ 3: {0, 0, 1},
},
},
{
"ReuseAndBindToDevice",
// 6 endpoints with various bindings.
[]endpointSockopts{
- {1, "dev0"},
- {1, "dev0"},
- {1, "dev1"},
- {1, "dev1"},
- {1, "dev1"},
- {1, ""},
+ {1, 1},
+ {1, 1},
+ {1, 2},
+ {1, 2},
+ {1, 2},
+ {1, 0},
},
- map[string][]float64{
+ map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed among endpoints bound to
// dev0.
- "dev0": {0.5, 0.5, 0, 0, 0, 0},
+ 1: {0.5, 0.5, 0, 0, 0, 0},
// Injected packets on dev1 get distributed among endpoints bound to
// dev1 or unbound.
- "dev1": {0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
+ 2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
// Injected packets on dev999 go only to the unbound.
- "dev999": {0, 0, 0, 0, 0, 1},
+ 1000: {0, 0, 0, 0, 0, 1},
},
},
} {
t.Run(test.name, func(t *testing.T) {
for device, wantedDistribution := range test.wantedDistributions {
- t.Run(device, func(t *testing.T) {
- var devices []string
+ t.Run(string(device), func(t *testing.T) {
+ var devices []tcpip.NICID
for d := range test.wantedDistributions {
devices = append(devices, d)
}
- c := newDualTestContextMultiNic(t, defaultMTU, devices)
+ c := newDualTestContextMultiNIC(t, defaultMTU, devices)
defer c.cleanup()
c.createV6Endpoint(false)
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 1eca76c30..b7813cbc0 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -35,10 +35,10 @@ import (
"reflect"
"strconv"
"strings"
- "sync"
"sync/atomic"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/waiter"
@@ -322,7 +322,7 @@ type ControlMessages struct {
HasTOS bool
// TOS is the IPv4 type of service of the associated packet.
- TOS int8
+ TOS uint8
// HasTClass indicates whether Tclass is valid/set.
HasTClass bool
@@ -500,9 +500,13 @@ type WriteOptions struct {
type SockOptBool int
const (
+ // ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS
+ // ancillary message is passed with incoming packets.
+ ReceiveTOSOption SockOptBool = iota
+
// V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6
// socket is to be restricted to sending and receiving IPv6 packets only.
- V6OnlyOption SockOptBool = iota
+ V6OnlyOption
)
// SockOptInt represents socket options which values have the int type.
@@ -552,7 +556,7 @@ type ReusePortOption int
// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets
// should bind only on a specific NIC.
-type BindToDeviceOption string
+type BindToDeviceOption NICID
// QuickAckOption is stubbed out in SetSockOpt/GetSockOpt.
type QuickAckOption int
diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go
index 1f735d735..2d20f7ef3 100644
--- a/pkg/tcpip/timer_test.go
+++ b/pkg/tcpip/timer_test.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package timer_test
+package tcpip_test
import (
"sync"
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index d8c5b5058..3aa23d529 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -28,6 +28,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/sleep",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index c7ce74cdd..42afb3f5b 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -15,8 +15,7 @@
package icmp
import (
- "sync"
-
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -289,7 +288,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
toCopy := *to
to = &toCopy
- netProto, err := e.checkV4Mapped(to, true)
+ netProto, err := e.checkV4Mapped(to)
if err != nil {
return 0, nil, err
}
@@ -476,18 +475,12 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
})
}
-func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := e.NetProto
- if header.IsV4MappedAddress(addr.Addr) {
- return 0, tcpip.ErrNoRoute
- }
-
- // Fail if we're bound to an address length different from the one we're
- // checking.
- if l := len(e.ID.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) {
- return 0, tcpip.ErrInvalidEndpointState
+func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, false /* v6only */)
+ if err != nil {
+ return 0, err
}
-
+ *addr = unwrapped
return netProto, nil
}
@@ -519,7 +512,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- netProto, err := e.checkV4Mapped(&addr, false)
+ netProto, err := e.checkV4Mapped(&addr)
if err != nil {
return err
}
@@ -632,7 +625,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- netProto, err := e.checkV4Mapped(&addr, false)
+ netProto, err := e.checkV4Mapped(&addr)
if err != nil {
return err
}
diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD
index 44b58ff6b..4858d150c 100644
--- a/pkg/tcpip/transport/packet/BUILD
+++ b/pkg/tcpip/transport/packet/BUILD
@@ -28,6 +28,7 @@ go_library(
deps = [
"//pkg/log",
"//pkg/sleep",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 07ffa8aba..fc5bc69fa 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -25,8 +25,7 @@
package packet
import (
- "sync"
-
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index 00991ac8e..2f2131ff7 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -29,6 +29,7 @@ go_library(
deps = [
"//pkg/log",
"//pkg/sleep",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 85f7eb76b..ee9c4c58b 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -26,8 +26,7 @@
package raw
import (
- "sync"
-
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 3b353d56c..0e3ab05ad 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -16,6 +16,18 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "tcp_endpoint_list",
+ out = "tcp_endpoint_list.go",
+ package = "tcp",
+ prefix = "endpoint",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*endpoint",
+ "Linker": "*endpoint",
+ },
+)
+
go_library(
name = "tcp",
srcs = [
@@ -23,6 +35,7 @@ go_library(
"connect.go",
"cubic.go",
"cubic_state.go",
+ "dispatcher.go",
"endpoint.go",
"endpoint_state.go",
"forwarder.go",
@@ -38,6 +51,7 @@ go_library(
"segment_state.go",
"snd.go",
"snd_state.go",
+ "tcp_endpoint_list.go",
"tcp_segment_list.go",
"timer.go",
],
@@ -45,9 +59,9 @@ go_library(
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
- "//pkg/log",
"//pkg/rand",
"//pkg/sleep",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/hash/jenkins",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 5422ae80c..1a2e3efa9 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -19,11 +19,11 @@ import (
"encoding/binary"
"hash"
"io"
- "sync"
"time"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -285,7 +285,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// listenEP is nil when listenContext is used by tcp.Forwarder.
if l.listenEP != nil {
l.listenEP.mu.Lock()
- if l.listenEP.state != StateListen {
+ if l.listenEP.EndpointState() != StateListen {
l.listenEP.mu.Unlock()
return nil, tcpip.ErrConnectionAborted
}
@@ -344,11 +344,12 @@ func (l *listenContext) closeAllPendingEndpoints() {
// instead.
func (e *endpoint) deliverAccepted(n *endpoint) {
e.mu.Lock()
- state := e.state
+ state := e.EndpointState()
e.pendingAccepted.Add(1)
defer e.pendingAccepted.Done()
acceptedChan := e.acceptedChan
e.mu.Unlock()
+
if state == StateListen {
acceptedChan <- n
e.waiterQueue.Notify(waiter.EventIn)
@@ -562,8 +563,8 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// We do not use transitionToStateEstablishedLocked here as there is
// no handshake state available when doing a SYN cookie based accept.
n.stack.Stats().TCP.CurrentEstablished.Increment()
- n.state = StateEstablished
n.isConnectNotified = true
+ n.setEndpointState(StateEstablished)
// Do the delivery in a separate goroutine so
// that we don't block the listen loop in case
@@ -596,7 +597,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
// handleSynSegment() from attempting to queue new connections
// to the endpoint.
e.mu.Lock()
- e.state = StateClose
+ e.setEndpointState(StateClose)
// close any endpoints in SYN-RCVD state.
ctx.closeAllPendingEndpoints()
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index cdd69f360..a2f384384 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -16,11 +16,11 @@ package tcp
import (
"encoding/binary"
- "sync"
"time"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
@@ -190,7 +190,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea
h.mss = opts.MSS
h.sndWndScale = opts.WS
h.ep.mu.Lock()
- h.ep.state = StateSynRecv
+ h.ep.setEndpointState(StateSynRecv)
h.ep.mu.Unlock()
}
@@ -274,14 +274,14 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// SYN-RCVD state.
h.state = handshakeSynRcvd
h.ep.mu.Lock()
- h.ep.state = StateSynRecv
ttl := h.ep.ttl
+ h.ep.setEndpointState(StateSynRecv)
h.ep.mu.Unlock()
synOpts := header.TCPSynOptions{
WS: int(h.effectiveRcvWndScale()),
TS: rcvSynOpts.TS,
TSVal: h.ep.timestamp(),
- TSEcr: h.ep.recentTS,
+ TSEcr: h.ep.recentTimestamp(),
// We only send SACKPermitted if the other side indicated it
// permits SACK. This is not explicitly defined in the RFC but
@@ -341,7 +341,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
WS: h.rcvWndScale,
TS: h.ep.sendTSOk,
TSVal: h.ep.timestamp(),
- TSEcr: h.ep.recentTS,
+ TSEcr: h.ep.recentTimestamp(),
SACKPermitted: h.ep.sackPermitted,
MSS: h.ep.amss,
}
@@ -501,7 +501,7 @@ func (h *handshake) execute() *tcpip.Error {
WS: h.rcvWndScale,
TS: true,
TSVal: h.ep.timestamp(),
- TSEcr: h.ep.recentTS,
+ TSEcr: h.ep.recentTimestamp(),
SACKPermitted: bool(sackEnabled),
MSS: h.ep.amss,
}
@@ -792,7 +792,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
// Ref: https://tools.ietf.org/html/rfc7323#section-5.4.
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeNOP(options[offset:])
- offset += header.EncodeTSOption(e.timestamp(), uint32(e.recentTS), options[offset:])
+ offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:])
}
if e.sackPermitted && len(sackBlocks) > 0 {
offset += header.EncodeNOP(options[offset:])
@@ -811,7 +811,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
// sendRaw sends a TCP segment to the endpoint's peer.
func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
var sackBlocks []header.SACKBlock
- if e.state == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) {
+ if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) {
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
@@ -848,6 +848,9 @@ func (e *endpoint) handleWrite() *tcpip.Error {
}
func (e *endpoint) handleClose() *tcpip.Error {
+ if !e.EndpointState().connected() {
+ return nil
+ }
// Drain the send queue.
e.handleWrite()
@@ -864,11 +867,7 @@ func (e *endpoint) handleClose() *tcpip.Error {
func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
// Only send a reset if the connection is being aborted for a reason
// other than receiving a reset.
- if e.state == StateEstablished || e.state == StateCloseWait {
- e.stack.Stats().TCP.EstablishedResets.Increment()
- e.stack.Stats().TCP.CurrentEstablished.Decrement()
- }
- e.state = StateError
+ e.setEndpointState(StateError)
e.HardError = err
if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout {
// The exact sequence number to be used for the RST is the same as the
@@ -888,9 +887,12 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
}
// completeWorkerLocked is called by the worker goroutine when it's about to
-// exit. It marks the worker as completed and performs cleanup work if requested
-// by Close().
+// exit.
func (e *endpoint) completeWorkerLocked() {
+ // Worker is terminating(either due to moving to
+ // CLOSED or ERROR state, ensure we release all
+ // registrations port reservations even if the socket
+ // itself is not yet closed by the application.
e.workerRunning = false
if e.workerCleanup {
e.cleanupLocked()
@@ -917,8 +919,7 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
e.rcvAutoParams.prevCopied = int(h.rcvWnd)
e.rcvListMu.Unlock()
}
- h.ep.stack.Stats().TCP.CurrentEstablished.Increment()
- e.state = StateEstablished
+ e.setEndpointState(StateEstablished)
}
// transitionToStateCloseLocked ensures that the endpoint is
@@ -927,11 +928,12 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
// delivered to this endpoint from the demuxer when the endpoint
// is transitioned to StateClose.
func (e *endpoint) transitionToStateCloseLocked() {
- if e.state == StateClose {
+ if e.EndpointState() == StateClose {
return
}
+ // Mark the endpoint as fully closed for reads/writes.
e.cleanupLocked()
- e.state = StateClose
+ e.setEndpointState(StateClose)
e.stack.Stats().TCP.EstablishedClosed.Increment()
}
@@ -946,7 +948,9 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
s.decRef()
return
}
- ep.(*endpoint).enqueueSegment(s)
+ if ep.(*endpoint).enqueueSegment(s) {
+ ep.(*endpoint).newSegmentWaker.Assert()
+ }
}
func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
@@ -955,9 +959,8 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
// except SYN-SENT, all reset (RST) segments are
// validated by checking their SEQ-fields." So
// we only process it if it's acceptable.
- s.decRef()
e.mu.Lock()
- switch e.state {
+ switch e.EndpointState() {
// In case of a RST in CLOSE-WAIT linux moves
// the socket to closed state with an error set
// to indicate EPIPE.
@@ -981,103 +984,53 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
e.transitionToStateCloseLocked()
e.HardError = tcpip.ErrAborted
e.mu.Unlock()
+ e.notifyProtocolGoroutine(notifyTickleWorker)
return false, nil
default:
e.mu.Unlock()
+ // RFC 793, page 37 states that "in all states
+ // except SYN-SENT, all reset (RST) segments are
+ // validated by checking their SEQ-fields." So
+ // we only process it if it's acceptable.
+
+ // Notify protocol goroutine. This is required when
+ // handleSegment is invoked from the processor goroutine
+ // rather than the worker goroutine.
+ e.notifyProtocolGoroutine(notifyResetByPeer)
return false, tcpip.ErrConnectionReset
}
}
return true, nil
}
-// handleSegments pulls segments from the queue and processes them. It returns
-// no error if the protocol loop should continue, an error otherwise.
-func (e *endpoint) handleSegments() *tcpip.Error {
+// handleSegments processes all inbound segments.
+func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error {
checkRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
+ if e.EndpointState() == StateClose || e.EndpointState() == StateError {
+ return nil
+ }
s := e.segmentQueue.dequeue()
if s == nil {
checkRequeue = false
break
}
- // Invoke the tcp probe if installed.
- if e.probe != nil {
- e.probe(e.completeState())
+ cont, err := e.handleSegment(s)
+ if err != nil {
+ s.decRef()
+ return err
}
-
- if s.flagIsSet(header.TCPFlagRst) {
- if ok, err := e.handleReset(s); !ok {
- return err
- }
- } else if s.flagIsSet(header.TCPFlagSyn) {
- // See: https://tools.ietf.org/html/rfc5961#section-4.1
- // 1) If the SYN bit is set, irrespective of the sequence number, TCP
- // MUST send an ACK (also referred to as challenge ACK) to the remote
- // peer:
- //
- // <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK>
- //
- // After sending the acknowledgment, TCP MUST drop the unacceptable
- // segment and stop processing further.
- //
- // By sending an ACK, the remote peer is challenged to confirm the loss
- // of the previous connection and the request to start a new connection.
- // A legitimate peer, after restart, would not have a TCB in the
- // synchronized state. Thus, when the ACK arrives, the peer should send
- // a RST segment back with the sequence number derived from the ACK
- // field that caused the RST.
-
- // This RST will confirm that the remote peer has indeed closed the
- // previous connection. Upon receipt of a valid RST, the local TCP
- // endpoint MUST terminate its connection. The local TCP endpoint
- // should then rely on SYN retransmission from the remote end to
- // re-establish the connection.
-
- e.snd.sendAck()
- } else if s.flagIsSet(header.TCPFlagAck) {
- // Patch the window size in the segment according to the
- // send window scale.
- s.window <<= e.snd.sndWndScale
-
- // RFC 793, page 41 states that "once in the ESTABLISHED
- // state all segments must carry current acknowledgment
- // information."
- drop, err := e.rcv.handleRcvdSegment(s)
- if err != nil {
- s.decRef()
- return err
- }
- if drop {
- s.decRef()
- continue
- }
-
- // Now check if the received segment has caused us to transition
- // to a CLOSED state, if yes then terminate processing and do
- // not invoke the sender.
- e.mu.RLock()
- state := e.state
- e.mu.RUnlock()
- if state == StateClose {
- // When we get into StateClose while processing from the queue,
- // return immediately and let the protocolMainloop handle it.
- //
- // We can reach StateClose only while processing a previous segment
- // or a notification from the protocolMainLoop (caller goroutine).
- // This means that with this return, the segment dequeue below can
- // never occur on a closed endpoint.
- s.decRef()
- return nil
- }
- e.snd.handleRcvdSegment(s)
+ if !cont {
+ s.decRef()
+ return nil
}
- s.decRef()
}
- // If the queue is not empty, make sure we'll wake up in the next
- // iteration.
- if checkRequeue && !e.segmentQueue.empty() {
+ // When fastPath is true we don't want to wake up the worker
+ // goroutine. If the endpoint has more segments to process the
+ // dispatcher will call handleSegments again anyway.
+ if !fastPath && checkRequeue && !e.segmentQueue.empty() {
e.newSegmentWaker.Assert()
}
@@ -1086,11 +1039,88 @@ func (e *endpoint) handleSegments() *tcpip.Error {
e.snd.sendAck()
}
- e.resetKeepaliveTimer(true)
+ e.resetKeepaliveTimer(true /* receivedData */)
return nil
}
+// handleSegment handles a given segment and notifies the worker goroutine if
+// if the connection should be terminated.
+func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
+ // Invoke the tcp probe if installed.
+ if e.probe != nil {
+ e.probe(e.completeState())
+ }
+
+ if s.flagIsSet(header.TCPFlagRst) {
+ if ok, err := e.handleReset(s); !ok {
+ return false, err
+ }
+ } else if s.flagIsSet(header.TCPFlagSyn) {
+ // See: https://tools.ietf.org/html/rfc5961#section-4.1
+ // 1) If the SYN bit is set, irrespective of the sequence number, TCP
+ // MUST send an ACK (also referred to as challenge ACK) to the remote
+ // peer:
+ //
+ // <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK>
+ //
+ // After sending the acknowledgment, TCP MUST drop the unacceptable
+ // segment and stop processing further.
+ //
+ // By sending an ACK, the remote peer is challenged to confirm the loss
+ // of the previous connection and the request to start a new connection.
+ // A legitimate peer, after restart, would not have a TCB in the
+ // synchronized state. Thus, when the ACK arrives, the peer should send
+ // a RST segment back with the sequence number derived from the ACK
+ // field that caused the RST.
+
+ // This RST will confirm that the remote peer has indeed closed the
+ // previous connection. Upon receipt of a valid RST, the local TCP
+ // endpoint MUST terminate its connection. The local TCP endpoint
+ // should then rely on SYN retransmission from the remote end to
+ // re-establish the connection.
+
+ e.snd.sendAck()
+ } else if s.flagIsSet(header.TCPFlagAck) {
+ // Patch the window size in the segment according to the
+ // send window scale.
+ s.window <<= e.snd.sndWndScale
+
+ // RFC 793, page 41 states that "once in the ESTABLISHED
+ // state all segments must carry current acknowledgment
+ // information."
+ drop, err := e.rcv.handleRcvdSegment(s)
+ if err != nil {
+ return false, err
+ }
+ if drop {
+ return true, nil
+ }
+
+ // Now check if the received segment has caused us to transition
+ // to a CLOSED state, if yes then terminate processing and do
+ // not invoke the sender.
+ e.mu.RLock()
+ state := e.state
+ e.mu.RUnlock()
+ if state == StateClose {
+ // When we get into StateClose while processing from the queue,
+ // return immediately and let the protocolMainloop handle it.
+ //
+ // We can reach StateClose only while processing a previous segment
+ // or a notification from the protocolMainLoop (caller goroutine).
+ // This means that with this return, the segment dequeue below can
+ // never occur on a closed endpoint.
+ s.decRef()
+ return false, nil
+ }
+
+ e.snd.handleRcvdSegment(s)
+ }
+
+ return true, nil
+}
+
// keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP
// keepalive packets periodically when the connection is idle. If we don't hear
// from the other side after a number of tries, we terminate the connection.
@@ -1160,7 +1190,7 @@ func (e *endpoint) disableKeepaliveTimer() {
// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
// goroutine and is responsible for sending segments and handling received
// segments.
-func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
+func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error {
var closeTimer *time.Timer
var closeWaker sleep.Waker
@@ -1182,6 +1212,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
}
e.mu.Unlock()
+ e.workMu.Unlock()
// When the protocol loop exits we should wake up our waiters.
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
@@ -1193,7 +1224,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
initialRcvWnd := e.initialReceiveWindow()
h := newHandshake(e, seqnum.Size(initialRcvWnd))
e.mu.Lock()
- h.ep.state = StateSynSent
+ h.ep.setEndpointState(StateSynSent)
e.mu.Unlock()
if err := h.execute(); err != nil {
@@ -1202,12 +1233,11 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.lastErrorMu.Unlock()
e.mu.Lock()
- e.state = StateError
+ e.setEndpointState(StateError)
e.HardError = err
// Lock released below.
epilogue()
-
return err
}
}
@@ -1215,7 +1245,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.keepalive.timer.init(&e.keepalive.waker)
defer e.keepalive.timer.cleanup()
- // Tell waiters that the endpoint is connected and writable.
e.mu.Lock()
drained := e.drainDone != nil
e.mu.Unlock()
@@ -1224,8 +1253,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
<-e.undrain
}
- e.waiterQueue.Notify(waiter.EventOut)
-
// Set up the functions that will be called when the main protocol loop
// wakes up.
funcs := []struct {
@@ -1241,17 +1268,14 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
f: e.handleClose,
},
{
- w: &e.newSegmentWaker,
- f: e.handleSegments,
- },
- {
w: &closeWaker,
f: func() *tcpip.Error {
// This means the socket is being closed due
- // to the TCP_FIN_WAIT2 timeout was hit. Just
+ // to the TCP-FIN-WAIT2 timeout was hit. Just
// mark the socket as closed.
e.mu.Lock()
e.transitionToStateCloseLocked()
+ e.workerCleanup = true
e.mu.Unlock()
return nil
},
@@ -1267,6 +1291,12 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
},
},
{
+ w: &e.newSegmentWaker,
+ f: func() *tcpip.Error {
+ return e.handleSegments(false /* fastPath */)
+ },
+ },
+ {
w: &e.keepalive.waker,
f: e.keepaliveTimerExpired,
},
@@ -1293,14 +1323,16 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
}
if n&notifyReset != 0 {
- e.mu.Lock()
- e.resetConnectionLocked(tcpip.ErrConnectionAborted)
- e.mu.Unlock()
+ return tcpip.ErrConnectionAborted
+ }
+
+ if n&notifyResetByPeer != 0 {
+ return tcpip.ErrConnectionReset
}
if n&notifyClose != 0 && closeTimer == nil {
e.mu.Lock()
- if e.state == StateFinWait2 && e.closed {
+ if e.EndpointState() == StateFinWait2 && e.closed {
// The socket has been closed and we are in FIN_WAIT2
// so start the FIN_WAIT2 timer.
closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() {
@@ -1320,11 +1352,11 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
if n&notifyDrain != 0 {
for !e.segmentQueue.empty() {
- if err := e.handleSegments(); err != nil {
+ if err := e.handleSegments(false /* fastPath */); err != nil {
return err
}
}
- if e.state != StateClose && e.state != StateError {
+ if e.EndpointState() != StateClose && e.EndpointState() != StateError {
// Only block the worker if the endpoint
// is not in closed state or error state.
close(e.drainDone)
@@ -1349,14 +1381,21 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
s.AddWaker(funcs[i].w, i)
}
+ // Notify the caller that the waker initialization is complete and the
+ // endpoint is ready.
+ if wakerInitDone != nil {
+ close(wakerInitDone)
+ }
+
+ // Tell waiters that the endpoint is connected and writable.
+ e.waiterQueue.Notify(waiter.EventOut)
+
// The following assertions and notifications are needed for restored
// endpoints. Fresh newly created endpoints have empty states and should
// not invoke any.
- e.segmentQueue.mu.Lock()
- if !e.segmentQueue.list.Empty() {
+ if !e.segmentQueue.empty() {
e.newSegmentWaker.Assert()
}
- e.segmentQueue.mu.Unlock()
e.rcvListMu.Lock()
if !e.rcvList.Empty() {
@@ -1371,28 +1410,53 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
// Main loop. Handle segments until both send and receive ends of the
// connection have completed.
+ cleanupOnError := func(err *tcpip.Error) {
+ e.mu.Lock()
+ e.workerCleanup = true
+ if err != nil {
+ e.resetConnectionLocked(err)
+ }
+ // Lock released below.
+ epilogue()
+ }
- for e.state != StateTimeWait && e.state != StateClose && e.state != StateError {
+loop:
+ for e.EndpointState() != StateTimeWait && e.EndpointState() != StateClose && e.EndpointState() != StateError {
e.mu.Unlock()
e.workMu.Unlock()
v, _ := s.Fetch(true)
e.workMu.Lock()
- if err := funcs[v].f(); err != nil {
- e.mu.Lock()
- // Ensure we release all endpoint registration and route
- // references as the connection is now in an error
- // state.
- e.workerCleanup = true
- e.resetConnectionLocked(err)
- // Lock released below.
- epilogue()
+ // We need to double check here because the notification maybe
+ // stale by the time we got around to processing it.
+ //
+ // NOTE: since we now hold the workMu the processors cannot
+ // change the state of the endpoint so it's safe to proceed
+ // after this check.
+ switch e.EndpointState() {
+ case StateError:
+ // If the endpoint has already transitioned to an ERROR
+ // state just pass nil here as any reset that may need
+ // to be sent etc should already have been done and we
+ // just want to terminate the loop and cleanup the
+ // endpoint.
+ cleanupOnError(nil)
return nil
+ case StateTimeWait:
+ fallthrough
+ case StateClose:
+ e.mu.Lock()
+ break loop
+ default:
+ if err := funcs[v].f(); err != nil {
+ cleanupOnError(err)
+ return nil
+ }
+ e.mu.Lock()
}
- e.mu.Lock()
}
- state := e.state
+ state := e.EndpointState()
e.mu.Unlock()
var reuseTW func()
if state == StateTimeWait {
@@ -1405,13 +1469,15 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
s.Done()
// Wake up any waiters before we enter TIME_WAIT.
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ e.mu.Lock()
+ e.workerCleanup = true
+ e.mu.Unlock()
reuseTW = e.doTimeWait()
}
// Mark endpoint as closed.
e.mu.Lock()
- if e.state != StateError {
- e.stack.Stats().TCP.CurrentEstablished.Decrement()
+ if e.EndpointState() != StateError {
e.transitionToStateCloseLocked()
}
@@ -1468,7 +1534,11 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()
tcpEP := listenEP.(*endpoint)
if EndpointState(tcpEP.State()) == StateListen {
reuseTW = func() {
- tcpEP.enqueueSegment(s)
+ if !tcpEP.enqueueSegment(s) {
+ s.decRef()
+ return
+ }
+ tcpEP.newSegmentWaker.Assert()
}
// We explicitly do not decRef
// the segment as it's still
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
new file mode 100644
index 000000000..e18012ac0
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -0,0 +1,224 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// epQueue is a queue of endpoints.
+type epQueue struct {
+ mu sync.Mutex
+ list endpointList
+}
+
+// enqueue adds e to the queue if the endpoint is not already on the queue.
+func (q *epQueue) enqueue(e *endpoint) {
+ q.mu.Lock()
+ if e.pendingProcessing {
+ q.mu.Unlock()
+ return
+ }
+ q.list.PushBack(e)
+ e.pendingProcessing = true
+ q.mu.Unlock()
+}
+
+// dequeue removes and returns the first element from the queue if available,
+// returns nil otherwise.
+func (q *epQueue) dequeue() *endpoint {
+ q.mu.Lock()
+ if e := q.list.Front(); e != nil {
+ q.list.Remove(e)
+ e.pendingProcessing = false
+ q.mu.Unlock()
+ return e
+ }
+ q.mu.Unlock()
+ return nil
+}
+
+// empty returns true if the queue is empty, false otherwise.
+func (q *epQueue) empty() bool {
+ q.mu.Lock()
+ v := q.list.Empty()
+ q.mu.Unlock()
+ return v
+}
+
+// processor is responsible for processing packets queued to a tcp endpoint.
+type processor struct {
+ epQ epQueue
+ newEndpointWaker sleep.Waker
+ id int
+}
+
+func newProcessor(id int) *processor {
+ p := &processor{
+ id: id,
+ }
+ go p.handleSegments()
+ return p
+}
+
+func (p *processor) queueEndpoint(ep *endpoint) {
+ // Queue an endpoint for processing by the processor goroutine.
+ p.epQ.enqueue(ep)
+ p.newEndpointWaker.Assert()
+}
+
+func (p *processor) handleSegments() {
+ const newEndpointWaker = 1
+ s := sleep.Sleeper{}
+ s.AddWaker(&p.newEndpointWaker, newEndpointWaker)
+ defer s.Done()
+ for {
+ s.Fetch(true)
+ for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() {
+ if ep.segmentQueue.empty() {
+ continue
+ }
+
+ // If socket has transitioned out of connected state
+ // then just let the worker handle the packet.
+ //
+ // NOTE: We read this outside of e.mu lock which means
+ // that by the time we get to handleSegments the
+ // endpoint may not be in ESTABLISHED. But this should
+ // be fine as all normal shutdown states are handled by
+ // handleSegments and if the endpoint moves to a
+ // CLOSED/ERROR state then handleSegments is a noop.
+ if ep.EndpointState() != StateEstablished {
+ ep.newSegmentWaker.Assert()
+ continue
+ }
+
+ if !ep.workMu.TryLock() {
+ ep.newSegmentWaker.Assert()
+ continue
+ }
+ // If the endpoint is in a connected state then we do
+ // direct delivery to ensure low latency and avoid
+ // scheduler interactions.
+ if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose {
+ // Send any active resets if required.
+ if err != nil {
+ ep.mu.Lock()
+ ep.resetConnectionLocked(err)
+ ep.mu.Unlock()
+ }
+ ep.notifyProtocolGoroutine(notifyTickleWorker)
+ ep.workMu.Unlock()
+ continue
+ }
+
+ if !ep.segmentQueue.empty() {
+ p.epQ.enqueue(ep)
+ }
+
+ ep.workMu.Unlock()
+ }
+ }
+}
+
+// dispatcher manages a pool of TCP endpoint processors which are responsible
+// for the processing of inbound segments. This fixed pool of processor
+// goroutines do full tcp processing. The processor is selected based on the
+// hash of the endpoint id to ensure that delivery for the same endpoint happens
+// in-order.
+type dispatcher struct {
+ processors []*processor
+ seed uint32
+}
+
+func newDispatcher(nProcessors int) *dispatcher {
+ processors := []*processor{}
+ for i := 0; i < nProcessors; i++ {
+ processors = append(processors, newProcessor(i))
+ }
+ return &dispatcher{
+ processors: processors,
+ seed: generateRandUint32(),
+ }
+}
+
+func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
+ ep := stackEP.(*endpoint)
+ s := newSegment(r, id, pkt)
+ if !s.parse() {
+ ep.stack.Stats().MalformedRcvdPackets.Increment()
+ ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
+ ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
+ s.decRef()
+ return
+ }
+
+ if !s.csumValid {
+ ep.stack.Stats().MalformedRcvdPackets.Increment()
+ ep.stack.Stats().TCP.ChecksumErrors.Increment()
+ ep.stats.ReceiveErrors.ChecksumErrors.Increment()
+ s.decRef()
+ return
+ }
+
+ ep.stack.Stats().TCP.ValidSegmentsReceived.Increment()
+ ep.stats.SegmentsReceived.Increment()
+ if (s.flags & header.TCPFlagRst) != 0 {
+ ep.stack.Stats().TCP.ResetsReceived.Increment()
+ }
+
+ if !ep.enqueueSegment(s) {
+ s.decRef()
+ return
+ }
+
+ // For sockets not in established state let the worker goroutine
+ // handle the packets.
+ if ep.EndpointState() != StateEstablished {
+ ep.newSegmentWaker.Assert()
+ return
+ }
+
+ d.selectProcessor(id).queueEndpoint(ep)
+}
+
+func generateRandUint32() uint32 {
+ b := make([]byte, 4)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
+}
+
+func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor {
+ payload := []byte{
+ byte(id.LocalPort),
+ byte(id.LocalPort >> 8),
+ byte(id.RemotePort),
+ byte(id.RemotePort >> 8)}
+
+ h := jenkins.Sum32(d.seed)
+ h.Write(payload)
+ h.Write([]byte(id.LocalAddress))
+ h.Write([]byte(id.RemoteAddress))
+
+ return d.processors[h.Sum32()%uint32(len(d.processors))]
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 2ac1b6877..4797f11d1 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -19,12 +19,12 @@ import (
"fmt"
"math"
"strings"
- "sync"
"sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
@@ -120,6 +120,7 @@ const (
notifyMTUChanged
notifyDrain
notifyReset
+ notifyResetByPeer
notifyKeepaliveChanged
notifyMSSChanged
// notifyTickleWorker is used to tickle the protocol main loop during a
@@ -127,6 +128,7 @@ const (
// ensures the loop terminates if the final state of the endpoint is
// say TIME_WAIT.
notifyTickleWorker
+ notifyError
)
// SACKInfo holds TCP SACK related information for a given endpoint.
@@ -283,6 +285,18 @@ func (*EndpointInfo) IsEndpointInfo() {}
type endpoint struct {
EndpointInfo
+ // endpointEntry is used to queue endpoints for processing to the
+ // a given tcp processor goroutine.
+ //
+ // Precondition: epQueue.mu must be held to read/write this field..
+ endpointEntry `state:"nosave"`
+
+ // pendingProcessing is true if this endpoint is queued for processing
+ // to a TCP processor.
+ //
+ // Precondition: epQueue.mu must be held to read/write this field..
+ pendingProcessing bool `state:"nosave"`
+
// workMu is used to arbitrate which goroutine may perform protocol
// work. Only the main protocol goroutine is expected to call Lock() on
// it, but other goroutines (e.g., send) may call TryLock() to eagerly
@@ -324,6 +338,7 @@ type endpoint struct {
// The following fields are protected by the mutex.
mu sync.RWMutex `state:"nosave"`
+ // state must be read/set using the EndpointState()/setEndpointState() methods.
state EndpointState `state:".(EndpointState)"`
// origEndpointState is only used during a restore phase to save the
@@ -359,7 +374,7 @@ type endpoint struct {
workerRunning bool
// workerCleanup specifies if the worker goroutine must perform cleanup
- // before exitting. This can only be set to true when workerRunning is
+ // before exiting. This can only be set to true when workerRunning is
// also true, and they're both protected by the mutex.
workerCleanup bool
@@ -371,6 +386,8 @@ type endpoint struct {
// recentTS is the timestamp that should be sent in the TSEcr field of
// the timestamp for future segments sent by the endpoint. This field is
// updated if required when a new segment is received by this endpoint.
+ //
+ // recentTS must be read/written atomically.
recentTS uint32
// tsOffset is a randomized offset added to the value of the
@@ -567,6 +584,47 @@ func (e *endpoint) ResumeWork() {
e.workMu.Unlock()
}
+// setEndpointState updates the state of the endpoint to state atomically. This
+// method is unexported as the only place we should update the state is in this
+// package but we allow the state to be read freely without holding e.mu.
+//
+// Precondition: e.mu must be held to call this method.
+func (e *endpoint) setEndpointState(state EndpointState) {
+ oldstate := EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
+ switch state {
+ case StateEstablished:
+ e.stack.Stats().TCP.CurrentEstablished.Increment()
+ case StateError:
+ fallthrough
+ case StateClose:
+ if oldstate == StateCloseWait || oldstate == StateEstablished {
+ e.stack.Stats().TCP.EstablishedResets.Increment()
+ }
+ fallthrough
+ default:
+ if oldstate == StateEstablished {
+ e.stack.Stats().TCP.CurrentEstablished.Decrement()
+ }
+ }
+ atomic.StoreUint32((*uint32)(&e.state), uint32(state))
+}
+
+// EndpointState returns the current state of the endpoint.
+func (e *endpoint) EndpointState() EndpointState {
+ return EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
+}
+
+// setRecentTimestamp atomically sets the recentTS field to the
+// provided value.
+func (e *endpoint) setRecentTimestamp(recentTS uint32) {
+ atomic.StoreUint32(&e.recentTS, recentTS)
+}
+
+// recentTimestamp atomically reads and returns the value of the recentTS field.
+func (e *endpoint) recentTimestamp() uint32 {
+ return atomic.LoadUint32(&e.recentTS)
+}
+
// keepalive is a synchronization wrapper used to appease stateify. See the
// comment in endpoint, where it is used.
//
@@ -656,7 +714,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
e.mu.RLock()
defer e.mu.RUnlock()
- switch e.state {
+ switch e.EndpointState() {
case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv:
// Ready for nothing.
@@ -672,7 +730,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
}
}
}
- if e.state.connected() {
+ if e.EndpointState().connected() {
// Determine if the endpoint is writable if requested.
if (mask & waiter.EventOut) != 0 {
e.sndBufMu.Lock()
@@ -733,14 +791,20 @@ func (e *endpoint) Close() {
// Issue a shutdown so that the peer knows we won't send any more data
// if we're connected, or stop accepting if we're listening.
e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
+ e.closeNoShutdown()
+}
+// closeNoShutdown closes the endpoint without doing a full shutdown. This is
+// used when a connection needs to be aborted with a RST and we want to skip
+// a full 4 way TCP shutdown.
+func (e *endpoint) closeNoShutdown() {
e.mu.Lock()
// For listening sockets, we always release ports inline so that they
// are immediately available for reuse after Close() is called. If also
// registered, we unregister as well otherwise the next user would fail
// in Listen() when trying to register.
- if e.state == StateListen && e.isPortReserved {
+ if e.EndpointState() == StateListen && e.isPortReserved {
if e.isRegistered {
e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice)
e.isRegistered = false
@@ -780,6 +844,8 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() {
defer close(done)
for n := range e.acceptedChan {
n.notifyProtocolGoroutine(notifyReset)
+ // close all connections that have completed but
+ // not accepted by the application.
n.Close()
}
}()
@@ -797,11 +863,13 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() {
// after Close() is called and the worker goroutine (if any) is done with its
// work.
func (e *endpoint) cleanupLocked() {
+
// Close all endpoints that might have been accepted by TCP but not by
// the client.
if e.acceptedChan != nil {
e.closePendingAcceptableConnectionsLocked()
}
+
e.workerCleanup = false
if e.isRegistered {
@@ -885,8 +953,14 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
// reject valid data that might already be in flight as the
// acceptable window will shrink.
if rcvWnd > e.rcvBufSize {
+ availBefore := e.receiveBufferAvailableLocked()
e.rcvBufSize = rcvWnd
- e.notifyProtocolGoroutine(notifyReceiveWindowChanged)
+ availAfter := e.receiveBufferAvailableLocked()
+ mask := uint32(notifyReceiveWindowChanged)
+ if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above {
+ mask |= notifyNonZeroReceiveWindow
+ }
+ e.notifyProtocolGoroutine(mask)
}
// We only update prevCopied when we grow the buffer because in cases
@@ -914,7 +988,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
// reads to proceed before returning a ECONNRESET.
e.rcvListMu.Lock()
bufUsed := e.rcvBufUsed
- if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 {
+ if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
e.rcvListMu.Unlock()
he := e.HardError
e.mu.RUnlock()
@@ -938,7 +1012,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
if e.rcvBufUsed == 0 {
- if e.rcvClosed || !e.state.connected() {
+ if e.rcvClosed || !e.EndpointState().connected() {
return buffer.View{}, tcpip.ErrClosedForReceive
}
return buffer.View{}, tcpip.ErrWouldBlock
@@ -955,11 +1029,12 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
}
e.rcvBufUsed -= len(v)
- // If the window was zero before this read and if the read freed up
- // enough buffer space for the scaled window to be non-zero then notify
- // the protocol goroutine to send a window update.
- if e.zeroWindow && !e.zeroReceiveWindow(e.rcv.rcvWndScale) {
- e.zeroWindow = false
+
+ // If the window was small before this read and if the read freed up
+ // enough buffer space, to either fit an aMSS or half a receive buffer
+ // (whichever smaller), then notify the protocol goroutine to send a
+ // window update.
+ if crossed, above := e.windowCrossedACKThreshold(len(v)); crossed && above {
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
@@ -973,8 +1048,8 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
// Caller must hold e.mu and e.sndBufMu
func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
// The endpoint cannot be written to if it's not connected.
- if !e.state.connected() {
- switch e.state {
+ if !e.EndpointState().connected() {
+ switch e.EndpointState() {
case StateError:
return 0, e.HardError
default:
@@ -1032,42 +1107,86 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, perr
}
- if !opts.Atomic { // See above.
- e.mu.RLock()
- e.sndBufMu.Lock()
+ if opts.Atomic {
+ // Add data to the send queue.
+ s := newSegmentFromView(&e.route, e.ID, v)
+ e.sndBufUsed += len(v)
+ e.sndBufInQueue += seqnum.Size(len(v))
+ e.sndQueue.PushBack(s)
+ e.sndBufMu.Unlock()
+ // Release the endpoint lock to prevent deadlocks due to lock
+ // order inversion when acquiring workMu.
+ e.mu.RUnlock()
+ }
- // Because we released the lock before copying, check state again
- // to make sure the endpoint is still in a valid state for a write.
- avail, err = e.isEndpointWritableLocked()
- if err != nil {
+ if e.workMu.TryLock() {
+ // Since we released locks in between it's possible that the
+ // endpoint transitioned to a CLOSED/ERROR states so make
+ // sure endpoint is still writable before trying to write.
+ if !opts.Atomic { // See above.
+ e.mu.RLock()
+ e.sndBufMu.Lock()
+
+ // Because we released the lock before copying, check state again
+ // to make sure the endpoint is still in a valid state for a write.
+ avail, err = e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+ e.stats.WriteErrors.WriteClosed.Increment()
+ return 0, nil, err
+ }
+
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
+ }
+ // Add data to the send queue.
+ s := newSegmentFromView(&e.route, e.ID, v)
+ e.sndBufUsed += len(v)
+ e.sndBufInQueue += seqnum.Size(len(v))
+ e.sndQueue.PushBack(s)
e.sndBufMu.Unlock()
+ // Release the endpoint lock to prevent deadlocks due to lock
+ // order inversion when acquiring workMu.
e.mu.RUnlock()
- e.stats.WriteErrors.WriteClosed.Increment()
- return 0, nil, err
- }
- // Discard any excess data copied in due to avail being reduced due
- // to a simultaneous write call to the socket.
- if avail < len(v) {
- v = v[:avail]
}
- }
-
- // Add data to the send queue.
- s := newSegmentFromView(&e.route, e.ID, v)
- e.sndBufUsed += len(v)
- e.sndBufInQueue += seqnum.Size(len(v))
- e.sndQueue.PushBack(s)
- e.sndBufMu.Unlock()
- // Release the endpoint lock to prevent deadlocks due to lock
- // order inversion when acquiring workMu.
- e.mu.RUnlock()
-
- if e.workMu.TryLock() {
// Do the work inline.
e.handleWrite()
e.workMu.Unlock()
} else {
+ if !opts.Atomic { // See above.
+ e.mu.RLock()
+ e.sndBufMu.Lock()
+
+ // Because we released the lock before copying, check state again
+ // to make sure the endpoint is still in a valid state for a write.
+ avail, err = e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+ e.stats.WriteErrors.WriteClosed.Increment()
+ return 0, nil, err
+ }
+
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
+ }
+ // Add data to the send queue.
+ s := newSegmentFromView(&e.route, e.ID, v)
+ e.sndBufUsed += len(v)
+ e.sndBufInQueue += seqnum.Size(len(v))
+ e.sndQueue.PushBack(s)
+ e.sndBufMu.Unlock()
+ // Release the endpoint lock to prevent deadlocks due to lock
+ // order inversion when acquiring workMu.
+ e.mu.RUnlock()
+
+ }
// Let the protocol goroutine do the work.
e.sndWaker.Assert()
}
@@ -1084,7 +1203,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
// The endpoint can be read if it's connected, or if it's already closed
// but has some pending unread data.
- if s := e.state; !s.connected() && s != StateClose {
+ if s := e.EndpointState(); !s.connected() && s != StateClose {
if s == StateError {
return 0, tcpip.ControlMessages{}, e.HardError
}
@@ -1096,7 +1215,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
defer e.rcvListMu.Unlock()
if e.rcvBufUsed == 0 {
- if e.rcvClosed || !e.state.connected() {
+ if e.rcvClosed || !e.EndpointState().connected() {
e.stats.ReadErrors.ReadClosed.Increment()
return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
}
@@ -1133,16 +1252,38 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
return num, tcpip.ControlMessages{}, nil
}
-// zeroReceiveWindow checks if the receive window to be announced now would be
-// zero, based on the amount of available buffer and the receive window scaling.
+// windowCrossedACKThreshold checks if the receive window to be announced now
+// would be under aMSS or under half receive buffer, whichever smaller. This is
+// useful as a receive side silly window syndrome prevention mechanism. If
+// window grows to reasonable value, we should send ACK to the sender to inform
+// the rx space is now large. We also want ensure a series of small read()'s
+// won't trigger a flood of spurious tiny ACK's.
//
-// It must be called with rcvListMu held.
-func (e *endpoint) zeroReceiveWindow(scale uint8) bool {
- if e.rcvBufUsed >= e.rcvBufSize {
- return true
+// For large receive buffers, the threshold is aMSS - once reader reads more
+// than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of
+// receive buffer size. This is chosen arbitrairly.
+// crossed will be true if the window size crossed the ACK threshold.
+// above will be true if the new window is >= ACK threshold and false
+// otherwise.
+func (e *endpoint) windowCrossedACKThreshold(deltaBefore int) (crossed bool, above bool) {
+ newAvail := e.receiveBufferAvailableLocked()
+ oldAvail := newAvail - deltaBefore
+ if oldAvail < 0 {
+ oldAvail = 0
}
- return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0
+ threshold := int(e.amss)
+ if threshold > e.rcvBufSize/2 {
+ threshold = e.rcvBufSize / 2
+ }
+
+ switch {
+ case oldAvail < threshold && newAvail >= threshold:
+ return true, true
+ case oldAvail >= threshold && newAvail < threshold:
+ return true, false
+ }
+ return false, false
}
// SetSockOptBool sets a socket option.
@@ -1158,7 +1299,7 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
defer e.mu.Unlock()
// We only allow this to be set when we're in the initial state.
- if e.state != StateInitial {
+ if e.EndpointState() != StateInitial {
return tcpip.ErrInvalidEndpointState
}
@@ -1204,10 +1345,16 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
size = math.MaxInt32 / 2
}
+ availBefore := e.receiveBufferAvailableLocked()
e.rcvBufSize = size
+ availAfter := e.receiveBufferAvailableLocked()
+
e.rcvAutoParams.disabled = true
- if e.zeroWindow && !e.zeroReceiveWindow(scale) {
- e.zeroWindow = false
+
+ // Immediately send an ACK to uncork the sender silly window
+ // syndrome prevetion, when our available space grows above aMSS
+ // or half receive buffer, whichever smaller.
+ if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above {
mask |= notifyNonZeroReceiveWindow
}
e.rcvListMu.Unlock()
@@ -1279,19 +1426,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
case tcpip.BindToDeviceOption:
- e.mu.Lock()
- defer e.mu.Unlock()
- if v == "" {
- e.bindToDevice = 0
- return nil
- }
- for nicID, nic := range e.stack.NICInfo() {
- if nic.Name == string(v) {
- e.bindToDevice = nicID
- return nil
- }
+ id := tcpip.NICID(v)
+ if id != 0 && !e.stack.HasNIC(id) {
+ return tcpip.ErrUnknownDevice
}
- return tcpip.ErrUnknownDevice
+ e.mu.Lock()
+ e.bindToDevice = id
+ e.mu.Unlock()
+ return nil
case tcpip.QuickAckOption:
if v == 0 {
@@ -1372,14 +1514,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
// Acquire the work mutex as we may need to
// reinitialize the congestion control state.
e.mu.Lock()
- state := e.state
+ state := e.EndpointState()
e.cc = v
e.mu.Unlock()
switch state {
case StateEstablished:
e.workMu.Lock()
e.mu.Lock()
- if e.state == state {
+ if e.EndpointState() == state {
e.snd.cc = e.snd.initCongestionControl(e.cc)
}
e.mu.Unlock()
@@ -1442,7 +1584,7 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
defer e.mu.RUnlock()
// The endpoint cannot be in listen state.
- if e.state == StateListen {
+ if e.EndpointState() == StateListen {
return 0, tcpip.ErrInvalidEndpointState
}
@@ -1550,12 +1692,8 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case *tcpip.BindToDeviceOption:
e.mu.RLock()
- defer e.mu.RUnlock()
- if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
- *o = tcpip.BindToDeviceOption(nic.Name)
- return nil
- }
- *o = ""
+ *o = tcpip.BindToDeviceOption(e.bindToDevice)
+ e.mu.RUnlock()
return nil
case *tcpip.QuickAckOption:
@@ -1665,26 +1803,11 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := e.NetProto
- if header.IsV4MappedAddress(addr.Addr) {
- // Fail if using a v4 mapped address on a v6only endpoint.
- if e.v6only {
- return 0, tcpip.ErrNoRoute
- }
-
- netProto = header.IPv4ProtocolNumber
- addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
- if addr.Addr == header.IPv4Any {
- addr.Addr = ""
- }
- }
-
- // Fail if we're bound to an address length different from the one we're
- // checking.
- if l := len(e.ID.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) {
- return 0, tcpip.ErrInvalidEndpointState
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only)
+ if err != nil {
+ return 0, err
}
-
+ *addr = unwrapped
return netProto, nil
}
@@ -1720,7 +1843,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
return err
}
- if e.state.connected() {
+ if e.EndpointState().connected() {
// The endpoint is already connected. If caller hasn't been
// notified yet, return success.
if !e.isConnectNotified {
@@ -1732,7 +1855,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
nicID := addr.NIC
- switch e.state {
+ switch e.EndpointState() {
case StateBound:
// If we're already bound to a NIC but the caller is requesting
// that we use a different one now, we cannot proceed.
@@ -1839,7 +1962,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
e.isRegistered = true
- e.state = StateConnecting
+ e.setEndpointState(StateConnecting)
e.route = r.Clone()
e.boundNICID = nicID
e.effectiveNetProtos = netProtos
@@ -1860,14 +1983,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
e.segmentQueue.mu.Unlock()
e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0)
- e.state = StateEstablished
- e.stack.Stats().TCP.CurrentEstablished.Increment()
+ e.setEndpointState(StateEstablished)
}
if run {
e.workerRunning = true
e.stack.Stats().TCP.ActiveConnectionOpenings.Increment()
- go e.protocolMainLoop(handshake) // S/R-SAFE: will be drained before save.
+ go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save.
}
return tcpip.ErrConnectStarted
@@ -1885,7 +2007,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
e.shutdownFlags |= flags
finQueued := false
switch {
- case e.state.connected():
+ case e.EndpointState().connected():
// Close for read.
if (e.shutdownFlags & tcpip.ShutdownRead) != 0 {
// Mark read side as closed.
@@ -1897,8 +2019,18 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// If we're fully closed and we have unread data we need to abort
// the connection with a RST.
if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 {
- e.notifyProtocolGoroutine(notifyReset)
e.mu.Unlock()
+ // Try to send an active reset immediately if the
+ // work mutex is available.
+ if e.workMu.TryLock() {
+ e.mu.Lock()
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ e.mu.Unlock()
+ e.workMu.Unlock()
+ } else {
+ e.notifyProtocolGoroutine(notifyReset)
+ }
return nil
}
}
@@ -1920,11 +2052,10 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
finQueued = true
// Mark endpoint as closed.
e.sndClosed = true
-
e.sndBufMu.Unlock()
}
- case e.state == StateListen:
+ case e.EndpointState() == StateListen:
// Tell protocolListenLoop to stop.
if flags&tcpip.ShutdownRead != 0 {
e.notifyProtocolGoroutine(notifyClose)
@@ -1965,7 +2096,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
// When the endpoint shuts down, it sets workerCleanup to true, and from
// that point onward, acceptedChan is the responsibility of the cleanup()
// method (and should not be touched anywhere else, including here).
- if e.state == StateListen && !e.workerCleanup {
+ if e.EndpointState() == StateListen && !e.workerCleanup {
// Adjust the size of the channel iff we can fix existing
// pending connections into the new one.
if len(e.acceptedChan) > backlog {
@@ -1983,7 +2114,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
return nil
}
- if e.state == StateInitial {
+ if e.EndpointState() == StateInitial {
// The listen is called on an unbound socket, the socket is
// automatically bound to a random free port with the local
// address set to INADDR_ANY.
@@ -1993,7 +2124,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
}
// Endpoint must be bound before it can transition to listen mode.
- if e.state != StateBound {
+ if e.EndpointState() != StateBound {
e.stats.ReadErrors.InvalidEndpointState.Increment()
return tcpip.ErrInvalidEndpointState
}
@@ -2004,24 +2135,27 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
}
e.isRegistered = true
- e.state = StateListen
+ e.setEndpointState(StateListen)
+
if e.acceptedChan == nil {
e.acceptedChan = make(chan *endpoint, backlog)
}
e.workerRunning = true
-
go e.protocolListenLoop( // S/R-SAFE: drained on save.
seqnum.Size(e.receiveBufferAvailable()))
-
return nil
}
// startAcceptedLoop sets up required state and starts a goroutine with the
// main loop for accepted connections.
func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) {
+ e.mu.Lock()
e.waiterQueue = waiterQueue
e.workerRunning = true
- go e.protocolMainLoop(false) // S/R-SAFE: drained on save.
+ e.mu.Unlock()
+ wakerInitDone := make(chan struct{})
+ go e.protocolMainLoop(false, wakerInitDone) // S/R-SAFE: drained on save.
+ <-wakerInitDone
}
// Accept returns a new endpoint if a peer has established a connection
@@ -2031,7 +2165,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
defer e.mu.RUnlock()
// Endpoint must be in listen state before it can accept connections.
- if e.state != StateListen {
+ if e.EndpointState() != StateListen {
return nil, nil, tcpip.ErrInvalidEndpointState
}
@@ -2058,7 +2192,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// Don't allow binding once endpoint is not in the initial state
// anymore. This is because once the endpoint goes into a connected or
// listen state, it is already bound.
- if e.state != StateInitial {
+ if e.EndpointState() != StateInitial {
return tcpip.ErrAlreadyBound
}
@@ -2120,7 +2254,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
}
// Mark endpoint as bound.
- e.state = StateBound
+ e.setEndpointState(StateBound)
return nil
}
@@ -2142,7 +2276,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if !e.state.connected() {
+ if !e.EndpointState().connected() {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
@@ -2153,45 +2287,22 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
}, nil
}
-// HandlePacket is called by the stack when new packets arrive to this transport
-// endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
- s := newSegment(r, id, pkt)
- if !s.parse() {
- e.stack.Stats().MalformedRcvdPackets.Increment()
- e.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
- e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
- s.decRef()
- return
- }
-
- if !s.csumValid {
- e.stack.Stats().MalformedRcvdPackets.Increment()
- e.stack.Stats().TCP.ChecksumErrors.Increment()
- e.stats.ReceiveErrors.ChecksumErrors.Increment()
- s.decRef()
- return
- }
-
- e.stack.Stats().TCP.ValidSegmentsReceived.Increment()
- e.stats.SegmentsReceived.Increment()
- if (s.flags & header.TCPFlagRst) != 0 {
- e.stack.Stats().TCP.ResetsReceived.Increment()
- }
-
- e.enqueueSegment(s)
+ // TCP HandlePacket is not required anymore as inbound packets first
+ // land at the Dispatcher which then can either delivery using the
+ // worker go routine or directly do the invoke the tcp processing inline
+ // based on the state of the endpoint.
}
-func (e *endpoint) enqueueSegment(s *segment) {
+func (e *endpoint) enqueueSegment(s *segment) bool {
// Send packet to worker goroutine.
- if e.segmentQueue.enqueue(s) {
- e.newSegmentWaker.Assert()
- } else {
+ if !e.segmentQueue.enqueue(s) {
// The queue is full, so we drop the segment.
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.SegmentQueueDropped.Increment()
- s.decRef()
+ return false
}
+ return true
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
@@ -2234,13 +2345,10 @@ func (e *endpoint) readyToRead(s *segment) {
if s != nil {
s.incRef()
e.rcvBufUsed += s.data.Size()
- // Check if the receive window is now closed. If so make sure
- // we set the zero window before we deliver the segment to ensure
- // that a subsequent read of the segment will correctly trigger
- // a non-zero notification.
- if avail := e.receiveBufferAvailableLocked(); avail>>e.rcv.rcvWndScale == 0 {
+ // Increase counter if the receive window falls down below MSS
+ // or half receive buffer size, whichever smaller.
+ if crossed, above := e.windowCrossedACKThreshold(-s.data.Size()); crossed && !above {
e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
- e.zeroWindow = true
}
e.rcvList.PushBack(s)
} else {
@@ -2311,8 +2419,8 @@ func (e *endpoint) rcvWndScaleForHandshake() int {
// updateRecentTimestamp updates the recent timestamp using the algorithm
// described in https://tools.ietf.org/html/rfc7323#section-4.3
func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) {
- if e.sendTSOk && seqnum.Value(e.recentTS).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) {
- e.recentTS = tsVal
+ if e.sendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) {
+ e.setRecentTimestamp(tsVal)
}
}
@@ -2322,7 +2430,7 @@ func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value,
func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
if synOpts.TS {
e.sendTSOk = true
- e.recentTS = synOpts.TSVal
+ e.setRecentTimestamp(synOpts.TSVal)
}
}
@@ -2411,7 +2519,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
// Endpoint TCP Option state.
s.SendTSOk = e.sendTSOk
- s.RecentTS = e.recentTS
+ s.RecentTS = e.recentTimestamp()
s.TSOffset = e.tsOffset
s.SACKPermitted = e.sackPermitted
s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks)
@@ -2518,9 +2626,7 @@ func (e *endpoint) initGSO() {
// State implements tcpip.Endpoint.State. It exports the endpoint's protocol
// state for diagnostics.
func (e *endpoint) State() uint32 {
- e.mu.Lock()
- defer e.mu.Unlock()
- return uint32(e.state)
+ return uint32(e.EndpointState())
}
// Info returns a copy of the endpoint info.
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 7aa4c3f0e..4a46f0ec5 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -16,9 +16,10 @@ package tcp
import (
"fmt"
- "sync"
+ "sync/atomic"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -48,7 +49,7 @@ func (e *endpoint) beforeSave() {
e.mu.Lock()
defer e.mu.Unlock()
- switch e.state {
+ switch e.EndpointState() {
case StateInitial, StateBound:
// TODO(b/138137272): this enumeration duplicates
// EndpointState.connected. remove it.
@@ -70,31 +71,30 @@ func (e *endpoint) beforeSave() {
fallthrough
case StateListen, StateConnecting:
e.drainSegmentLocked()
- if e.state != StateClose && e.state != StateError {
+ if e.EndpointState() != StateClose && e.EndpointState() != StateError {
if !e.workerRunning {
panic("endpoint has no worker running in listen, connecting, or connected state")
}
break
}
- fallthrough
case StateError, StateClose:
- for (e.state == StateError || e.state == StateClose) && e.workerRunning {
+ for e.workerRunning {
e.mu.Unlock()
time.Sleep(100 * time.Millisecond)
e.mu.Lock()
}
if e.workerRunning {
- panic("endpoint still has worker running in closed or error state")
+ panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.ID))
}
default:
- panic(fmt.Sprintf("endpoint in unknown state %v", e.state))
+ panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState()))
}
if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() {
panic("endpoint still has waiters upon save")
}
- if e.state != StateClose && !((e.state == StateBound || e.state == StateListen) == e.isPortReserved) {
+ if e.EndpointState() != StateClose && !((e.EndpointState() == StateBound || e.EndpointState() == StateListen) == e.isPortReserved) {
panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state")
}
}
@@ -135,7 +135,7 @@ func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) {
// saveState is invoked by stateify.
func (e *endpoint) saveState() EndpointState {
- return e.state
+ return e.EndpointState()
}
// Endpoint loading must be done in the following ordering by their state, to
@@ -151,7 +151,8 @@ var connectingLoading sync.WaitGroup
func (e *endpoint) loadState(state EndpointState) {
// This is to ensure that the loading wait groups include all applicable
// endpoints before any asynchronous calls to the Wait() methods.
- if state.connected() {
+ // For restore purposes we treat TimeWait like a connected endpoint.
+ if state.connected() || state == StateTimeWait {
connectedLoading.Add(1)
}
switch state {
@@ -160,13 +161,14 @@ func (e *endpoint) loadState(state EndpointState) {
case StateConnecting, StateSynSent, StateSynRecv:
connectingLoading.Add(1)
}
- e.state = state
+ // Directly update the state here rather than using e.setEndpointState
+ // as the endpoint is still being loaded and the stack reference to increment
+ // metrics is not yet initialized.
+ atomic.StoreUint32((*uint32)(&e.state), uint32(state))
}
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
- // Freeze segment queue before registering to prevent any segments
- // from being delivered while it is being restored.
e.origEndpointState = e.state
// Restore the endpoint to InitialState as it will be moved to
// its origEndpointState during Resume.
@@ -180,7 +182,6 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.segmentQueue.setLimit(MaxUnprocessedSegments)
e.workMu.Init()
state := e.origEndpointState
-
switch state {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
var ss SendBufferSizeOption
@@ -276,7 +277,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
listenLoading.Wait()
connectingLoading.Wait()
bind()
- e.state = StateClose
+ e.setEndpointState(StateClose)
tcpip.AsyncLoading.Done()
}()
}
@@ -288,6 +289,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
+
}
// saveLastError is invoked by stateify.
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 4983bca81..7eb613be5 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -15,8 +15,7 @@
package tcp
import (
- "sync"
-
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index bc718064c..958c06fa7 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -21,10 +21,11 @@
package tcp
import (
+ "runtime"
"strings"
- "sync"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -104,6 +105,7 @@ type protocol struct {
moderateReceiveBuffer bool
tcpLingerTimeout time.Duration
tcpTimeWaitTimeout time.Duration
+ dispatcher *dispatcher
}
// Number returns the tcp protocol number.
@@ -134,6 +136,14 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
return h.SourcePort(), h.DestinationPort(), nil
}
+// QueuePacket queues packets targeted at an endpoint after hashing the packet
+// to a specific processing queue. Each queue is serviced by its own processor
+// goroutine which is responsible for dequeuing and doing full TCP dispatch of
+// the packet.
+func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
+ p.dispatcher.queuePacket(r, ep, id, pkt)
+}
+
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
//
@@ -330,5 +340,6 @@ func NewProtocol() stack.TransportProtocol {
availableCongestionControl: []string{ccReno, ccCubic},
tcpLingerTimeout: DefaultTCPLingerTimeout,
tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout,
+ dispatcher: newDispatcher(runtime.GOMAXPROCS(0)),
}
}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 0a5534959..958f03ac1 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -98,12 +98,6 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// in such cases we may need to send an ack to indicate to our peer that it can
// resume sending data.
func (r *receiver) nonZeroWindow() {
- if (r.rcvAcc-r.rcvNxt)>>r.rcvWndScale != 0 {
- // We never got around to announcing a zero window size, so we
- // don't need to immediately announce a nonzero one.
- return
- }
-
// Immediately send an ack.
r.ep.snd.sendAck()
}
@@ -175,19 +169,19 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// We just received a FIN, our next state depends on whether we sent a
// FIN already or not.
r.ep.mu.Lock()
- switch r.ep.state {
+ switch r.ep.EndpointState() {
case StateEstablished:
- r.ep.state = StateCloseWait
+ r.ep.setEndpointState(StateCloseWait)
case StateFinWait1:
if s.flagIsSet(header.TCPFlagAck) {
// FIN-ACK, transition to TIME-WAIT.
- r.ep.state = StateTimeWait
+ r.ep.setEndpointState(StateTimeWait)
} else {
// Simultaneous close, expecting a final ACK.
- r.ep.state = StateClosing
+ r.ep.setEndpointState(StateClosing)
}
case StateFinWait2:
- r.ep.state = StateTimeWait
+ r.ep.setEndpointState(StateTimeWait)
}
r.ep.mu.Unlock()
@@ -211,16 +205,16 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// shutdown states.
if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt {
r.ep.mu.Lock()
- switch r.ep.state {
+ switch r.ep.EndpointState() {
case StateFinWait1:
- r.ep.state = StateFinWait2
+ r.ep.setEndpointState(StateFinWait2)
// Notify protocol goroutine that we have received an
// ACK to our FIN so that it can start the FIN_WAIT2
// timer to abort connection if the other side does
// not close within 2MSL.
r.ep.notifyProtocolGoroutine(notifyClose)
case StateClosing:
- r.ep.state = StateTimeWait
+ r.ep.setEndpointState(StateTimeWait)
case StateLastAck:
r.ep.transitionToStateCloseLocked()
}
@@ -273,7 +267,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
switch state {
case StateCloseWait, StateClosing, StateLastAck:
if !s.sequenceNumber.LessThanEq(r.rcvNxt) {
- s.decRef()
// Just drop the segment as we have
// already received a FIN and this
// segment is after the sequence number
@@ -290,7 +283,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// trigger a RST.
endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size()))
if rcvClosed && r.rcvNxt.LessThan(endDataSeq) {
- s.decRef()
return true, tcpip.ErrConnectionAborted
}
if state == StateFinWait1 {
@@ -320,7 +312,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// the last actual data octet in a segment in
// which it occurs.
if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) {
- s.decRef()
return true, tcpip.ErrConnectionAborted
}
}
@@ -342,7 +333,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo
// r as they arrive. It is called by the protocol main loop.
func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
r.ep.mu.RLock()
- state := r.ep.state
+ state := r.ep.EndpointState()
closed := r.ep.closed
r.ep.mu.RUnlock()
diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go
index e0759225e..bd20a7ee9 100644
--- a/pkg/tcpip/transport/tcp/segment_queue.go
+++ b/pkg/tcpip/transport/tcp/segment_queue.go
@@ -15,7 +15,7 @@
package tcp
import (
- "sync"
+ "gvisor.dev/gvisor/pkg/sync"
)
// segmentQueue is a bounded, thread-safe queue of TCP segments.
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 8a947dc66..b74b61e7d 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -16,11 +16,11 @@ package tcp
import (
"math"
- "sync"
"sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -442,6 +442,13 @@ func (s *sender) retransmitTimerExpired() bool {
return true
}
+ // TODO(b/147297758): Band-aid fix, retransmitTimer can fire in some edge cases
+ // when writeList is empty. Remove this once we have a proper fix for this
+ // issue.
+ if s.writeList.Front() == nil {
+ return true
+ }
+
s.ep.stack.Stats().TCP.Timeouts.Increment()
s.ep.stats.SendErrors.Timeouts.Increment()
@@ -698,17 +705,15 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
}
seg.flags = header.TCPFlagAck | header.TCPFlagFin
segEnd = seg.sequenceNumber.Add(1)
- // Transition to FIN-WAIT1 state since we're initiating an active close.
- s.ep.mu.Lock()
- switch s.ep.state {
+ // Update the state to reflect that we have now
+ // queued a FIN.
+ switch s.ep.EndpointState() {
case StateCloseWait:
- // We've already received a FIN and are now sending our own. The
- // sender is now awaiting a final ACK for this FIN.
- s.ep.state = StateLastAck
+ s.ep.setEndpointState(StateLastAck)
default:
- s.ep.state = StateFinWait1
+ s.ep.setEndpointState(StateFinWait1)
}
- s.ep.mu.Unlock()
+
} else {
// We're sending a non-FIN segment.
if seg.flags&header.TCPFlagFin != 0 {
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 15745ebd4..a9dfbe857 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -293,7 +293,6 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
checker.SeqNum(uint32(c.IRS+1)),
checker.AckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
-
finHeaders := &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -459,6 +458,9 @@ func TestConnectResetAfterClose(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
+ // RST is always generated with sndNxt which if the FIN
+ // has been sent will be 1 higher than the sequence number
+ // of the FIN itself.
checker.SeqNum(uint32(c.IRS)+2),
checker.AckNum(0),
checker.TCPFlags(header.TCPFlagRst),
@@ -1083,12 +1085,12 @@ func TestTrafficClassV6(t *testing.T) {
func TestConnectBindToDevice(t *testing.T) {
for _, test := range []struct {
name string
- device string
+ device tcpip.NICID
want tcp.EndpointState
}{
- {"RightDevice", "nic1", tcp.StateEstablished},
- {"WrongDevice", "nic2", tcp.StateSynSent},
- {"AnyDevice", "", tcp.StateEstablished},
+ {"RightDevice", 1, tcp.StateEstablished},
+ {"WrongDevice", 2, tcp.StateSynSent},
+ {"AnyDevice", 0, tcp.StateEstablished},
} {
t.Run(test.name, func(t *testing.T) {
c := context.New(t, defaultMTU)
@@ -1500,6 +1502,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ // RST is always generated with sndNxt which if the FIN
+ // has been sent will be 1 higher than the sequence
+ // number of the FIN itself.
checker.SeqNum(uint32(c.IRS)+2),
))
// The RST puts the endpoint into an error state.
@@ -2091,10 +2096,14 @@ func TestZeroScaledWindowReceive(t *testing.T) {
)
}
- // Read some data. An ack should be sent in response to that.
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %v", err)
+ // Read at least 1MSS of data. An ack should be sent in response to that.
+ sz := 0
+ for sz < defaultMTU {
+ v, _, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Read failed: %v", err)
+ }
+ sz += len(v)
}
checker.IPv4(t, c.GetPacket(),
@@ -2103,7 +2112,7 @@ func TestZeroScaledWindowReceive(t *testing.T) {
checker.DstPort(context.TestPort),
checker.SeqNum(uint32(c.IRS)+1),
checker.AckNum(uint32(790+sent)),
- checker.Window(uint16(len(v)>>ws)),
+ checker.Window(uint16(sz>>ws)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -3794,47 +3803,41 @@ func TestBindToDeviceOption(t *testing.T) {
}
defer ep.Close()
- opts := stack.NICOptions{Name: "my_device"}
- if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil {
- t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
- }
-
- // Make an nameless NIC.
- if err := s.CreateNIC(54321, loopback.New()); err != nil {
+ if err := s.CreateNIC(321, loopback.New()); err != nil {
t.Errorf("CreateNIC failed: %v", err)
}
- // strPtr is used instead of taking the address of string literals, which is
+ // nicIDPtr is used instead of taking the address of NICID literals, which is
// a compiler error.
- strPtr := func(s string) *string {
+ nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
return &s
}
testActions := []struct {
name string
- setBindToDevice *string
+ setBindToDevice *tcpip.NICID
setBindToDeviceError *tcpip.Error
getBindToDevice tcpip.BindToDeviceOption
}{
- {"GetDefaultValue", nil, nil, ""},
- {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
- {"BindToExistent", strPtr("my_device"), nil, "my_device"},
- {"UnbindToDevice", strPtr(""), nil, ""},
+ {"GetDefaultValue", nil, nil, 0},
+ {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
+ {"BindToExistent", nicIDPtr(321), nil, 321},
+ {"UnbindToDevice", nicIDPtr(0), nil, 0},
}
for _, testAction := range testActions {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
- t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr)
}
}
- bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
- if ep.GetSockOpt(&bindToDevice) != nil {
- t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ bindToDevice := tcpip.BindToDeviceOption(88888)
+ if err := ep.GetSockOpt(&bindToDevice); err != nil {
+ t.Errorf("GetSockOpt got %v, want %v", err, nil)
}
if got, want := bindToDevice, testAction.getBindToDevice; got != want {
- t.Errorf("bindToDevice got %q, want %q", got, want)
+ t.Errorf("bindToDevice got %d, want %d", got, want)
}
})
}
@@ -5443,6 +5446,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
packetsSent++
}
+
// Resume the worker so that it only sees the packets once all of them
// are waiting to be read.
worker.ResumeWork()
@@ -5510,7 +5514,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
stk := c.Stack()
// Set lower limits for auto-tuning tests. This is required because the
// test stops the worker which can cause packets to be dropped because
- // the segment queue holding unprocessed packets is limited to 500.
+ // the segment queue holding unprocessed packets is limited to 300.
const receiveBufferSize = 80 << 10 // 80KB.
const maxReceiveBufferSize = receiveBufferSize * 10
if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil {
@@ -5565,6 +5569,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
totalSent += mss
packetsSent++
}
+
// Resume it so that it only sees the packets once all of them
// are waiting to be read.
worker.ResumeWork()
@@ -6562,3 +6567,140 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want)
}
}
+
+func TestIncreaseWindowOnReceive(t *testing.T) {
+ // This test ensures that the endpoint sends an ack,
+ // after recv() when the window grows to more than 1 MSS.
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ const rcvBuf = 65535 * 10
+ c.CreateConnected(789, 30000, rcvBuf)
+
+ // Write chunks of ~30000 bytes. It's important that two
+ // payloads make it equal or longer than MSS.
+ remain := rcvBuf
+ sent := 0
+ data := make([]byte, defaultMTU/2)
+ lastWnd := uint16(0)
+
+ for remain > len(data) {
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790 + sent),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+ sent += len(data)
+ remain -= len(data)
+
+ lastWnd = uint16(remain)
+ if remain > 0xffff {
+ lastWnd = 0xffff
+ }
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(lastWnd),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ }
+
+ if lastWnd == 0xffff || lastWnd == 0 {
+ t.Fatalf("expected small, non-zero window: %d", lastWnd)
+ }
+
+ // We now have < 1 MSS in the buffer space. Read the data! An
+ // ack should be sent in response to that. The window was not
+ // zero, but it grew to larger than MSS.
+ if _, _, err := c.EP.Read(nil); err != nil {
+ t.Fatalf("Read failed: %v", err)
+ }
+
+ if _, _, err := c.EP.Read(nil); err != nil {
+ t.Fatalf("Read failed: %v", err)
+ }
+
+ // After reading two packets, we surely crossed MSS. See the ack:
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(uint16(0xffff)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func TestIncreaseWindowOnBufferResize(t *testing.T) {
+ // This test ensures that the endpoint sends an ack,
+ // after available recv buffer grows to more than 1 MSS.
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ const rcvBuf = 65535 * 10
+ c.CreateConnected(789, 30000, rcvBuf)
+
+ // Write chunks of ~30000 bytes. It's important that two
+ // payloads make it equal or longer than MSS.
+ remain := rcvBuf
+ sent := 0
+ data := make([]byte, defaultMTU/2)
+ lastWnd := uint16(0)
+
+ for remain > len(data) {
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790 + sent),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+ sent += len(data)
+ remain -= len(data)
+
+ lastWnd = uint16(remain)
+ if remain > 0xffff {
+ lastWnd = 0xffff
+ }
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(lastWnd),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ }
+
+ if lastWnd == 0xffff || lastWnd == 0 {
+ t.Fatalf("expected small, non-zero window: %d", lastWnd)
+ }
+
+ // Increasing the buffer from should generate an ACK,
+ // since window grew from small value to larger equal MSS
+ c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2)
+
+ // After reading two packets, we surely crossed MSS. See the ack:
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(uint16(0xffff)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index 97e4d5825..57ff123e3 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -30,6 +30,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/sleep",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 1a5ee6317..c9cbed8f4 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,8 +15,7 @@
package udp
import (
- "sync"
-
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -32,6 +31,7 @@ type udpPacket struct {
senderAddress tcpip.FullAddress
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
timestamp int64
+ tos uint8
}
// EndpointState represents the state of a UDP endpoint.
@@ -114,6 +114,10 @@ type endpoint struct {
// applied while sending packets. Defaults to 0 as on Linux.
sendTOS uint8
+ // receiveTOS determines if the incoming IPv4 TOS header field is passed
+ // as ancillary data to ControlMessages on Read.
+ receiveTOS bool
+
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
@@ -244,7 +248,18 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
*addr = p.senderAddress
}
- return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil
+ cm := tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: p.timestamp,
+ }
+ e.mu.RLock()
+ receiveTOS := e.receiveTOS
+ e.mu.RUnlock()
+ if receiveTOS {
+ cm.HasTOS = true
+ cm.TOS = p.tos
+ }
+ return p.data.ToView(), cm, nil
}
// prepareForWrite prepares the endpoint for sending data. In particular, it
@@ -403,7 +418,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, tcpip.ErrBroadcastDisabled
}
- netProto, err := e.checkV4Mapped(to, false)
+ netProto, err := e.checkV4Mapped(to)
if err != nil {
return 0, nil, err
}
@@ -459,6 +474,12 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
switch opt {
+ case tcpip.ReceiveTOSOption:
+ e.mu.Lock()
+ e.receiveTOS = v
+ e.mu.Unlock()
+ return nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
@@ -502,7 +523,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
defer e.mu.Unlock()
fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
- netProto, err := e.checkV4Mapped(&fa, false)
+ netProto, err := e.checkV4Mapped(&fa)
if err != nil {
return err
}
@@ -631,19 +652,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
case tcpip.BindToDeviceOption:
- e.mu.Lock()
- defer e.mu.Unlock()
- if v == "" {
- e.bindToDevice = 0
- return nil
- }
- for nicID, nic := range e.stack.NICInfo() {
- if nic.Name == string(v) {
- e.bindToDevice = nicID
- return nil
- }
+ id := tcpip.NICID(v)
+ if id != 0 && !e.stack.HasNIC(id) {
+ return tcpip.ErrUnknownDevice
}
- return tcpip.ErrUnknownDevice
+ e.mu.Lock()
+ e.bindToDevice = id
+ e.mu.Unlock()
+ return nil
case tcpip.BroadcastOption:
e.mu.Lock()
@@ -670,15 +686,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
switch opt {
+ case tcpip.ReceiveTOSOption:
+ e.mu.RLock()
+ v := e.receiveTOS
+ e.mu.RUnlock()
+ return v, nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
return false, tcpip.ErrUnknownProtocolOption
}
- e.mu.Lock()
+ e.mu.RLock()
v := e.v6only
- e.mu.Unlock()
+ e.mu.RUnlock()
return v, nil
}
@@ -767,12 +789,8 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case *tcpip.BindToDeviceOption:
e.mu.RLock()
- defer e.mu.RUnlock()
- if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
- *o = tcpip.BindToDeviceOption(nic.Name)
- return nil
- }
- *o = tcpip.BindToDeviceOption("")
+ *o = tcpip.BindToDeviceOption(e.bindToDevice)
+ e.mu.RUnlock()
return nil
case *tcpip.KeepaliveEnabledOption:
@@ -849,35 +867,12 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
return nil
}
-func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := e.NetProto
- if len(addr.Addr) == 0 {
- return netProto, nil
- }
- if header.IsV4MappedAddress(addr.Addr) {
- // Fail if using a v4 mapped address on a v6only endpoint.
- if e.v6only {
- return 0, tcpip.ErrNoRoute
- }
-
- netProto = header.IPv4ProtocolNumber
- addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
- if addr.Addr == header.IPv4Any {
- addr.Addr = ""
- }
-
- // Fail if we are bound to an IPv6 address.
- if !allowMismatch && len(e.ID.LocalAddress) == 16 {
- return 0, tcpip.ErrNetworkUnreachable
- }
- }
-
- // Fail if we're bound to an address length different from the one we're
- // checking.
- if l := len(e.ID.LocalAddress); l != 0 && l != len(addr.Addr) {
- return 0, tcpip.ErrInvalidEndpointState
+func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only)
+ if err != nil {
+ return 0, err
}
-
+ *addr = unwrapped
return netProto, nil
}
@@ -926,7 +921,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- netProto, err := e.checkV4Mapped(&addr, false)
+ netProto, err := e.checkV4Mapped(&addr)
if err != nil {
return err
}
@@ -1084,7 +1079,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- netProto, err := e.checkV4Mapped(&addr, true)
+ netProto, err := e.checkV4Mapped(&addr)
if err != nil {
return err
}
@@ -1248,6 +1243,12 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
e.rcvList.PushBack(packet)
e.rcvBufSize += pkt.Data.Size()
+ // Save any useful information from the network header to the packet.
+ switch r.NetProto {
+ case header.IPv4ProtocolNumber:
+ packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS()
+ }
+
packet.timestamp = e.stack.NowNanoseconds()
e.rcvMu.Unlock()
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 149fff999..ee9d10555 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -56,6 +56,7 @@ const (
multicastAddr = "\xe8\x2b\xd3\xea"
multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
broadcastAddr = header.IPv4Broadcast
+ testTOS = 0x80
// defaultMTU is the MTU, in bytes, used throughout the tests, except
// where another value is explicitly used. It is chosen to match the MTU
@@ -453,6 +454,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool
ip := header.IPv4(buf)
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
+ TOS: testTOS,
TotalLength: uint16(len(buf)),
TTL: 65,
Protocol: uint8(udp.ProtocolNumber),
@@ -513,42 +515,37 @@ func TestBindToDeviceOption(t *testing.T) {
t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
}
- // Make an nameless NIC.
- if err := s.CreateNIC(54321, loopback.New()); err != nil {
- t.Errorf("CreateNIC failed: %v", err)
- }
-
- // strPtr is used instead of taking the address of string literals, which is
+ // nicIDPtr is used instead of taking the address of NICID literals, which is
// a compiler error.
- strPtr := func(s string) *string {
+ nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
return &s
}
testActions := []struct {
name string
- setBindToDevice *string
+ setBindToDevice *tcpip.NICID
setBindToDeviceError *tcpip.Error
getBindToDevice tcpip.BindToDeviceOption
}{
- {"GetDefaultValue", nil, nil, ""},
- {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
- {"BindToExistent", strPtr("my_device"), nil, "my_device"},
- {"UnbindToDevice", strPtr(""), nil, ""},
+ {"GetDefaultValue", nil, nil, 0},
+ {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
+ {"BindToExistent", nicIDPtr(321), nil, 321},
+ {"UnbindToDevice", nicIDPtr(0), nil, 0},
}
for _, testAction := range testActions {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
- t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr)
}
}
- bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
- if ep.GetSockOpt(&bindToDevice) != nil {
- t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ bindToDevice := tcpip.BindToDeviceOption(88888)
+ if err := ep.GetSockOpt(&bindToDevice); err != nil {
+ t.Errorf("GetSockOpt got %v, want %v", err, nil)
}
if got, want := bindToDevice, testAction.getBindToDevice; got != want {
- t.Errorf("bindToDevice got %q, want %q", got, want)
+ t.Errorf("bindToDevice got %d, want %d", got, want)
}
})
}
@@ -557,8 +554,8 @@ func TestBindToDeviceOption(t *testing.T) {
// testReadInternal sends a packet of the given test flow into the stack by
// injecting it into the link endpoint. It then attempts to read it from the
// UDP endpoint and depending on if this was expected to succeed verifies its
-// correctness.
-func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) {
+// correctness including any additional checker functions provided.
+func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) {
c.t.Helper()
payload := newPayload()
@@ -573,12 +570,12 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
var addr tcpip.FullAddress
- v, _, err := c.ep.Read(&addr)
+ v, cm, err := c.ep.Read(&addr)
if err == tcpip.ErrWouldBlock {
// Wait for data to become available.
select {
case <-ch:
- v, _, err = c.ep.Read(&addr)
+ v, cm, err = c.ep.Read(&addr)
case <-time.After(300 * time.Millisecond):
if packetShouldBeDropped {
@@ -611,15 +608,21 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
if !bytes.Equal(payload, v) {
c.t.Fatalf("bad payload: got %x, want %x", v, payload)
}
+
+ // Run any checkers against the ControlMessages.
+ for _, f := range checkers {
+ f(c.t, cm)
+ }
+
c.checkEndpointReadStats(1, epstats, err)
}
// testRead sends a packet of the given test flow into the stack by injecting it
// into the link endpoint. It then reads it from the UDP endpoint and verifies
-// its correctness.
-func testRead(c *testContext, flow testFlow) {
+// its correctness including any additional checker functions provided.
+func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) {
c.t.Helper()
- testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */)
+ testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...)
}
// testFailingRead sends a packet of the given test flow into the stack by
@@ -1287,7 +1290,7 @@ func TestTOSV4(t *testing.T) {
c.createEndpointForFlow(flow)
- const tos = 0xC0
+ const tos = testTOS
var v tcpip.IPv4TOSOption
if err := c.ep.GetSockOpt(&v); err != nil {
c.t.Errorf("GetSockopt failed: %s", err)
@@ -1322,7 +1325,7 @@ func TestTOSV6(t *testing.T) {
c.createEndpointForFlow(flow)
- const tos = 0xC0
+ const tos = testTOS
var v tcpip.IPv6TrafficClassOption
if err := c.ep.GetSockOpt(&v); err != nil {
c.t.Errorf("GetSockopt failed: %s", err)
@@ -1349,6 +1352,47 @@ func TestTOSV6(t *testing.T) {
}
}
+func TestReceiveTOSV4(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, broadcast} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Verify that setting and reading the option works.
+ v, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption)
+ if err != nil {
+ c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err)
+ }
+ // Test for expected default value.
+ if v != false {
+ c.t.Errorf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", v, false)
+ }
+
+ want := true
+ if err := c.ep.SetSockOptBool(tcpip.ReceiveTOSOption, want); err != nil {
+ c.t.Fatalf("SetSockOptBool(tcpip.ReceiveTOSOption, %t) failed: %s", want, err)
+ }
+
+ got, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption)
+ if err != nil {
+ c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err)
+ }
+ if got != want {
+ c.t.Fatalf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", got, want)
+ }
+
+ // Verify that the correct received TOS is handed through as
+ // ancillary data to the ControlMessages struct.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatal("Bind failed:", err)
+ }
+ testRead(c, flow, checker.ReceiveTOS(testTOS))
+ })
+ }
+}
+
func TestMulticastInterfaceOption(t *testing.T) {
for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD
index 6afdb29b7..07778e4f7 100644
--- a/pkg/tmutex/BUILD
+++ b/pkg/tmutex/BUILD
@@ -15,4 +15,5 @@ go_test(
size = "medium",
srcs = ["tmutex_test.go"],
embed = [":tmutex"],
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/tmutex/tmutex_test.go b/pkg/tmutex/tmutex_test.go
index ce34c7962..05540696a 100644
--- a/pkg/tmutex/tmutex_test.go
+++ b/pkg/tmutex/tmutex_test.go
@@ -17,10 +17,11 @@ package tmutex
import (
"fmt"
"runtime"
- "sync"
"sync/atomic"
"testing"
"time"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
func TestBasicLock(t *testing.T) {
diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD
index 8f6f180e5..d1885ae66 100644
--- a/pkg/unet/BUILD
+++ b/pkg/unet/BUILD
@@ -24,4 +24,5 @@ go_test(
"unet_test.go",
],
embed = [":unet"],
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/unet/unet_test.go b/pkg/unet/unet_test.go
index a3cc6f5d3..5c4b9e8e9 100644
--- a/pkg/unet/unet_test.go
+++ b/pkg/unet/unet_test.go
@@ -19,10 +19,11 @@ import (
"os"
"path/filepath"
"reflect"
- "sync"
"syscall"
"testing"
"time"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
func randomFilename() (string, error) {
diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD
index b6bbb0ea2..b8fdc3125 100644
--- a/pkg/urpc/BUILD
+++ b/pkg/urpc/BUILD
@@ -11,6 +11,7 @@ go_library(
deps = [
"//pkg/fd",
"//pkg/log",
+ "//pkg/sync",
"//pkg/unet",
],
)
diff --git a/pkg/urpc/urpc.go b/pkg/urpc/urpc.go
index df59ffab1..13b2ea314 100644
--- a/pkg/urpc/urpc.go
+++ b/pkg/urpc/urpc.go
@@ -27,10 +27,10 @@ import (
"os"
"reflect"
"runtime"
- "sync"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD
index 0427bc41f..1c6890e52 100644
--- a/pkg/waiter/BUILD
+++ b/pkg/waiter/BUILD
@@ -24,6 +24,7 @@ go_library(
],
importpath = "gvisor.dev/gvisor/pkg/waiter",
visibility = ["//visibility:public"],
+ deps = ["//pkg/sync"],
)
go_test(
diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go
index 8a65ed164..f708e95fa 100644
--- a/pkg/waiter/waiter.go
+++ b/pkg/waiter/waiter.go
@@ -58,7 +58,7 @@
package waiter
import (
- "sync"
+ "gvisor.dev/gvisor/pkg/sync"
)
// EventMask represents io events as used in the poll() syscall.
diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD
index 6226b63f8..3e20f8f2f 100644
--- a/runsc/boot/BUILD
+++ b/runsc/boot/BUILD
@@ -74,6 +74,7 @@ go_library(
"//pkg/sentry/usage",
"//pkg/sentry/usermem",
"//pkg/sentry/watchdog",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/tcpip",
"//pkg/tcpip/link/fdbased",
@@ -114,6 +115,7 @@ go_test(
"//pkg/sentry/context/contexttest",
"//pkg/sentry/fs",
"//pkg/sentry/kernel/auth",
+ "//pkg/sync",
"//pkg/unet",
"//runsc/fsgofer",
"@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
diff --git a/runsc/boot/compat.go b/runsc/boot/compat.go
index 352e710d2..9c23b9553 100644
--- a/runsc/boot/compat.go
+++ b/runsc/boot/compat.go
@@ -17,7 +17,6 @@ package boot
import (
"fmt"
"os"
- "sync"
"syscall"
"github.com/golang/protobuf/proto"
@@ -27,6 +26,7 @@ import (
ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto"
"gvisor.dev/gvisor/pkg/sentry/strace"
spb "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto"
+ "gvisor.dev/gvisor/pkg/sync"
)
func initCompatLogs(fd int) error {
diff --git a/runsc/boot/limits.go b/runsc/boot/limits.go
index d1c0bb9b5..ce62236e5 100644
--- a/runsc/boot/limits.go
+++ b/runsc/boot/limits.go
@@ -16,12 +16,12 @@ package boot
import (
"fmt"
- "sync"
"syscall"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Mapping from linux resource names to limits.LimitType.
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index bc1d0c1bb..fad72f4ab 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -20,7 +20,6 @@ import (
mrand "math/rand"
"os"
"runtime"
- "sync"
"sync/atomic"
"syscall"
gtime "time"
@@ -46,6 +45,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/watchdog"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go
index 147ff7703..bec0dc292 100644
--- a/runsc/boot/loader_test.go
+++ b/runsc/boot/loader_test.go
@@ -19,7 +19,6 @@ import (
"math/rand"
"os"
"reflect"
- "sync"
"syscall"
"testing"
"time"
@@ -30,6 +29,7 @@ import (
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/sentry/context/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
"gvisor.dev/gvisor/runsc/fsgofer"
)
diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD
index 250845ad7..b94bc4fa0 100644
--- a/runsc/cmd/BUILD
+++ b/runsc/cmd/BUILD
@@ -44,6 +44,7 @@ go_library(
"//pkg/sentry/control",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
+ "//pkg/sync",
"//pkg/unet",
"//pkg/urpc",
"//runsc/boot",
diff --git a/runsc/cmd/create.go b/runsc/cmd/create.go
index a4e3071b3..1815c93b9 100644
--- a/runsc/cmd/create.go
+++ b/runsc/cmd/create.go
@@ -16,6 +16,7 @@ package cmd
import (
"context"
+
"flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/runsc/boot"
diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go
index 4831210c0..7df7995f0 100644
--- a/runsc/cmd/gofer.go
+++ b/runsc/cmd/gofer.go
@@ -21,7 +21,6 @@ import (
"os"
"path/filepath"
"strings"
- "sync"
"syscall"
"flag"
@@ -30,6 +29,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/fsgofer"
diff --git a/runsc/cmd/start.go b/runsc/cmd/start.go
index de2115dff..5e9bc53ab 100644
--- a/runsc/cmd/start.go
+++ b/runsc/cmd/start.go
@@ -16,6 +16,7 @@ package cmd
import (
"context"
+
"flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/runsc/boot"
diff --git a/runsc/container/BUILD b/runsc/container/BUILD
index 2bd12120d..6dea179e4 100644
--- a/runsc/container/BUILD
+++ b/runsc/container/BUILD
@@ -18,6 +18,7 @@ go_library(
deps = [
"//pkg/log",
"//pkg/sentry/control",
+ "//pkg/sync",
"//runsc/boot",
"//runsc/cgroup",
"//runsc/sandbox",
@@ -53,6 +54,7 @@ go_test(
"//pkg/sentry/control",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
+ "//pkg/sync",
"//pkg/unet",
"//pkg/urpc",
"//runsc/boot",
diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go
index 5ed131a7f..060b63bf3 100644
--- a/runsc/container/console_test.go
+++ b/runsc/container/console_test.go
@@ -20,7 +20,6 @@ import (
"io"
"os"
"path/filepath"
- "sync"
"syscall"
"testing"
"time"
@@ -29,6 +28,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
"gvisor.dev/gvisor/pkg/urpc"
"gvisor.dev/gvisor/runsc/testutil"
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index c10f85992..b54d8f712 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -26,7 +26,6 @@ import (
"reflect"
"strconv"
"strings"
- "sync"
"syscall"
"testing"
"time"
@@ -39,6 +38,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/boot/platforms"
"gvisor.dev/gvisor/runsc/specutils"
diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go
index 4ad09ceab..2da93ec5b 100644
--- a/runsc/container/multi_container_test.go
+++ b/runsc/container/multi_container_test.go
@@ -22,7 +22,6 @@ import (
"path"
"path/filepath"
"strings"
- "sync"
"syscall"
"testing"
"time"
@@ -30,6 +29,7 @@ import (
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/specutils"
"gvisor.dev/gvisor/runsc/testutil"
diff --git a/runsc/container/state_file.go b/runsc/container/state_file.go
index d95151ea5..17a251530 100644
--- a/runsc/container/state_file.go
+++ b/runsc/container/state_file.go
@@ -20,10 +20,10 @@ import (
"io/ioutil"
"os"
"path/filepath"
- "sync"
"github.com/gofrs/flock"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
)
const stateFileExtension = ".state"
diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD
index afcb41801..a9582d92b 100644
--- a/runsc/fsgofer/BUILD
+++ b/runsc/fsgofer/BUILD
@@ -19,6 +19,7 @@ go_library(
"//pkg/fd",
"//pkg/log",
"//pkg/p9",
+ "//pkg/sync",
"//pkg/syserr",
"//runsc/specutils",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go
index b59e1a70e..4d84ad999 100644
--- a/runsc/fsgofer/fsgofer.go
+++ b/runsc/fsgofer/fsgofer.go
@@ -29,7 +29,6 @@ import (
"path/filepath"
"runtime"
"strconv"
- "sync"
"syscall"
"golang.org/x/sys/unix"
@@ -37,6 +36,7 @@ import (
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/runsc/specutils"
)
@@ -767,6 +767,16 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error {
return err
}
+// TODO(b/127675828): support getxattr.
+func (l *localFile) GetXattr(name string, size uint64) (string, error) {
+ return "", syscall.EOPNOTSUPP
+}
+
+// TODO(b/127675828): support setxattr.
+func (l *localFile) SetXattr(name, value string, flags uint32) error {
+ return syscall.EOPNOTSUPP
+}
+
// Allocate implements p9.File.
func (l *localFile) Allocate(mode p9.AllocateMode, offset, length uint64) error {
if !l.isOpen() {
diff --git a/runsc/sandbox/BUILD b/runsc/sandbox/BUILD
index 8001949d5..ddbc37456 100644
--- a/runsc/sandbox/BUILD
+++ b/runsc/sandbox/BUILD
@@ -19,6 +19,7 @@ go_library(
"//pkg/log",
"//pkg/sentry/control",
"//pkg/sentry/platform",
+ "//pkg/sync",
"//pkg/tcpip/header",
"//pkg/tcpip/stack",
"//pkg/urpc",
diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go
index be8b72b3e..ff48f5646 100644
--- a/runsc/sandbox/network.go
+++ b/runsc/sandbox/network.go
@@ -321,16 +321,21 @@ func createSocket(iface net.Interface, ifaceLink netlink.Link, enableGSO bool) (
}
}
- // Use SO_RCVBUFFORCE because on linux the receive buffer for an
- // AF_PACKET socket is capped by "net.core.rmem_max". rmem_max
- // defaults to a unusually low value of 208KB. This is too low
- // for gVisor to be able to receive packets at high throughputs
- // without incurring packet drops.
- const rcvBufSize = 4 << 20 // 4MB.
-
- if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUFFORCE, rcvBufSize); err != nil {
- return nil, fmt.Errorf("failed to increase socket rcv buffer to %d: %v", rcvBufSize, err)
+ // Use SO_RCVBUFFORCE/SO_SNDBUFFORCE because on linux the receive/send buffer
+ // for an AF_PACKET socket is capped by "net.core.rmem_max/wmem_max".
+ // wmem_max/rmem_max default to a unusually low value of 208KB. This is too low
+ // for gVisor to be able to receive packets at high throughputs without
+ // incurring packet drops.
+ const bufSize = 4 << 20 // 4MB.
+
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUFFORCE, bufSize); err != nil {
+ return nil, fmt.Errorf("failed to increase socket rcv buffer to %d: %v", bufSize, err)
}
+
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUFFORCE, bufSize); err != nil {
+ return nil, fmt.Errorf("failed to increase socket snd buffer to %d: %v", bufSize, err)
+ }
+
return &socketEntry{deviceFile, gsoMaxSize}, nil
}
diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go
index ce1452b87..ec72bdbfd 100644
--- a/runsc/sandbox/sandbox.go
+++ b/runsc/sandbox/sandbox.go
@@ -22,7 +22,6 @@ import (
"os"
"os/exec"
"strconv"
- "sync"
"syscall"
"time"
@@ -34,6 +33,7 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/urpc"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/boot/platforms"
diff --git a/runsc/testutil/BUILD b/runsc/testutil/BUILD
index c96ca2eb6..3c3027cb5 100644
--- a/runsc/testutil/BUILD
+++ b/runsc/testutil/BUILD
@@ -10,6 +10,7 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/log",
+ "//pkg/sync",
"//runsc/boot",
"//runsc/specutils",
"@com_github_cenkalti_backoff//:go_default_library",
diff --git a/runsc/testutil/testutil.go b/runsc/testutil/testutil.go
index 9632776d2..fb22eae39 100644
--- a/runsc/testutil/testutil.go
+++ b/runsc/testutil/testutil.go
@@ -34,7 +34,6 @@ import (
"path/filepath"
"strconv"
"strings"
- "sync"
"sync/atomic"
"syscall"
"time"
@@ -42,6 +41,7 @@ import (
"github.com/cenkalti/backoff"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/specutils"
)
diff --git a/scripts/common.sh b/scripts/common.sh
index 6dabad141..fdb1aa142 100755
--- a/scripts/common.sh
+++ b/scripts/common.sh
@@ -73,7 +73,7 @@ function install_runsc() {
sudo "${RUNSC_BIN}" install --experimental=true --runtime="${runtime}" -- --debug-log "${RUNSC_LOGS}" "$@"
# Clear old logs files that may exist.
- sudo rm -f "${RUNSC_LOGS_DIR}"/*
+ sudo rm -f "${RUNSC_LOGS_DIR}"/'*'
# Restart docker to pick up the new runtime configuration.
sudo systemctl restart docker
diff --git a/scripts/issue_reviver.sh b/scripts/issue_reviver.sh
new file mode 100755
index 000000000..bac9b9192
--- /dev/null
+++ b/scripts/issue_reviver.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+DIR=$(dirname $0)
+source "${DIR}"/common.sh
+
+# Provide a credential file if available.
+export OAUTH_TOKEN_FILE=""
+if [[ -v KOKORO_GITHUB_ACCESS_TOKEN ]]; then
+ OAUTH_TOKEN_FILE="${KOKORO_KEYSTORE_DIR}/${KOKORO_GITHUB_ACCESS_TOKEN}"
+fi
+
+REPO_ROOT=$(cd "$(dirname "${DIR}")"; pwd)
+run //tools/issue_reviver:issue_reviver --path "${REPO_ROOT}" --oauth-token-file="${OAUTH_TOKEN_FILE}"
diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go
index 1c04601df..4b8bbb093 100644
--- a/test/iptables/filter_input.go
+++ b/test/iptables/filter_input.go
@@ -125,7 +125,7 @@ func (FilterInputDropDifferentUDPPort) LocalAction(ip net.IP) error {
return sendUDPLoop(ip, acceptPort, sendloopDuration)
}
-// FilterInputDropTCP tests that connections are not accepted on specified source ports.
+// FilterInputDropTCPDestPort tests that connections are not accepted on specified source ports.
type FilterInputDropTCPDestPort struct{}
// Name implements TestCase.Name.
@@ -135,14 +135,13 @@ func (FilterInputDropTCPDestPort) Name() string {
// ContainerAction implements TestCase.ContainerAction.
func (FilterInputDropTCPDestPort) ContainerAction(ip net.IP) error {
- if err := filterTable("-A", "INPUT", "-p", "tcp", "-m", "tcp", "--dport",
- fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
+ if err := filterTable("-A", "INPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
return err
}
// Listen for TCP packets on drop port.
if err := listenTCP(dropPort, sendloopDuration); err == nil {
- return fmt.Errorf("Connections on port %d should not be accepted, but got accepted", dropPort)
+ return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort)
}
return nil
@@ -151,7 +150,7 @@ func (FilterInputDropTCPDestPort) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterInputDropTCPDestPort) LocalAction(ip net.IP) error {
if err := connectTCP(ip, dropPort, acceptPort, sendloopDuration); err == nil {
- return fmt.Errorf("Connection destined to port %d should not be accepted, but got accepted", dropPort)
+ return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort)
}
return nil
@@ -167,14 +166,13 @@ func (FilterInputDropTCPSrcPort) Name() string {
// ContainerAction implements TestCase.ContainerAction.
func (FilterInputDropTCPSrcPort) ContainerAction(ip net.IP) error {
- if err := filterTable("-A", "INPUT", "-p", "tcp", "-m", "tcp", "--sport",
- fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
+ if err := filterTable("-A", "INPUT", "-p", "tcp", "-m", "tcp", "--sport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
return err
}
// Listen for TCP packets on accept port.
if err := listenTCP(acceptPort, sendloopDuration); err == nil {
- return fmt.Errorf("connections destined to port %d should not be accepted, but got accepted", dropPort)
+ return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort)
}
return nil
@@ -183,7 +181,7 @@ func (FilterInputDropTCPSrcPort) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterInputDropTCPSrcPort) LocalAction(ip net.IP) error {
if err := connectTCP(ip, acceptPort, dropPort, sendloopDuration); err == nil {
- return fmt.Errorf("connection sent from port %d should not be accepted", dropPort)
+ return fmt.Errorf("connection on port %d should not be acceptedi, but got accepted", dropPort)
}
return nil
diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go
index 63d74e4f4..ee2c49f9a 100644
--- a/test/iptables/filter_output.go
+++ b/test/iptables/filter_output.go
@@ -34,14 +34,13 @@ func (FilterOutputDropTCPDestPort) Name() string {
// ContainerAction implements TestCase.ContainerAction.
func (FilterOutputDropTCPDestPort) ContainerAction(ip net.IP) error {
- if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport",
- fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
+ if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
return err
}
// Listen for TCP packets on accept port.
if err := listenTCP(acceptPort, sendloopDuration); err == nil {
- return fmt.Errorf("connections destined to port %d should not be accepted, but got accepted", dropPort)
+ return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort)
}
return nil
@@ -50,7 +49,7 @@ func (FilterOutputDropTCPDestPort) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterOutputDropTCPDestPort) LocalAction(ip net.IP) error {
if err := connectTCP(ip, acceptPort, dropPort, sendloopDuration); err == nil {
- return fmt.Errorf("connection sent from port %d should not be accepted, but got accepted", dropPort)
+ return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort)
}
return nil
@@ -66,14 +65,13 @@ func (FilterOutputDropTCPSrcPort) Name() string {
// ContainerAction implements TestCase.ContainerAction.
func (FilterOutputDropTCPSrcPort) ContainerAction(ip net.IP) error {
- if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--sport",
- fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
+ if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--sport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
return err
}
// Listen for TCP packets on drop port.
if err := listenTCP(dropPort, sendloopDuration); err == nil {
- return fmt.Errorf("connections on port %d should not be accepted, but got accepted", dropPort)
+ return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort)
}
return nil
@@ -82,8 +80,8 @@ func (FilterOutputDropTCPSrcPort) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterOutputDropTCPSrcPort) LocalAction(ip net.IP) error {
if err := connectTCP(ip, dropPort, acceptPort, sendloopDuration); err == nil {
- return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort)
- }
+ return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort)
+ }
return nil
}
diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go
index 3eeb75b8b..d268ea9b4 100644
--- a/test/iptables/iptables_test.go
+++ b/test/iptables/iptables_test.go
@@ -28,7 +28,7 @@ import (
"gvisor.dev/gvisor/runsc/testutil"
)
-const timeout time.Duration = 18 * time.Second
+const timeout = 18 * time.Second
var image = flag.String("image", "bazel/test/iptables/runner:runner", "image to run tests in")
diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go
index 44945bd89..1c4f4f665 100644
--- a/test/iptables/iptables_util.go
+++ b/test/iptables/iptables_util.go
@@ -81,33 +81,33 @@ func sendUDPLoop(ip net.IP, port int, duration time.Duration) error {
return nil
}
-// listenTCP listens for connections on a TCP port
+// listenTCP listens for connections on a TCP port.
func listenTCP(port int, timeout time.Duration) error {
localAddr := net.TCPAddr{
Port: port,
}
- // Starts listening on port
+ // Starts listening on port.
lConn, err := net.ListenTCP("tcp4", &localAddr)
if err != nil {
return err
}
defer lConn.Close()
- // Accept connections on port
+ // Accept connections on port.
lConn.SetDeadline(time.Now().Add(timeout))
conn, err := lConn.AcceptTCP()
- if err == nil {
- conn.Close()
+ if err != nil {
+ return err
}
- return err
+ conn.Close()
+ return nil
}
-// connectTCP connects the TCP server over specified local port, server IP
-// and remote/server port
-func connectTCP(ip net.IP, remotePort int, localPort int, duration time.Duration) error {
+// connectTCP connects the TCP server over specified local port, server IP and remote/server port.
+func connectTCP(ip net.IP, remotePort, localPort int, duration time.Duration) error {
remote := net.TCPAddr{
- IP: ip,
+ IP: ip,
Port: remotePort,
}
@@ -115,23 +115,21 @@ func connectTCP(ip net.IP, remotePort int, localPort int, duration time.Duration
Port: localPort,
}
- // Container may not be up. Retry DialTCP
- // over a given duration
+ // Container may not be up. Retry DialTCP over a duration.
to := time.After(duration)
- var res error
- for timedOut := false; !timedOut; {
+ for {
conn, err := net.DialTCP("tcp4", &local, &remote)
- res = err
- if res == nil {
+ if err == nil {
conn.Close()
return nil
}
- select{
+ select {
+ // Timed out waiting for connection to be accepted.
case <-to:
- timedOut = true
+ return err
default:
time.Sleep(200 * time.Millisecond)
}
}
- return res
+ return fmt.Errorf("Failed to establish connection on port %d", localPort)
}
diff --git a/test/iptables/nat.go b/test/iptables/nat.go
index 72c413af2..b5c6f927e 100644
--- a/test/iptables/nat.go
+++ b/test/iptables/nat.go
@@ -20,7 +20,7 @@ import (
)
const (
- redirectPort = 42
+ redirectPort = 42
)
func init() {
@@ -28,7 +28,7 @@ func init() {
RegisterTestCase(NATDropUDP{})
}
-// InputRedirectUDPPort tests that packets are redirected to different port.
+// NATRedirectUDPPort tests that packets are redirected to different port.
type NATRedirectUDPPort struct{}
// Name implements TestCase.Name.
@@ -38,8 +38,7 @@ func (NATRedirectUDPPort) Name() string {
// ContainerAction implements TestCase.ContainerAction.
func (NATRedirectUDPPort) ContainerAction(ip net.IP) error {
- if err := filterTable("-t", "nat", "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports",
- fmt.Sprintf("%d", redirectPort)); err != nil {
+ if err := filterTable("-t", "nat", "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil {
return err
}
@@ -64,8 +63,7 @@ func (NATDropUDP) Name() string {
// ContainerAction implements TestCase.ContainerAction.
func (NATDropUDP) ContainerAction(ip net.IP) error {
- if err := filterTable("-t", "nat", "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports",
- fmt.Sprintf("%d", redirectPort)); err != nil {
+ if err := filterTable("-t", "nat", "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil {
return err
}
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index a3a85917d..829693e8e 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -717,11 +717,6 @@ syscall_test(test = "//test/syscalls/linux:proc_net_tcp_test")
syscall_test(test = "//test/syscalls/linux:proc_net_udp_test")
-syscall_test(
- add_overlay = True,
- test = "//test/syscalls/linux:xattr_test",
-)
-
go_binary(
name = "syscall_test_runner",
testonly = 1,
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index ce8abe217..4c7ec3f06 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -2693,6 +2693,7 @@ cc_binary(
srcs = ["socket_inet_loopback.cc"],
linkstatic = 1,
deps = [
+ ":ip_socket_test_util",
":socket_test_util",
"//test/util:file_descriptor",
"//test/util:posix_error",
diff --git a/test/syscalls/linux/inotify.cc b/test/syscalls/linux/inotify.cc
index 59ec9940a..fdef646eb 100644
--- a/test/syscalls/linux/inotify.cc
+++ b/test/syscalls/linux/inotify.cc
@@ -977,7 +977,7 @@ TEST(Inotify, WatchOnRelativePath) {
ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY));
// Change working directory to root.
- const char* old_working_dir = get_current_dir_name();
+ const FileDescriptor cwd = ASSERT_NO_ERRNO_AND_VALUE(Open(".", O_PATH));
EXPECT_THAT(chdir(root.path().c_str()), SyscallSucceeds());
// Add a watch on file1 with a relative path.
@@ -997,7 +997,7 @@ TEST(Inotify, WatchOnRelativePath) {
// continue to hold a reference, random save/restore tests can fail if a save
// is triggered after "root" is unlinked; we can't save deleted fs objects
// with active references.
- EXPECT_THAT(chdir(old_working_dir), SyscallSucceeds());
+ EXPECT_THAT(fchdir(cwd.get()), SyscallSucceeds());
}
TEST(Inotify, ZeroLengthReadWriteDoesNotGenerateEvent) {
diff --git a/test/syscalls/linux/poll.cc b/test/syscalls/linux/poll.cc
index 9e5aa7fd0..c42472474 100644
--- a/test/syscalls/linux/poll.cc
+++ b/test/syscalls/linux/poll.cc
@@ -275,7 +275,8 @@ TEST_F(PollTest, Nfds) {
// Each entry in the 'fds' array refers to the eventfd and polls for
// "writable" events (events=POLLOUT). This essentially guarantees that the
// poll() is a no-op and allows negative testing of the 'nfds' parameter.
- std::vector<struct pollfd> fds(max_fds, {.fd = efd.get(), .events = POLLOUT});
+ std::vector<struct pollfd> fds(max_fds + 1,
+ {.fd = efd.get(), .events = POLLOUT});
// Verify that 'nfds' up to RLIMIT_NOFILE are allowed.
EXPECT_THAT(RetryEINTR(poll)(fds.data(), 1, 1), SyscallSucceedsWithValue(1));
diff --git a/test/syscalls/linux/preadv2.cc b/test/syscalls/linux/preadv2.cc
index c9246367d..cd936ea90 100644
--- a/test/syscalls/linux/preadv2.cc
+++ b/test/syscalls/linux/preadv2.cc
@@ -202,7 +202,7 @@ TEST(Preadv2Test, TestInvalidOffset) {
iov[0].iov_len = 0;
EXPECT_THAT(preadv2(fd.get(), iov.get(), /*iovcnt=*/1, /*offset=*/-8,
- /*flags=*/RWF_HIPRI),
+ /*flags=*/0),
SyscallFailsWithErrno(EINVAL));
}
diff --git a/test/syscalls/linux/readv_common.cc b/test/syscalls/linux/readv_common.cc
index 491d5f40f..2694dc64f 100644
--- a/test/syscalls/linux/readv_common.cc
+++ b/test/syscalls/linux/readv_common.cc
@@ -154,7 +154,7 @@ void ReadBuffersOverlapping(int fd) {
char* expected_ptr = expected.data();
memcpy(expected_ptr, &kReadvTestData[overlap_bytes], overlap_bytes);
memcpy(&expected_ptr[overlap_bytes], &kReadvTestData[overlap_bytes],
- kReadvTestDataSize);
+ kReadvTestDataSize - overlap_bytes);
struct iovec iovs[2];
iovs[0].iov_base = buffer.data();
diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc
index 5767181a1..5ed57625c 100644
--- a/test/syscalls/linux/socket_bind_to_device_distribution.cc
+++ b/test/syscalls/linux/socket_bind_to_device_distribution.cc
@@ -183,7 +183,14 @@ TEST_P(BindToDeviceDistributionTest, Tcp) {
}
// Receive some data from a socket to be sure that the connect()
// system call has been completed on another side.
- int data;
+ // Do a short read and then close the socket to trigger a RST. This
+ // ensures that both ends of the connection are cleaned up and no
+ // goroutines hang around in TIME-WAIT. We do this so that this test
+ // does not timeout under gotsan runs where lots of goroutines can
+ // cause the test to use absurd amounts of memory.
+ //
+ // See: https://tools.ietf.org/html/rfc2525#page-50 section 2.17
+ uint16_t data;
EXPECT_THAT(
RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0),
SyscallSucceedsWithValue(sizeof(data)));
@@ -198,15 +205,29 @@ TEST_P(BindToDeviceDistributionTest, Tcp) {
}
for (int i = 0; i < kConnectAttempts; i++) {
- FileDescriptor const fd = ASSERT_NO_ERRNO_AND_VALUE(
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
ASSERT_THAT(
RetryEINTR(connect)(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr),
connector.addr_len),
SyscallSucceeds());
+ // Do two separate sends to ensure two segments are received. This is
+ // required for netstack where read is incorrectly assuming a whole
+ // segment is read when endpoint.Read() is called which is technically
+ // incorrect as the syscall that invoked endpoint.Read() may only
+ // consume it partially. This results in a case where a close() of
+ // such a socket does not trigger a RST in netstack due to the
+ // endpoint assuming that the endpoint has no unread data.
EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
SyscallSucceedsWithValue(sizeof(i)));
+
+ // TODO(gvisor.dev/issue/1449): Remove this block once netstack correctly
+ // generates a RST.
+ if (IsRunningOnGvisor()) {
+ EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
}
// Join threads to be sure that all connections have been counted.
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 619d41901..2f9821555 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -32,6 +32,7 @@
#include "absl/strings/str_cat.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/posix_error.h"
@@ -102,6 +103,161 @@ TEST(BadSocketPairArgs, ValidateErrForBadCallsToSocketPair) {
SyscallFailsWithErrno(EAFNOSUPPORT));
}
+enum class Operation {
+ Bind,
+ Connect,
+ SendTo,
+};
+
+std::string OperationToString(Operation operation) {
+ switch (operation) {
+ case Operation::Bind:
+ return "Bind";
+ case Operation::Connect:
+ return "Connect";
+ case Operation::SendTo:
+ return "SendTo";
+ }
+}
+
+using OperationSequence = std::vector<Operation>;
+
+using DualStackSocketTest =
+ ::testing::TestWithParam<std::tuple<TestAddress, OperationSequence>>;
+
+TEST_P(DualStackSocketTest, AddressOperations) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_DGRAM, 0));
+
+ const TestAddress& addr = std::get<0>(GetParam());
+ const OperationSequence& operations = std::get<1>(GetParam());
+
+ auto addr_in = reinterpret_cast<const sockaddr*>(&addr.addr);
+
+ // sockets may only be bound once. Both `connect` and `sendto` cause a socket
+ // to be bound.
+ bool bound = false;
+ for (const Operation& operation : operations) {
+ bool sockname = false;
+ bool peername = false;
+ switch (operation) {
+ case Operation::Bind: {
+ ASSERT_NO_ERRNO(SetAddrPort(
+ addr.family(), const_cast<sockaddr_storage*>(&addr.addr), 0));
+
+ int bind_ret = bind(fd.get(), addr_in, addr.addr_len);
+
+ // Dual stack sockets may only be bound to AF_INET6.
+ if (!bound && addr.family() == AF_INET6) {
+ EXPECT_THAT(bind_ret, SyscallSucceeds());
+ bound = true;
+
+ sockname = true;
+ } else {
+ EXPECT_THAT(bind_ret, SyscallFailsWithErrno(EINVAL));
+ }
+ break;
+ }
+ case Operation::Connect: {
+ ASSERT_NO_ERRNO(SetAddrPort(
+ addr.family(), const_cast<sockaddr_storage*>(&addr.addr), 1337));
+
+ EXPECT_THAT(connect(fd.get(), addr_in, addr.addr_len),
+ SyscallSucceeds())
+ << GetAddrStr(addr_in);
+ bound = true;
+
+ sockname = true;
+ peername = true;
+
+ break;
+ }
+ case Operation::SendTo: {
+ const char payload[] = "hello";
+ ASSERT_NO_ERRNO(SetAddrPort(
+ addr.family(), const_cast<sockaddr_storage*>(&addr.addr), 1337));
+
+ ssize_t sendto_ret = sendto(fd.get(), &payload, sizeof(payload), 0,
+ addr_in, addr.addr_len);
+
+ EXPECT_THAT(sendto_ret, SyscallSucceedsWithValue(sizeof(payload)));
+ sockname = !bound;
+ bound = true;
+ break;
+ }
+ }
+
+ if (sockname) {
+ sockaddr_storage sock_addr;
+ socklen_t addrlen = sizeof(sock_addr);
+ ASSERT_THAT(getsockname(fd.get(), reinterpret_cast<sockaddr*>(&sock_addr),
+ &addrlen),
+ SyscallSucceeds());
+ ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6));
+
+ auto sock_addr_in6 = reinterpret_cast<const sockaddr_in6*>(&sock_addr);
+
+ if (operation == Operation::SendTo) {
+ EXPECT_EQ(sock_addr_in6->sin6_family, AF_INET6);
+ EXPECT_TRUE(IN6_IS_ADDR_UNSPECIFIED(sock_addr_in6->sin6_addr.s6_addr32))
+ << OperationToString(operation) << " getsocknam="
+ << GetAddrStr(reinterpret_cast<sockaddr*>(&sock_addr));
+
+ EXPECT_NE(sock_addr_in6->sin6_port, 0);
+ } else if (IN6_IS_ADDR_V4MAPPED(
+ reinterpret_cast<const sockaddr_in6*>(addr_in)
+ ->sin6_addr.s6_addr32)) {
+ EXPECT_TRUE(IN6_IS_ADDR_V4MAPPED(sock_addr_in6->sin6_addr.s6_addr32))
+ << OperationToString(operation) << " getsocknam="
+ << GetAddrStr(reinterpret_cast<sockaddr*>(&sock_addr));
+ }
+ }
+
+ if (peername) {
+ sockaddr_storage peer_addr;
+ socklen_t addrlen = sizeof(peer_addr);
+ ASSERT_THAT(getpeername(fd.get(), reinterpret_cast<sockaddr*>(&peer_addr),
+ &addrlen),
+ SyscallSucceeds());
+ ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6));
+
+ if (addr.family() == AF_INET ||
+ IN6_IS_ADDR_V4MAPPED(reinterpret_cast<const sockaddr_in6*>(addr_in)
+ ->sin6_addr.s6_addr32)) {
+ EXPECT_TRUE(IN6_IS_ADDR_V4MAPPED(
+ reinterpret_cast<const sockaddr_in6*>(&peer_addr)
+ ->sin6_addr.s6_addr32))
+ << OperationToString(operation) << " getpeername="
+ << GetAddrStr(reinterpret_cast<sockaddr*>(&peer_addr));
+ }
+ }
+ }
+}
+
+// TODO(gvisor.dev/issues/1556): uncomment V4MappedAny.
+INSTANTIATE_TEST_SUITE_P(
+ All, DualStackSocketTest,
+ ::testing::Combine(
+ ::testing::Values(V4Any(), V4Loopback(), /*V4MappedAny(),*/
+ V4MappedLoopback(), V6Any(), V6Loopback()),
+ ::testing::ValuesIn<OperationSequence>(
+ {{Operation::Bind, Operation::Connect, Operation::SendTo},
+ {Operation::Bind, Operation::SendTo, Operation::Connect},
+ {Operation::Connect, Operation::Bind, Operation::SendTo},
+ {Operation::Connect, Operation::SendTo, Operation::Bind},
+ {Operation::SendTo, Operation::Bind, Operation::Connect},
+ {Operation::SendTo, Operation::Connect, Operation::Bind}})),
+ [](::testing::TestParamInfo<
+ std::tuple<TestAddress, OperationSequence>> const& info) {
+ const TestAddress& addr = std::get<0>(info.param);
+ const OperationSequence& operations = std::get<1>(info.param);
+ std::string s = addr.description;
+ for (const Operation& operation : operations) {
+ absl::StrAppend(&s, OperationToString(operation));
+ }
+ return s;
+ });
+
void tcpSimpleConnectTest(TestAddress const& listener,
TestAddress const& connector, bool unbound) {
// Create the listening socket.
@@ -377,7 +533,7 @@ TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) {
// Sleep for a little over the linger timeout to reduce flakiness in
// save/restore tests.
- absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1));
+ absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 2));
ds.reset();
@@ -714,7 +870,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
sockaddr_storage listen_addr = listener.addr;
sockaddr_storage conn_addr = connector.addr;
constexpr int kThreadCount = 3;
- constexpr int kConnectAttempts = 4096;
+ constexpr int kConnectAttempts = 10000;
// Create the listening socket.
FileDescriptor listener_fds[kThreadCount];
@@ -729,7 +885,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
ASSERT_THAT(
bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
SyscallSucceeds());
- ASSERT_THAT(listen(fd, kConnectAttempts / 3), SyscallSucceeds());
+ ASSERT_THAT(listen(fd, 40), SyscallSucceeds());
// On the first bind we need to determine which port was bound.
if (i != 0) {
@@ -772,7 +928,14 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
}
// Receive some data from a socket to be sure that the connect()
// system call has been completed on another side.
- int data;
+ // Do a short read and then close the socket to trigger a RST. This
+ // ensures that both ends of the connection are cleaned up and no
+ // goroutines hang around in TIME-WAIT. We do this so that this test
+ // does not timeout under gotsan runs where lots of goroutines can
+ // cause the test to use absurd amounts of memory.
+ //
+ // See: https://tools.ietf.org/html/rfc2525#page-50 section 2.17
+ uint16_t data;
EXPECT_THAT(
RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0),
SyscallSucceedsWithValue(sizeof(data)));
@@ -795,8 +958,22 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
connector.addr_len),
SyscallSucceeds());
+ // Do two separate sends to ensure two segments are received. This is
+ // required for netstack where read is incorrectly assuming a whole
+ // segment is read when endpoint.Read() is called which is technically
+ // incorrect as the syscall that invoked endpoint.Read() may only
+ // consume it partially. This results in a case where a close() of
+ // such a socket does not trigger a RST in netstack due to the
+ // endpoint assuming that the endpoint has no unread data.
EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
SyscallSucceedsWithValue(sizeof(i)));
+
+ // TODO(gvisor.dev/issue/1449): Remove this block once netstack correctly
+ // generates a RST.
+ if (IsRunningOnGvisor()) {
+ EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
}
});
diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc
index 66eb68857..53290bed7 100644
--- a/test/syscalls/linux/socket_ip_udp_generic.cc
+++ b/test/syscalls/linux/socket_ip_udp_generic.cc
@@ -209,6 +209,46 @@ TEST_P(UDPSocketPairTest, SetMulticastLoopChar) {
EXPECT_EQ(get, kSockOptOn);
}
+// Ensure that Receiving TOS is off by default.
+TEST_P(UDPSocketPairTest, RecvTosDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+// Test that setting and getting IP_RECVTOS works as expected.
+TEST_P(UDPSocketPairTest, SetRecvTos) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS,
+ &kSockOptOff, sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOn);
+}
+
TEST_P(UDPSocketPairTest, ReuseAddrDefault) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index 6b99c021d..33a5ac66c 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -814,6 +814,20 @@ TEST_P(TcpSocketTest, FullBuffer) {
t_ = -1;
}
+TEST_P(TcpSocketTest, PollAfterShutdown) {
+ ScopedThread client_thread([this]() {
+ EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallSucceedsWithValue(0));
+ struct pollfd poll_fd = {s_, POLLIN | POLLERR | POLLHUP, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000),
+ SyscallSucceedsWithValue(1));
+ });
+
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceedsWithValue(0));
+ struct pollfd poll_fd = {t_, POLLIN | POLLERR | POLLHUP, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000),
+ SyscallSucceedsWithValue(1));
+}
+
TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListener) {
// Initialize address to the loopback one.
sockaddr_storage addr =
diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc
index dc35c2f50..68e0a8109 100644
--- a/test/syscalls/linux/udp_socket_test_cases.cc
+++ b/test/syscalls/linux/udp_socket_test_cases.cc
@@ -1349,8 +1349,9 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) {
// outgoing packets, and that a receiving socket with IP_RECVTOS or
// IPV6_RECVTCLASS will create the corresponding control message.
TEST_P(UdpSocketTest, SetAndReceiveTOS) {
- // TODO(b/68320120): IP_RECVTOS/IPV6_RECVTCLASS not supported for netstack.
- SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet());
+ // TODO(b/68320120): IPV6_RECVTCLASS not supported for netstack.
+ SKIP_IF((GetParam() != AddressFamily::kIpv4) && IsRunningOnGvisor() &&
+ !IsRunningWithHostinet());
ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds());
@@ -1421,7 +1422,8 @@ TEST_P(UdpSocketTest, SetAndReceiveTOS) {
// TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or
// IPV6_RECVTCLASS will create the corresponding control message.
TEST_P(UdpSocketTest, SendAndReceiveTOS) {
- // TODO(b/68320120): IP_RECVTOS/IPV6_RECVTCLASS not supported for netstack.
+ // TODO(b/68320120): IPV6_RECVTCLASS not supported for netstack.
+ // TODO(b/146661005): Setting TOS via cmsg not supported for netstack.
SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet());
ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds());
diff --git a/test/syscalls/linux/xattr.cc b/test/syscalls/linux/xattr.cc
index 75740238c..b3bc3463e 100644
--- a/test/syscalls/linux/xattr.cc
+++ b/test/syscalls/linux/xattr.cc
@@ -59,7 +59,8 @@ TEST_F(XattrTest, XattrLargeName) {
std::string name = "user.";
name += std::string(XATTR_NAME_MAX - name.length(), 'a');
- // TODO(b/127675828): Support setxattr and getxattr.
+ // An xattr should be whitelisted before it can be accessed--do not allow
+ // arbitrary xattrs to be read/written in gVisor.
if (!IsRunningOnGvisor()) {
EXPECT_THAT(setxattr(path, name.c_str(), nullptr, 0, /*flags=*/0),
SyscallSucceeds());
@@ -83,59 +84,53 @@ TEST_F(XattrTest, XattrInvalidPrefix) {
SyscallFailsWithErrno(EOPNOTSUPP));
}
-TEST_F(XattrTest, XattrReadOnly) {
+// Do not allow save/restore cycles after making the test file read-only, as
+// the restore will fail to open it with r/w permissions.
+TEST_F(XattrTest, XattrReadOnly_NoRandomSave) {
// Drop capabilities that allow us to override file and directory permissions.
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
char val = 'a';
size_t size = sizeof(val);
- // TODO(b/127675828): Support setxattr and getxattr.
- if (!IsRunningOnGvisor()) {
- EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0),
- SyscallSucceeds());
- }
+ EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds());
+ DisableSave ds;
ASSERT_NO_ERRNO(testing::Chmod(test_file_name_, S_IRUSR));
EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0),
SyscallFailsWithErrno(EACCES));
- // TODO(b/127675828): Support setxattr and getxattr.
- if (!IsRunningOnGvisor()) {
- char buf = '-';
- EXPECT_THAT(getxattr(path, name, &buf, size),
- SyscallSucceedsWithValue(size));
- EXPECT_EQ(buf, val);
- }
+ char buf = '-';
+ EXPECT_THAT(getxattr(path, name, &buf, size), SyscallSucceedsWithValue(size));
+ EXPECT_EQ(buf, val);
}
-TEST_F(XattrTest, XattrWriteOnly) {
+// Do not allow save/restore cycles after making the test file write-only, as
+// the restore will fail to open it with r/w permissions.
+TEST_F(XattrTest, XattrWriteOnly_NoRandomSave) {
// Drop capabilities that allow us to override file and directory permissions.
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+ DisableSave ds;
ASSERT_NO_ERRNO(testing::Chmod(test_file_name_, S_IWUSR));
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
char val = 'a';
size_t size = sizeof(val);
- // TODO(b/127675828): Support setxattr and getxattr.
- if (!IsRunningOnGvisor()) {
- EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0),
- SyscallSucceeds());
- }
+ EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds());
EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(EACCES));
}
TEST_F(XattrTest, XattrTrustedWithNonadmin) {
- // TODO(b/127675828): Support setxattr and getxattr.
+ // TODO(b/127675828): Support setxattr and getxattr with "trusted" prefix.
SKIP_IF(IsRunningOnGvisor());
SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
@@ -147,11 +142,8 @@ TEST_F(XattrTest, XattrTrustedWithNonadmin) {
}
TEST_F(XattrTest, XattrOnDirectory) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
- char name[] = "user.abc";
+ const char name[] = "user.test";
EXPECT_THAT(setxattr(dir.path().c_str(), name, NULL, 0, /*flags=*/0),
SyscallSucceeds());
EXPECT_THAT(getxattr(dir.path().c_str(), name, NULL, 0),
@@ -159,13 +151,10 @@ TEST_F(XattrTest, XattrOnDirectory) {
}
TEST_F(XattrTest, XattrOnSymlink) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateSymlinkTo(dir.path(), test_file_name_));
- char name[] = "user.abc";
+ const char name[] = "user.test";
EXPECT_THAT(setxattr(link.path().c_str(), name, NULL, 0, /*flags=*/0),
SyscallSucceeds());
EXPECT_THAT(getxattr(link.path().c_str(), name, NULL, 0),
@@ -173,7 +162,7 @@ TEST_F(XattrTest, XattrOnSymlink) {
}
TEST_F(XattrTest, XattrOnInvalidFileTypes) {
- char name[] = "user.abc";
+ const char name[] = "user.test";
char char_device[] = "/dev/zero";
EXPECT_THAT(setxattr(char_device, name, NULL, 0, /*flags=*/0),
@@ -191,11 +180,8 @@ TEST_F(XattrTest, XattrOnInvalidFileTypes) {
}
TEST_F(XattrTest, SetxattrSizeSmallerThanValue) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
std::vector<char> val = {'a', 'a'};
size_t size = 1;
EXPECT_THAT(setxattr(path, name, val.data(), size, /*flags=*/0),
@@ -209,11 +195,8 @@ TEST_F(XattrTest, SetxattrSizeSmallerThanValue) {
}
TEST_F(XattrTest, SetxattrZeroSize) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
char val = 'a';
EXPECT_THAT(setxattr(path, name, &val, 0, /*flags=*/0), SyscallSucceeds());
@@ -225,7 +208,7 @@ TEST_F(XattrTest, SetxattrZeroSize) {
TEST_F(XattrTest, SetxattrSizeTooLarge) {
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
// Note that each particular fs implementation may stipulate a lower size
// limit, in which case we actually may fail (e.g. error with ENOSPC) for
@@ -235,43 +218,29 @@ TEST_F(XattrTest, SetxattrSizeTooLarge) {
EXPECT_THAT(setxattr(path, name, val.data(), size, /*flags=*/0),
SyscallFailsWithErrno(E2BIG));
- // TODO(b/127675828): Support setxattr and getxattr.
- if (!IsRunningOnGvisor()) {
- EXPECT_THAT(getxattr(path, name, nullptr, 0),
- SyscallFailsWithErrno(ENODATA));
- }
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
}
TEST_F(XattrTest, SetxattrNullValueAndNonzeroSize) {
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
EXPECT_THAT(setxattr(path, name, nullptr, 1, /*flags=*/0),
SyscallFailsWithErrno(EFAULT));
- // TODO(b/127675828): Support setxattr and getxattr.
- if (!IsRunningOnGvisor()) {
- EXPECT_THAT(getxattr(path, name, nullptr, 0),
- SyscallFailsWithErrno(ENODATA));
- }
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
}
TEST_F(XattrTest, SetxattrNullValueAndZeroSize) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0));
}
TEST_F(XattrTest, SetxattrValueTooLargeButOKSize) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
std::vector<char> val(XATTR_SIZE_MAX + 1);
std::fill(val.begin(), val.end(), 'a');
size_t size = 1;
@@ -286,11 +255,8 @@ TEST_F(XattrTest, SetxattrValueTooLargeButOKSize) {
}
TEST_F(XattrTest, SetxattrReplaceWithSmaller) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
std::vector<char> val = {'a', 'a'};
EXPECT_THAT(setxattr(path, name, val.data(), 2, /*flags=*/0),
SyscallSucceeds());
@@ -304,11 +270,8 @@ TEST_F(XattrTest, SetxattrReplaceWithSmaller) {
}
TEST_F(XattrTest, SetxattrReplaceWithLarger) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
std::vector<char> val = {'a', 'a'};
EXPECT_THAT(setxattr(path, name, val.data(), 1, /*flags=*/0),
SyscallSucceeds());
@@ -321,11 +284,8 @@ TEST_F(XattrTest, SetxattrReplaceWithLarger) {
}
TEST_F(XattrTest, SetxattrCreateFlag) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_CREATE),
SyscallSucceeds());
EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_CREATE),
@@ -335,11 +295,8 @@ TEST_F(XattrTest, SetxattrCreateFlag) {
}
TEST_F(XattrTest, SetxattrReplaceFlag) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_REPLACE),
SyscallFailsWithErrno(ENODATA));
EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
@@ -357,11 +314,8 @@ TEST_F(XattrTest, SetxattrInvalidFlags) {
}
TEST_F(XattrTest, Getxattr) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
int val = 1234;
size_t size = sizeof(val);
EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds());
@@ -372,11 +326,8 @@ TEST_F(XattrTest, Getxattr) {
}
TEST_F(XattrTest, GetxattrSizeSmallerThanValue) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
std::vector<char> val = {'a', 'a'};
size_t size = val.size();
EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds());
@@ -387,11 +338,8 @@ TEST_F(XattrTest, GetxattrSizeSmallerThanValue) {
}
TEST_F(XattrTest, GetxattrSizeLargerThanValue) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
char val = 'a';
EXPECT_THAT(setxattr(path, name, &val, 1, /*flags=*/0), SyscallSucceeds());
@@ -405,11 +353,8 @@ TEST_F(XattrTest, GetxattrSizeLargerThanValue) {
}
TEST_F(XattrTest, GetxattrZeroSize) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
char val = 'a';
EXPECT_THAT(setxattr(path, name, &val, sizeof(val), /*flags=*/0),
SyscallSucceeds());
@@ -421,11 +366,8 @@ TEST_F(XattrTest, GetxattrZeroSize) {
}
TEST_F(XattrTest, GetxattrSizeTooLarge) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
char val = 'a';
EXPECT_THAT(setxattr(path, name, &val, sizeof(val), /*flags=*/0),
SyscallSucceeds());
@@ -440,11 +382,8 @@ TEST_F(XattrTest, GetxattrSizeTooLarge) {
}
TEST_F(XattrTest, GetxattrNullValue) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
char val = 'a';
size_t size = sizeof(val);
EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds());
@@ -454,11 +393,8 @@ TEST_F(XattrTest, GetxattrNullValue) {
}
TEST_F(XattrTest, GetxattrNullValueAndZeroSize) {
- // TODO(b/127675828): Support setxattr and getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- char name[] = "user.abc";
+ const char name[] = "user.test";
char val = 'a';
size_t size = sizeof(val);
// Set value with zero size.
@@ -473,13 +409,9 @@ TEST_F(XattrTest, GetxattrNullValueAndZeroSize) {
}
TEST_F(XattrTest, GetxattrNonexistentName) {
- // TODO(b/127675828): Support getxattr.
- SKIP_IF(IsRunningOnGvisor());
-
const char* path = test_file_name_.c_str();
- std::string name = "user.nonexistent";
- EXPECT_THAT(getxattr(path, name.c_str(), nullptr, 0),
- SyscallFailsWithErrno(ENODATA));
+ const char name[] = "user.test";
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
}
} // namespace